159 lines
5.9 KiB
Python
159 lines
5.9 KiB
Python
# _________________________________________________________
|
|
|
|
|
|
from transformers.trainer import *
|
|
from transformers import (
|
|
TrainingArguments,
|
|
)
|
|
from .args import ContiunalRegularizationArguments
|
|
from peft_library.regularizations import EWC, LWF
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
|
|
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
|
|
logits = outputs.logits
|
|
device = logits.device
|
|
# Shift so that tokens < n predict n
|
|
shift_logits = logits[..., :-1, :]
|
|
shift_labels = labels[..., 1:].to(device)
|
|
# Save memory
|
|
masks = shift_labels != -100
|
|
shift_logits = shift_logits[masks]
|
|
shift_labels = shift_labels[masks]
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss(reduction="none")
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
if num_items_in_batch is None:
|
|
loss = loss.mean()
|
|
else:
|
|
# compat transformers>=4.46
|
|
loss = loss.sum() / num_items_in_batch
|
|
return loss
|
|
|
|
|
|
class ContinualTrainer(Trainer):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
args: TrainingArguments,
|
|
data_collator,
|
|
train_dataset,
|
|
eval_dataset,
|
|
accelerator,
|
|
reg_args: ContiunalRegularizationArguments = None,
|
|
):
|
|
self.accelerator = accelerator
|
|
super().__init__(
|
|
model,
|
|
args,
|
|
data_collator,
|
|
train_dataset,
|
|
eval_dataset,
|
|
# compute_loss_func=ce_loss_func,
|
|
)
|
|
|
|
if reg_args.ewc_enable:
|
|
self.ewc_lambda = reg_args.ewc_lambda
|
|
from peft_library.regularizations.ewc import EWC
|
|
|
|
self.EWC = EWC()
|
|
# fisher = t
|
|
|
|
if reg_args.lwf_enable:
|
|
self.lwf_lambda = reg_args.lwf_lambda
|
|
from peft_library.regularizations.lwf import LWF
|
|
|
|
self.LWF = LWF()
|
|
|
|
def create_accelerator_and_postprocess(self):
|
|
|
|
if self.accelerator is not None:
|
|
self.is_deepspeed_enabled = (
|
|
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
|
)
|
|
self.is_fsdp_enabled = (
|
|
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
|
)
|
|
self.gather_function = self.accelerator.gather_for_metrics
|
|
return
|
|
else:
|
|
super().create_accelerator_and_postprocess()
|
|
|
|
def create_optimizer(self):
|
|
"""
|
|
Setup the optimizer.
|
|
|
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
|
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
|
"""
|
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
|
|
if self.optimizer is None:
|
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [
|
|
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n)
|
|
],
|
|
"weight_decay": self.args.weight_decay,
|
|
"lr": self.args.learning_rate,
|
|
},
|
|
{
|
|
"params": [
|
|
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n)
|
|
],
|
|
"weight_decay": self.args.weight_decay,
|
|
"lr": self.args.learning_rate / 10,
|
|
},
|
|
{
|
|
"params": [
|
|
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
|
],
|
|
"weight_decay": 0.0,
|
|
"lr": self.args.learning_rate,
|
|
},
|
|
]
|
|
|
|
if self.optimizer_cls_and_kwargs is not None:
|
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
else:
|
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
|
|
|
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
# e.g. for GaLore optimizer.
|
|
if "params" in optimizer_kwargs:
|
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
|
|
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
# e.g. for LOMO optimizer.
|
|
if "model" in optimizer_kwargs:
|
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
|
|
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
|
# to avoid arguments conflicts.
|
|
if "optimizer_dict" in optimizer_kwargs:
|
|
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
|
|
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
|
|
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
|
|
import bitsandbytes
|
|
|
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
|
|
skipped = 0
|
|
for module in opt_model.modules():
|
|
if isinstance(module, nn.Embedding):
|
|
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
|
logger.info(f"skipped {module}: {skipped / 2**20}M params")
|
|
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
|
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
logger.info(f"skipped: {skipped / 2**20}M params")
|
|
|
|
if is_sagemaker_mp_enabled():
|
|
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
|
|
|
return self.optimizer
|
|
|