cl-lmm/src/peft_library/regularizations/ewc.py

59 lines
2.1 KiB
Python

from . import RegularizationMethod
import torch
class EWC(RegularizationMethod):
"""Learning Without Forgetting.
The method applies knowledge distilllation to mitigate forgetting.
The teacher is the model checkpoint after the last experience.
"""
def __init__(self, EWC_lambda=1, temperature=2):
"""
:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
:param temperature: softmax temperature for distillation
"""
self.EWC_lambda = EWC_lambda
self.temperature = temperature
self.fisher = {}
self.optpar = {}
""" In Avalanche, targets of different experiences are not ordered.
As a result, some units may be allocated even though their
corresponding class has never been seen by the model.
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output, model, **kwargs):
ewc_loss = 0
for n, p in model.named_parameters():
if p.requires_grad:
dev = p.device
l = (
self.EWC_lambda
* self.fisher[n].to(dev)
* (p.data - self.optpar[n].to(dev)).pow(2)
)
ewc_loss += l.sum()
output["loss"] += ewc_loss
return output
def init_epoch(self, model):
"""Update the previous logits for the given question id."""
optpar = {}
fisher = {}
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher[n] = torch.zeros(p.data.shape)
optpar[n] = p.clone().cpu().data
def update_fisher(self, model):
"""Update the fisher information for the given question id."""
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher = self.fisher[n]
fisher += p.grad.data.pow(2).cpu()
self.fisher[n] = fisher