59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
from . import RegularizationMethod
|
|
import torch
|
|
|
|
|
|
class EWC(RegularizationMethod):
|
|
"""Learning Without Forgetting.
|
|
|
|
The method applies knowledge distilllation to mitigate forgetting.
|
|
The teacher is the model checkpoint after the last experience.
|
|
"""
|
|
|
|
def __init__(self, EWC_lambda=1, temperature=2):
|
|
"""
|
|
:param alpha: distillation hyperparameter. It can be either a float
|
|
number or a list containing alpha for each experience.
|
|
:param temperature: softmax temperature for distillation
|
|
"""
|
|
self.EWC_lambda = EWC_lambda
|
|
self.temperature = temperature
|
|
self.fisher = {}
|
|
self.optpar = {}
|
|
""" In Avalanche, targets of different experiences are not ordered.
|
|
As a result, some units may be allocated even though their
|
|
corresponding class has never been seen by the model.
|
|
Knowledge distillation uses only units corresponding
|
|
to old classes.
|
|
"""
|
|
|
|
def adapt(self, output, model, **kwargs):
|
|
ewc_loss = 0
|
|
for n, p in model.named_parameters():
|
|
if p.requires_grad:
|
|
dev = p.device
|
|
l = (
|
|
self.EWC_lambda
|
|
* self.fisher[n].to(dev)
|
|
* (p.data - self.optpar[n].to(dev)).pow(2)
|
|
)
|
|
ewc_loss += l.sum()
|
|
output["loss"] += ewc_loss
|
|
return output
|
|
|
|
def init_epoch(self, model):
|
|
"""Update the previous logits for the given question id."""
|
|
optpar = {}
|
|
fisher = {}
|
|
for n, p in model.module.base_model.model.named_parameters():
|
|
if p.requires_grad:
|
|
fisher[n] = torch.zeros(p.data.shape)
|
|
optpar[n] = p.clone().cpu().data
|
|
|
|
def update_fisher(self, model):
|
|
"""Update the fisher information for the given question id."""
|
|
for n, p in model.module.base_model.model.named_parameters():
|
|
if p.requires_grad:
|
|
fisher = self.fisher[n]
|
|
fisher += p.grad.data.pow(2).cpu()
|
|
self.fisher[n] = fisher
|