cl-lmm/src/utils/trainer.py

159 lines
5.9 KiB
Python

# _________________________________________________________
from transformers.trainer import *
from transformers import (
TrainingArguments,
)
from .args import ContiunalRegularizationArguments
from peft_library.regularizations import EWC, LWF
from torch.nn import CrossEntropyLoss
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
logits = outputs.logits
device = logits.device
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:].to(device)
# Save memory
masks = shift_labels != -100
shift_logits = shift_logits[masks]
shift_labels = shift_labels[masks]
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits, shift_labels)
if num_items_in_batch is None:
loss = loss.mean()
else:
# compat transformers>=4.46
loss = loss.sum() / num_items_in_batch
return loss
class ContinualTrainer(Trainer):
def __init__(
self,
model,
args: TrainingArguments,
data_collator,
train_dataset,
eval_dataset,
accelerator,
reg_args: ContiunalRegularizationArguments = None,
):
self.accelerator = accelerator
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
# compute_loss_func=ce_loss_func,
)
if reg_args.ewc_enable:
self.ewc_lambda = reg_args.ewc_lambda
from peft_library.regularizations.ewc import EWC
self.EWC = EWC()
# fisher = t
if reg_args.lwf_enable:
self.lwf_lambda = reg_args.lwf_lambda
from peft_library.regularizations.lwf import LWF
self.LWF = LWF()
def create_accelerator_and_postprocess(self):
if self.accelerator is not None:
self.is_deepspeed_enabled = (
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
)
self.is_fsdp_enabled = (
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
)
self.gather_function = self.accelerator.gather_for_metrics
return
else:
super().create_accelerator_and_postprocess()
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate / 10,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.learning_rate,
},
]
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped / 2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped / 2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer