cl-lmm/src/peft_library/regularizations/ewc.py

51 lines
2.0 KiB
Python

from . import RegularizationMethod
import torch
class EWC(RegularizationMethod):
"""Learning Without Forgetting.
The method applies knowledge distilllation to mitigate forgetting.
The teacher is the model checkpoint after the last experience.
"""
def __init__(self, EWC_lambda=1, temperature=2):
"""
:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
:param temperature: softmax temperature for distillation
"""
self.EWC_lambda = EWC_lambda
self.temperature = temperature
self.fisher = {}
self.optpar = {}
""" In Avalanche, targets of different experiences are not ordered.
As a result, some units may be allocated even though their
corresponding class has never been seen by the model.
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output,model, **kwargs):
ewc_loss = 0
for n, p in model.named_parameters():
if p.requires_grad:
dev = p.device
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2)
ewc_loss += l.sum()
output['loss'] += ewc_loss
return output
def init_epoch(self, model):
"""Update the previous logits for the given question id."""
optpar = {}
fisher = {}
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher[n] = torch.zeros(p.data.shape)
optpar[n] = p.clone().cpu().data
def update_fisher(self, model):
"""Update the fisher information for the given question id."""
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher = self.fisher[n]
fisher += p.grad.data.pow(2).cpu()
self.fisher[n] = fisher