from dataclasses import dataclass, field from typing import Optional from trl import ScriptArguments, ModelConfig @dataclass class ContinualScriptArguments(ScriptArguments): """Script arguments for continual learning.""" dataset_name: list[str] = field( default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"] ) dataset_generation_split: str = "generation" @dataclass class ContinualModelConfig(ModelConfig): """Model configuration for continual learning.""" peft_type: Optional[str] = None @dataclass class ContiunalRegularizationArguments: """Regularization arguments for continual learning.""" # EWC ewc_lambda: float = 0.0 ewc_enable: bool = False # LWF lwf_lambda: float = 0.0 lwf_enable: bool = False