cl-lmm/src/utils/args.py

34 lines
785 B
Python

from dataclasses import dataclass, field
from typing import Optional
from trl import ScriptArguments, ModelConfig
@dataclass
class ContinualScriptArguments(ScriptArguments):
"""Script arguments for continual learning."""
dataset_name: list[str] = field(
default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"]
)
dataset_generation_split: str = "generation"
@dataclass
class ContinualModelConfig(ModelConfig):
"""Model configuration for continual learning."""
peft_type: Optional[str] = None
@dataclass
class ContiunalRegularizationArguments:
"""Regularization arguments for continual learning."""
# EWC
ewc_lambda: float = 0.0
ewc_enable: bool = False
# LWF
lwf_lambda: float = 0.0
lwf_enable: bool = False