34 lines
785 B
Python
34 lines
785 B
Python
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
from trl import ScriptArguments, ModelConfig
|
|
|
|
|
|
@dataclass
|
|
class ContinualScriptArguments(ScriptArguments):
|
|
"""Script arguments for continual learning."""
|
|
|
|
dataset_name: list[str] = field(
|
|
default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"]
|
|
)
|
|
dataset_generation_split: str = "generation"
|
|
|
|
|
|
@dataclass
|
|
class ContinualModelConfig(ModelConfig):
|
|
"""Model configuration for continual learning."""
|
|
|
|
peft_type: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class ContiunalRegularizationArguments:
|
|
"""Regularization arguments for continual learning."""
|
|
|
|
# EWC
|
|
ewc_lambda: float = 0.0
|
|
ewc_enable: bool = False
|
|
|
|
# LWF
|
|
lwf_lambda: float = 0.0
|
|
lwf_enable: bool = False
|