from . import RegularizationMethod import torch class EWC(RegularizationMethod): """Learning Without Forgetting. The method applies knowledge distilllation to mitigate forgetting. The teacher is the model checkpoint after the last experience. """ def __init__(self, EWC_lambda=1, temperature=2): """ :param alpha: distillation hyperparameter. It can be either a float number or a list containing alpha for each experience. :param temperature: softmax temperature for distillation """ self.EWC_lambda = EWC_lambda self.temperature = temperature self.fisher = {} self.optpar = {} """ In Avalanche, targets of different experiences are not ordered. As a result, some units may be allocated even though their corresponding class has never been seen by the model. Knowledge distillation uses only units corresponding to old classes. """ def adapt(self, output,model, **kwargs): ewc_loss = 0 for n, p in model.named_parameters(): if p.requires_grad: dev = p.device l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2) ewc_loss += l.sum() output['loss'] += ewc_loss return output def init_epoch(self, model): """Update the previous logits for the given question id.""" optpar = {} fisher = {} for n, p in model.module.base_model.model.named_parameters(): if p.requires_grad: fisher[n] = torch.zeros(p.data.shape) optpar[n] = p.clone().cpu().data def update_fisher(self, model): """Update the fisher information for the given question id.""" for n, p in model.module.base_model.model.named_parameters(): if p.requires_grad: fisher = self.fisher[n] fisher += p.grad.data.pow(2).cpu() self.fisher[n] = fisher