from . import RegularizationMethod import torch class EWC(RegularizationMethod): """Learning Without Forgetting. The method applies knowledge distilllation to mitigate forgetting. The teacher is the model checkpoint after the last experience. """ def __init__(self, EWC_lambda=1, temperature=2): """ :param alpha: distillation hyperparameter. It can be either a float number or a list containing alpha for each experience. :param temperature: softmax temperature for distillation """ self.EWC_lambda = EWC_lambda self.temperature = temperature self.fisher = {} self.optpar = {} """ In Avalanche, targets of different experiences are not ordered. As a result, some units may be allocated even though their corresponding class has never been seen by the model. Knowledge distillation uses only units corresponding to old classes. """ def adapt(self, output, model, **kwargs): ewc_loss = 0 for n, p in model.named_parameters(): if p.requires_grad: dev = p.device l = ( self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2) ) ewc_loss += l.sum() output["loss"] += ewc_loss return output def init_epoch(self, model): """Update the previous logits for the given question id.""" optpar = {} fisher = {} for n, p in model.module.base_model.model.named_parameters(): if p.requires_grad: fisher[n] = torch.zeros(p.data.shape) optpar[n] = p.clone().cpu().data def update_fisher(self, model): """Update the fisher information for the given question id.""" for n, p in model.module.base_model.model.named_parameters(): if p.requires_grad: fisher = self.fisher[n] fisher += p.grad.data.pow(2).cpu() self.fisher[n] = fisher