Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
188ea7df6e
commit
c100a59f0e
6
.vscode/settings.json
vendored
6
.vscode/settings.json
vendored
@ -4,6 +4,8 @@
|
|||||||
"src/peft_repo/src/"
|
"src/peft_repo/src/"
|
||||||
],
|
],
|
||||||
"python.analysis.exclude": [
|
"python.analysis.exclude": [
|
||||||
"dataset/**"
|
"dataset/**/*"
|
||||||
]
|
],
|
||||||
|
"python.languageServer": "Default",
|
||||||
|
"python.terminal.activateEnvInCurrentTerminal": true
|
||||||
}
|
}
|
15
src/peft_library/regularizations/__init__.py
Normal file
15
src/peft_library/regularizations/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
class RegularizationMethod:
|
||||||
|
"""RegularizationMethod implement regularization strategies.
|
||||||
|
RegularizationMethod is a callable.
|
||||||
|
The method `update` is called to update the loss, typically at the end
|
||||||
|
of an experience.
|
||||||
|
"""
|
||||||
|
def pre_adapt(self, agent, exp):
|
||||||
|
pass # implementation may be empty if adapt is not needed
|
||||||
|
|
||||||
|
def post_adapt(self, agent, exp):
|
||||||
|
pass # implementation may be empty if adapt is not needed
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError()
|
51
src/peft_library/regularizations/ewc.py
Normal file
51
src/peft_library/regularizations/ewc.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from . import RegularizationMethod
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class EWC(RegularizationMethod):
|
||||||
|
"""Learning Without Forgetting.
|
||||||
|
|
||||||
|
The method applies knowledge distilllation to mitigate forgetting.
|
||||||
|
The teacher is the model checkpoint after the last experience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, EWC_lambda=1, temperature=2):
|
||||||
|
"""
|
||||||
|
:param alpha: distillation hyperparameter. It can be either a float
|
||||||
|
number or a list containing alpha for each experience.
|
||||||
|
:param temperature: softmax temperature for distillation
|
||||||
|
"""
|
||||||
|
self.EWC_lambda = EWC_lambda
|
||||||
|
self.temperature = temperature
|
||||||
|
self.fisher = {}
|
||||||
|
self.optpar = {}
|
||||||
|
""" In Avalanche, targets of different experiences are not ordered.
|
||||||
|
As a result, some units may be allocated even though their
|
||||||
|
corresponding class has never been seen by the model.
|
||||||
|
Knowledge distillation uses only units corresponding
|
||||||
|
to old classes.
|
||||||
|
"""
|
||||||
|
def adapt(self, output,model, **kwargs):
|
||||||
|
ewc_loss = 0
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
dev = p.device
|
||||||
|
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2)
|
||||||
|
ewc_loss += l.sum()
|
||||||
|
output['loss'] += ewc_loss
|
||||||
|
return output
|
||||||
|
|
||||||
|
def init_epoch(self, model):
|
||||||
|
"""Update the previous logits for the given question id."""
|
||||||
|
optpar = {}
|
||||||
|
fisher = {}
|
||||||
|
for n, p in model.module.base_model.model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
fisher[n] = torch.zeros(p.data.shape)
|
||||||
|
optpar[n] = p.clone().cpu().data
|
||||||
|
def update_fisher(self, model):
|
||||||
|
"""Update the fisher information for the given question id."""
|
||||||
|
for n, p in model.module.base_model.model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
fisher = self.fisher[n]
|
||||||
|
fisher += p.grad.data.pow(2).cpu()
|
||||||
|
self.fisher[n] = fisher
|
54
src/peft_library/regularizations/lwf.py
Normal file
54
src/peft_library/regularizations/lwf.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from . import RegularizationMethod
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class LWF(RegularizationMethod):
|
||||||
|
"""Learning Without Forgetting.
|
||||||
|
|
||||||
|
The method applies knowledge distilllation to mitigate forgetting.
|
||||||
|
The teacher is the model checkpoint after the last experience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, LWF_lambda=1, temperature=2):
|
||||||
|
"""
|
||||||
|
:param alpha: distillation hyperparameter. It can be either a float
|
||||||
|
number or a list containing alpha for each experience.
|
||||||
|
:param temperature: softmax temperature for distillation
|
||||||
|
"""
|
||||||
|
self.LWF_lambda = LWF_lambda
|
||||||
|
self.temperature = temperature
|
||||||
|
self.previous_logits = {}
|
||||||
|
""" In Avalanche, targets of different experiences are not ordered.
|
||||||
|
As a result, some units may be allocated even though their
|
||||||
|
corresponding class has never been seen by the model.
|
||||||
|
Knowledge distillation uses only units corresponding
|
||||||
|
to old classes.
|
||||||
|
"""
|
||||||
|
def adapt(self, output, **kwargs):
|
||||||
|
def modified_kl_div(old, new):
|
||||||
|
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
||||||
|
|
||||||
|
def smooth(logits, temp, dim):
|
||||||
|
log = logits ** (1 / temp)
|
||||||
|
return log / torch.sum(log, dim).unsqueeze(1)
|
||||||
|
|
||||||
|
lwf_loss = []
|
||||||
|
|
||||||
|
soft = torch.nn.Softmax(dim=1)
|
||||||
|
|
||||||
|
previous_keys = self.previous_logits.keys()
|
||||||
|
|
||||||
|
for index, question_id in enumerate(iterable=kwargs['question_ids']):
|
||||||
|
if question_id in previous_keys:
|
||||||
|
previous_logits = self.previous_logits[question_id]
|
||||||
|
current_logits = output['logits'][index]
|
||||||
|
short_index = min(len(previous_logits), len(current_logits))
|
||||||
|
previous_logits = previous_logits[:short_index]
|
||||||
|
current_logits = current_logits[:short_index]
|
||||||
|
lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1)))
|
||||||
|
if len(lwf_loss) > 0:
|
||||||
|
output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def update_previous_logits(self, question_id, logits):
|
||||||
|
"""Update the previous logits for the given question id."""
|
||||||
|
self.previous_logits[question_id] = logits
|
@ -47,7 +47,7 @@ if __name__ == "__main__":
|
|||||||
accelerator = create_accelerator_and_postprocess(training_args)
|
accelerator = create_accelerator_and_postprocess(training_args)
|
||||||
|
|
||||||
if model_args.peft_type == "MMOELORA":
|
if model_args.peft_type == "MMOELORA":
|
||||||
from peft_library.tuners import MMOELoraConfig
|
from peft.tuners import MMOELoraConfig
|
||||||
|
|
||||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
|
@ -18,3 +18,17 @@ class ContinualModelConfig(ModelConfig):
|
|||||||
"""Model configuration for continual learning."""
|
"""Model configuration for continual learning."""
|
||||||
|
|
||||||
peft_type: Optional[str] = None
|
peft_type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContiunalRegularizationArguments:
|
||||||
|
"""Regularization arguments for continual learning."""
|
||||||
|
|
||||||
|
# EWC
|
||||||
|
ewc_lambda: float = 0.0
|
||||||
|
ewc_enable: bool = False
|
||||||
|
|
||||||
|
# LWF
|
||||||
|
lwf_lambda: float = 0.0
|
||||||
|
lwf_enable: bool = False
|
||||||
|
|
1450
src/utils/trainer.py
1450
src/utils/trainer.py
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user