Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
188ea7df6e
commit
c100a59f0e
6
.vscode/settings.json
vendored
6
.vscode/settings.json
vendored
@ -4,6 +4,8 @@
|
||||
"src/peft_repo/src/"
|
||||
],
|
||||
"python.analysis.exclude": [
|
||||
"dataset/**"
|
||||
]
|
||||
"dataset/**/*"
|
||||
],
|
||||
"python.languageServer": "Default",
|
||||
"python.terminal.activateEnvInCurrentTerminal": true
|
||||
}
|
15
src/peft_library/regularizations/__init__.py
Normal file
15
src/peft_library/regularizations/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
class RegularizationMethod:
|
||||
"""RegularizationMethod implement regularization strategies.
|
||||
RegularizationMethod is a callable.
|
||||
The method `update` is called to update the loss, typically at the end
|
||||
of an experience.
|
||||
"""
|
||||
def pre_adapt(self, agent, exp):
|
||||
pass # implementation may be empty if adapt is not needed
|
||||
|
||||
def post_adapt(self, agent, exp):
|
||||
pass # implementation may be empty if adapt is not needed
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
51
src/peft_library/regularizations/ewc.py
Normal file
51
src/peft_library/regularizations/ewc.py
Normal file
@ -0,0 +1,51 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
class EWC(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
The method applies knowledge distilllation to mitigate forgetting.
|
||||
The teacher is the model checkpoint after the last experience.
|
||||
"""
|
||||
|
||||
def __init__(self, EWC_lambda=1, temperature=2):
|
||||
"""
|
||||
:param alpha: distillation hyperparameter. It can be either a float
|
||||
number or a list containing alpha for each experience.
|
||||
:param temperature: softmax temperature for distillation
|
||||
"""
|
||||
self.EWC_lambda = EWC_lambda
|
||||
self.temperature = temperature
|
||||
self.fisher = {}
|
||||
self.optpar = {}
|
||||
""" In Avalanche, targets of different experiences are not ordered.
|
||||
As a result, some units may be allocated even though their
|
||||
corresponding class has never been seen by the model.
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
def adapt(self, output,model, **kwargs):
|
||||
ewc_loss = 0
|
||||
for n, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
dev = p.device
|
||||
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2)
|
||||
ewc_loss += l.sum()
|
||||
output['loss'] += ewc_loss
|
||||
return output
|
||||
|
||||
def init_epoch(self, model):
|
||||
"""Update the previous logits for the given question id."""
|
||||
optpar = {}
|
||||
fisher = {}
|
||||
for n, p in model.module.base_model.model.named_parameters():
|
||||
if p.requires_grad:
|
||||
fisher[n] = torch.zeros(p.data.shape)
|
||||
optpar[n] = p.clone().cpu().data
|
||||
def update_fisher(self, model):
|
||||
"""Update the fisher information for the given question id."""
|
||||
for n, p in model.module.base_model.model.named_parameters():
|
||||
if p.requires_grad:
|
||||
fisher = self.fisher[n]
|
||||
fisher += p.grad.data.pow(2).cpu()
|
||||
self.fisher[n] = fisher
|
54
src/peft_library/regularizations/lwf.py
Normal file
54
src/peft_library/regularizations/lwf.py
Normal file
@ -0,0 +1,54 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
class LWF(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
The method applies knowledge distilllation to mitigate forgetting.
|
||||
The teacher is the model checkpoint after the last experience.
|
||||
"""
|
||||
|
||||
def __init__(self, LWF_lambda=1, temperature=2):
|
||||
"""
|
||||
:param alpha: distillation hyperparameter. It can be either a float
|
||||
number or a list containing alpha for each experience.
|
||||
:param temperature: softmax temperature for distillation
|
||||
"""
|
||||
self.LWF_lambda = LWF_lambda
|
||||
self.temperature = temperature
|
||||
self.previous_logits = {}
|
||||
""" In Avalanche, targets of different experiences are not ordered.
|
||||
As a result, some units may be allocated even though their
|
||||
corresponding class has never been seen by the model.
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
def adapt(self, output, **kwargs):
|
||||
def modified_kl_div(old, new):
|
||||
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
||||
|
||||
def smooth(logits, temp, dim):
|
||||
log = logits ** (1 / temp)
|
||||
return log / torch.sum(log, dim).unsqueeze(1)
|
||||
|
||||
lwf_loss = []
|
||||
|
||||
soft = torch.nn.Softmax(dim=1)
|
||||
|
||||
previous_keys = self.previous_logits.keys()
|
||||
|
||||
for index, question_id in enumerate(iterable=kwargs['question_ids']):
|
||||
if question_id in previous_keys:
|
||||
previous_logits = self.previous_logits[question_id]
|
||||
current_logits = output['logits'][index]
|
||||
short_index = min(len(previous_logits), len(current_logits))
|
||||
previous_logits = previous_logits[:short_index]
|
||||
current_logits = current_logits[:short_index]
|
||||
lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1)))
|
||||
if len(lwf_loss) > 0:
|
||||
output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0)
|
||||
return output
|
||||
|
||||
def update_previous_logits(self, question_id, logits):
|
||||
"""Update the previous logits for the given question id."""
|
||||
self.previous_logits[question_id] = logits
|
@ -47,7 +47,7 @@ if __name__ == "__main__":
|
||||
accelerator = create_accelerator_and_postprocess(training_args)
|
||||
|
||||
if model_args.peft_type == "MMOELORA":
|
||||
from peft_library.tuners import MMOELoraConfig
|
||||
from peft.tuners import MMOELoraConfig
|
||||
|
||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
|
||||
|
@ -18,3 +18,17 @@ class ContinualModelConfig(ModelConfig):
|
||||
"""Model configuration for continual learning."""
|
||||
|
||||
peft_type: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContiunalRegularizationArguments:
|
||||
"""Regularization arguments for continual learning."""
|
||||
|
||||
# EWC
|
||||
ewc_lambda: float = 0.0
|
||||
ewc_enable: bool = False
|
||||
|
||||
# LWF
|
||||
lwf_lambda: float = 0.0
|
||||
lwf_enable: bool = False
|
||||
|
1450
src/utils/trainer.py
1450
src/utils/trainer.py
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user