# _________________________________________________________ from transformers.trainer import * from transformers import ( TrainingArguments, ) from .args import ContiunalRegularizationArguments from peft_library.regularizations import EWC, LWF from torch.nn import CrossEntropyLoss def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs): logits = outputs.logits device = logits.device # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:].to(device) # Save memory masks = shift_labels != -100 shift_logits = shift_logits[masks] shift_labels = shift_labels[masks] # Flatten the tokens loss_fct = CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits, shift_labels) if num_items_in_batch is None: loss = loss.mean() else: # compat transformers>=4.46 loss = loss.sum() / num_items_in_batch return loss class ContinualTrainer(Trainer): def __init__( self, model, args: TrainingArguments, data_collator, train_dataset, eval_dataset, accelerator, reg_args: ContiunalRegularizationArguments = None, ): self.accelerator = accelerator super().__init__( model, args, data_collator, train_dataset, eval_dataset, # compute_loss_func=ce_loss_func, ) if reg_args.ewc_enable: self.ewc_lambda = reg_args.ewc_lambda from peft_library.regularizations.ewc import EWC self.EWC = EWC() # fisher = t if reg_args.lwf_enable: self.lwf_lambda = reg_args.lwf_lambda from peft_library.regularizations.lwf import LWF self.LWF = LWF() def create_accelerator_and_postprocess(self): if self.accelerator is not None: self.is_deepspeed_enabled = ( getattr(self.accelerator.state, "deepspeed_plugin", None) is not None ) self.is_fsdp_enabled = ( getattr(self.accelerator.state, "fsdp_plugin", None) is not None ) self.gather_function = self.accelerator.gather_for_metrics return else: super().create_accelerator_and_postprocess() def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: decay_parameters = self.get_decay_parameter_names(opt_model) optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n) ], "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate, }, { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n) ], "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate / 10, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.learning_rate, }, ] if self.optimizer_cls_and_kwargs is not None: optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs else: optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for GaLore optimizer. if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("params") # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for LOMO optimizer. if "model" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("model") # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` # to avoid arguments conflicts. if "optimizer_dict" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8: import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped / 2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped / 2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer