Refactor code structure for improved readability and maintainability

This commit is contained in:
YunyaoZhou 2025-05-16 19:50:41 +08:00
parent 188ea7df6e
commit c100a59f0e
7 changed files with 874 additions and 718 deletions

View File

@ -4,6 +4,8 @@
"src/peft_repo/src/" "src/peft_repo/src/"
], ],
"python.analysis.exclude": [ "python.analysis.exclude": [
"dataset/**" "dataset/**/*"
] ],
"python.languageServer": "Default",
"python.terminal.activateEnvInCurrentTerminal": true
} }

View File

@ -0,0 +1,15 @@
class RegularizationMethod:
"""RegularizationMethod implement regularization strategies.
RegularizationMethod is a callable.
The method `update` is called to update the loss, typically at the end
of an experience.
"""
def pre_adapt(self, agent, exp):
pass # implementation may be empty if adapt is not needed
def post_adapt(self, agent, exp):
pass # implementation may be empty if adapt is not needed
def __call__(self, *args, **kwargs):
raise NotImplementedError()

View File

@ -0,0 +1,51 @@
from . import RegularizationMethod
import torch
class EWC(RegularizationMethod):
"""Learning Without Forgetting.
The method applies knowledge distilllation to mitigate forgetting.
The teacher is the model checkpoint after the last experience.
"""
def __init__(self, EWC_lambda=1, temperature=2):
"""
:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
:param temperature: softmax temperature for distillation
"""
self.EWC_lambda = EWC_lambda
self.temperature = temperature
self.fisher = {}
self.optpar = {}
""" In Avalanche, targets of different experiences are not ordered.
As a result, some units may be allocated even though their
corresponding class has never been seen by the model.
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output,model, **kwargs):
ewc_loss = 0
for n, p in model.named_parameters():
if p.requires_grad:
dev = p.device
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2)
ewc_loss += l.sum()
output['loss'] += ewc_loss
return output
def init_epoch(self, model):
"""Update the previous logits for the given question id."""
optpar = {}
fisher = {}
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher[n] = torch.zeros(p.data.shape)
optpar[n] = p.clone().cpu().data
def update_fisher(self, model):
"""Update the fisher information for the given question id."""
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher = self.fisher[n]
fisher += p.grad.data.pow(2).cpu()
self.fisher[n] = fisher

View File

@ -0,0 +1,54 @@
from . import RegularizationMethod
import torch
class LWF(RegularizationMethod):
"""Learning Without Forgetting.
The method applies knowledge distilllation to mitigate forgetting.
The teacher is the model checkpoint after the last experience.
"""
def __init__(self, LWF_lambda=1, temperature=2):
"""
:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
:param temperature: softmax temperature for distillation
"""
self.LWF_lambda = LWF_lambda
self.temperature = temperature
self.previous_logits = {}
""" In Avalanche, targets of different experiences are not ordered.
As a result, some units may be allocated even though their
corresponding class has never been seen by the model.
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output, **kwargs):
def modified_kl_div(old, new):
return -torch.mean(torch.sum(old * torch.log(new), 1))
def smooth(logits, temp, dim):
log = logits ** (1 / temp)
return log / torch.sum(log, dim).unsqueeze(1)
lwf_loss = []
soft = torch.nn.Softmax(dim=1)
previous_keys = self.previous_logits.keys()
for index, question_id in enumerate(iterable=kwargs['question_ids']):
if question_id in previous_keys:
previous_logits = self.previous_logits[question_id]
current_logits = output['logits'][index]
short_index = min(len(previous_logits), len(current_logits))
previous_logits = previous_logits[:short_index]
current_logits = current_logits[:short_index]
lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1)))
if len(lwf_loss) > 0:
output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0)
return output
def update_previous_logits(self, question_id, logits):
"""Update the previous logits for the given question id."""
self.previous_logits[question_id] = logits

View File

@ -47,7 +47,7 @@ if __name__ == "__main__":
accelerator = create_accelerator_and_postprocess(training_args) accelerator = create_accelerator_and_postprocess(training_args)
if model_args.peft_type == "MMOELORA": if model_args.peft_type == "MMOELORA":
from peft_library.tuners import MMOELoraConfig from peft.tuners import MMOELoraConfig
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules) peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)

View File

@ -18,3 +18,17 @@ class ContinualModelConfig(ModelConfig):
"""Model configuration for continual learning.""" """Model configuration for continual learning."""
peft_type: Optional[str] = None peft_type: Optional[str] = None
@dataclass
class ContiunalRegularizationArguments:
"""Regularization arguments for continual learning."""
# EWC
ewc_lambda: float = 0.0
ewc_enable: bool = False
# LWF
lwf_lambda: float = 0.0
lwf_enable: bool = False

File diff suppressed because it is too large Load Diff