From c100a59f0e6c2661d737ff8cece7a5154c52bd31 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Fri, 16 May 2025 19:50:41 +0800 Subject: [PATCH] Refactor code structure for improved readability and maintainability --- .vscode/settings.json | 6 +- src/peft_library/regularizations/__init__.py | 15 + src/peft_library/regularizations/ewc.py | 51 + src/peft_library/regularizations/lwf.py | 54 + src/train.py | 2 +- src/utils/args.py | 14 + src/utils/trainer.py | 1450 +++++++++--------- 7 files changed, 874 insertions(+), 718 deletions(-) create mode 100644 src/peft_library/regularizations/__init__.py create mode 100644 src/peft_library/regularizations/ewc.py create mode 100644 src/peft_library/regularizations/lwf.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 15e2ee1..0ddb8eb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,6 +4,8 @@ "src/peft_repo/src/" ], "python.analysis.exclude": [ - "dataset/**" - ] + "dataset/**/*" + ], + "python.languageServer": "Default", + "python.terminal.activateEnvInCurrentTerminal": true } \ No newline at end of file diff --git a/src/peft_library/regularizations/__init__.py b/src/peft_library/regularizations/__init__.py new file mode 100644 index 0000000..f1b5bba --- /dev/null +++ b/src/peft_library/regularizations/__init__.py @@ -0,0 +1,15 @@ + +class RegularizationMethod: + """RegularizationMethod implement regularization strategies. + RegularizationMethod is a callable. + The method `update` is called to update the loss, typically at the end + of an experience. + """ + def pre_adapt(self, agent, exp): + pass # implementation may be empty if adapt is not needed + + def post_adapt(self, agent, exp): + pass # implementation may be empty if adapt is not needed + + def __call__(self, *args, **kwargs): + raise NotImplementedError() \ No newline at end of file diff --git a/src/peft_library/regularizations/ewc.py b/src/peft_library/regularizations/ewc.py new file mode 100644 index 0000000..530fdd5 --- /dev/null +++ b/src/peft_library/regularizations/ewc.py @@ -0,0 +1,51 @@ +from . import RegularizationMethod +import torch + +class EWC(RegularizationMethod): + """Learning Without Forgetting. + + The method applies knowledge distilllation to mitigate forgetting. + The teacher is the model checkpoint after the last experience. + """ + + def __init__(self, EWC_lambda=1, temperature=2): + """ + :param alpha: distillation hyperparameter. It can be either a float + number or a list containing alpha for each experience. + :param temperature: softmax temperature for distillation + """ + self.EWC_lambda = EWC_lambda + self.temperature = temperature + self.fisher = {} + self.optpar = {} + """ In Avalanche, targets of different experiences are not ordered. + As a result, some units may be allocated even though their + corresponding class has never been seen by the model. + Knowledge distillation uses only units corresponding + to old classes. + """ + def adapt(self, output,model, **kwargs): + ewc_loss = 0 + for n, p in model.named_parameters(): + if p.requires_grad: + dev = p.device + l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2) + ewc_loss += l.sum() + output['loss'] += ewc_loss + return output + + def init_epoch(self, model): + """Update the previous logits for the given question id.""" + optpar = {} + fisher = {} + for n, p in model.module.base_model.model.named_parameters(): + if p.requires_grad: + fisher[n] = torch.zeros(p.data.shape) + optpar[n] = p.clone().cpu().data + def update_fisher(self, model): + """Update the fisher information for the given question id.""" + for n, p in model.module.base_model.model.named_parameters(): + if p.requires_grad: + fisher = self.fisher[n] + fisher += p.grad.data.pow(2).cpu() + self.fisher[n] = fisher \ No newline at end of file diff --git a/src/peft_library/regularizations/lwf.py b/src/peft_library/regularizations/lwf.py new file mode 100644 index 0000000..e93b348 --- /dev/null +++ b/src/peft_library/regularizations/lwf.py @@ -0,0 +1,54 @@ +from . import RegularizationMethod +import torch + +class LWF(RegularizationMethod): + """Learning Without Forgetting. + + The method applies knowledge distilllation to mitigate forgetting. + The teacher is the model checkpoint after the last experience. + """ + + def __init__(self, LWF_lambda=1, temperature=2): + """ + :param alpha: distillation hyperparameter. It can be either a float + number or a list containing alpha for each experience. + :param temperature: softmax temperature for distillation + """ + self.LWF_lambda = LWF_lambda + self.temperature = temperature + self.previous_logits = {} + """ In Avalanche, targets of different experiences are not ordered. + As a result, some units may be allocated even though their + corresponding class has never been seen by the model. + Knowledge distillation uses only units corresponding + to old classes. + """ + def adapt(self, output, **kwargs): + def modified_kl_div(old, new): + return -torch.mean(torch.sum(old * torch.log(new), 1)) + + def smooth(logits, temp, dim): + log = logits ** (1 / temp) + return log / torch.sum(log, dim).unsqueeze(1) + + lwf_loss = [] + + soft = torch.nn.Softmax(dim=1) + + previous_keys = self.previous_logits.keys() + + for index, question_id in enumerate(iterable=kwargs['question_ids']): + if question_id in previous_keys: + previous_logits = self.previous_logits[question_id] + current_logits = output['logits'][index] + short_index = min(len(previous_logits), len(current_logits)) + previous_logits = previous_logits[:short_index] + current_logits = current_logits[:short_index] + lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1))) + if len(lwf_loss) > 0: + output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0) + return output + + def update_previous_logits(self, question_id, logits): + """Update the previous logits for the given question id.""" + self.previous_logits[question_id] = logits diff --git a/src/train.py b/src/train.py index e7acb0c..5499094 100644 --- a/src/train.py +++ b/src/train.py @@ -47,7 +47,7 @@ if __name__ == "__main__": accelerator = create_accelerator_and_postprocess(training_args) if model_args.peft_type == "MMOELORA": - from peft_library.tuners import MMOELoraConfig + from peft.tuners import MMOELoraConfig peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules) diff --git a/src/utils/args.py b/src/utils/args.py index e0b1fcf..d7a81dc 100644 --- a/src/utils/args.py +++ b/src/utils/args.py @@ -18,3 +18,17 @@ class ContinualModelConfig(ModelConfig): """Model configuration for continual learning.""" peft_type: Optional[str] = None + + +@dataclass +class ContiunalRegularizationArguments: + """Regularization arguments for continual learning.""" + + # EWC + ewc_lambda: float = 0.0 + ewc_enable: bool = False + + # LWF + lwf_lambda: float = 0.0 + lwf_enable: bool = False + \ No newline at end of file diff --git a/src/utils/trainer.py b/src/utils/trainer.py index 192cd36..4ad0d75 100644 --- a/src/utils/trainer.py +++ b/src/utils/trainer.py @@ -11,16 +11,36 @@ from transformers.trainer import ( sys, ) from transformers.trainer import * +from transformers import ( + TrainingArguments, +) +from .args import ContiunalRegularizationArguments class ContinualTrainer(Trainer): def __init__( - self, model, args, data_collator, train_dataset, eval_dataset, accelerator + self, + model, + args: TrainingArguments, + data_collator, + train_dataset, + eval_dataset, + accelerator, + regularization_args:ContiunalRegularizationArguments=None, ): self.accelerator = accelerator super().__init__(model, args, data_collator, train_dataset, eval_dataset) + if regularization_args.ewc_enable: + self.ewc_lambda = regularization_args.ewc_lambda + # fisher = t + + if regularization_args.lwf_enable: + pass + + def create_accelerator_and_postprocess(self): + if self.accelerator is not None: self.is_deepspeed_enabled = ( getattr(self.accelerator.state, "deepspeed_plugin", None) is not None @@ -33,717 +53,717 @@ class ContinualTrainer(Trainer): else: super().create_accelerator_and_postprocess() - def compute_loss( - self, model, inputs, return_outputs=False, num_items_in_batch=None - ): - """ - How the loss is computed by Trainer. By default, all models return the loss in the first element. - - Subclass and override for custom behavior. - """ - if ( - self.label_smoother is not None or self.compute_loss_func is not None - ) and "labels" in inputs: - labels = inputs.pop("labels") - else: - labels = None - if self.model_accepts_loss_kwargs: - loss_kwargs = {} - if num_items_in_batch is not None: - loss_kwargs["num_items_in_batch"] = num_items_in_batch - inputs = {**inputs, **loss_kwargs} - outputs = model(**inputs) - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index] - - if labels is not None: - unwrapped_model = self.accelerator.unwrap_model(model) - if _is_peft_model(unwrapped_model): - model_name = unwrapped_model.base_model.model._get_name() - else: - model_name = unwrapped_model._get_name() - # User-defined compute_loss function - if self.compute_loss_func is not None: - loss = self.compute_loss_func( - outputs, labels, num_items_in_batch=num_items_in_batch - ) - elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): - loss = self.label_smoother(outputs, labels, shift_labels=True) - else: - loss = self.label_smoother(outputs, labels) - else: - if isinstance(outputs, dict) and "loss" not in outputs: - raise ValueError( - "The model did not return a loss from the inputs, only the following keys: " - f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." - ) - # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] - - return (loss, outputs) if return_outputs else loss - - def _inner_training_loop( - self, - batch_size=None, - args=None, - resume_from_checkpoint=None, - trial=None, - ignore_keys_for_eval=None, - ): - self.accelerator.free_memory() - self._train_batch_size = batch_size - if self.args.auto_find_batch_size: - if self.state.train_batch_size != self._train_batch_size: - from accelerate.utils import release_memory - - (self.model_wrapped,) = release_memory(self.model_wrapped) - self.model_wrapped = self.model - - # Check for DeepSpeed *after* the intial pass and modify the config - if self.is_deepspeed_enabled: - # Temporarily unset `self.args.train_batch_size` - original_bs = self.args.per_device_train_batch_size - self.args.per_device_train_batch_size = ( - self._train_batch_size // max(1, self.args.n_gpu) - ) - self.propagate_args_to_deepspeed(True) - self.args.per_device_train_batch_size = original_bs - self.state.train_batch_size = self._train_batch_size - logger.debug( - f"Currently training with a batch size of: {self._train_batch_size}" - ) - # Data loader and number of training steps - train_dataloader = self.get_train_dataloader() - if self.is_fsdp_xla_v2_enabled: - train_dataloader = tpu_spmd_dataloader(train_dataloader) - - # Setting up training control variables: - # number of training epochs: num_train_epochs - # number of training steps per epoch: num_update_steps_per_epoch - # total number of training steps to execute: max_steps - total_train_batch_size = ( - self._train_batch_size * args.gradient_accumulation_steps * args.world_size - ) - - len_dataloader = None - num_train_tokens = None - if has_length(train_dataloader): - len_dataloader = len(train_dataloader) - num_update_steps_per_epoch = ( - len_dataloader // args.gradient_accumulation_steps - ) - num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) - num_examples = self.num_examples(train_dataloader) - if args.max_steps > 0: - max_steps = args.max_steps - num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( - args.max_steps % num_update_steps_per_epoch > 0 - ) - # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's - # the best we can do. - num_train_samples = args.max_steps * total_train_batch_size - if args.include_tokens_per_second: - num_train_tokens = ( - self.num_tokens(train_dataloader, args.max_steps) - * args.gradient_accumulation_steps - ) - else: - max_steps = math.ceil( - args.num_train_epochs * num_update_steps_per_epoch - ) - num_train_epochs = math.ceil(args.num_train_epochs) - num_train_samples = ( - self.num_examples(train_dataloader) * args.num_train_epochs - ) - if args.include_tokens_per_second: - num_train_tokens = ( - self.num_tokens(train_dataloader) * args.num_train_epochs - ) - elif ( - args.max_steps > 0 - ): # Rely on max_steps when dataloader does not have a working size - max_steps = args.max_steps - # Setting a very large number of epochs so we go as many times as necessary over the iterator. - num_train_epochs = sys.maxsize - num_update_steps_per_epoch = max_steps - num_examples = total_train_batch_size * args.max_steps - num_train_samples = args.max_steps * total_train_batch_size - if args.include_tokens_per_second: - num_train_tokens = ( - self.num_tokens(train_dataloader, args.max_steps) - * args.gradient_accumulation_steps - ) - else: - raise ValueError( - "args.max_steps must be set to a positive value if dataloader does not have a length, was" - f" {args.max_steps}" - ) - - if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: - if self.args.n_gpu > 1: - # nn.DataParallel(model) replicates the model, creating new variables and module - # references registered here no longer work on other gpus, breaking the module - raise ValueError( - "Currently --debug underflow_overflow is not supported under DP. Please use DDP" - " (torchrun or torch.distributed.launch (deprecated))." - ) - else: - debug_overflow = DebugUnderflowOverflow(self.model) # noqa - - delay_optimizer_creation = ( - is_sagemaker_mp_enabled() - or self.is_fsdp_xla_enabled - or self.is_fsdp_enabled - ) - - # We need to reset the scheduler, as its parameters may be different on subsequent calls - if self._created_lr_scheduler: - self.lr_scheduler = None - self._created_lr_scheduler = False - - if self.is_deepspeed_enabled: - self.optimizer, self.lr_scheduler = deepspeed_init( - self, num_training_steps=max_steps - ) - - if not delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - - self.state = TrainerState( - stateful_callbacks=[ - cb - for cb in self.callback_handler.callbacks + [self.control] - if isinstance(cb, ExportableState) - ] - ) - self.state.is_hyper_param_search = trial is not None - self.state.train_batch_size = self._train_batch_size - - # Compute absolute values for logging, eval, and save if given as ratio - if args.logging_steps is not None: - if args.logging_steps < 1: - self.state.logging_steps = math.ceil(max_steps * args.logging_steps) - else: - self.state.logging_steps = args.logging_steps - if args.eval_steps is not None: - if args.eval_steps < 1: - self.state.eval_steps = math.ceil(max_steps * args.eval_steps) - else: - self.state.eval_steps = args.eval_steps - if args.save_steps is not None: - if args.save_steps < 1: - self.state.save_steps = math.ceil(max_steps * args.save_steps) - else: - self.state.save_steps = args.save_steps - - # Activate gradient checkpointing if needed - if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs - ) - - model = self._wrap_model(self.model_wrapped) - - # as the model is wrapped, don't use `accelerator.prepare` - # this is for unhandled cases such as - # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX - use_accelerator_prepare = True if model is self.model else False - - if use_accelerator_prepare and self.is_fsdp_enabled: - # In case of auto_find_batch_size=True - # Remove FSDP wrapping from sub-models. - self.model = unwrap_model(self.model, recursive=True) - - if delay_optimizer_creation: - if use_accelerator_prepare: - # configure fsdp plugin for qlora if any - self._fsdp_qlora_plugin_updates() - if self.accelerator.mixed_precision != "fp8": - self.model = self.accelerator.prepare(self.model) - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - - # prepare using `accelerator` prepare - if use_accelerator_prepare: - self.model.train() - if hasattr(self.lr_scheduler, "step"): - if self.use_apex: - model = self.accelerator.prepare(self.model) - else: - model, self.optimizer = self.accelerator.prepare( - self.model, self.optimizer - ) - else: - # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. - model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.lr_scheduler - ) - elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - # In this case we are in DDP + LOMO, which should be supported - self.optimizer = self.accelerator.prepare(self.optimizer) - - if self.is_fsdp_enabled: - self.model = self.model_wrapped = model - - # for the rest of this function `model` is the outside model, whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model - - # backward compatibility - if self.is_deepspeed_enabled: - self.deepspeed = self.model_wrapped - - # ckpt loading - if resume_from_checkpoint is not None: - if self.is_deepspeed_enabled: - deepspeed_load_checkpoint( - self.model_wrapped, - resume_from_checkpoint, - load_module_strict=not _is_peft_model(self.model), - ) - elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: - self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) - - # Check if saved optimizer or scheduler states exist - self._load_optimizer_and_scheduler(resume_from_checkpoint) - - # important: at this point: - # self.model is the Transformers Model - # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), - # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. - - # Train! - logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples:,}") - logger.info(f" Num Epochs = {num_train_epochs:,}") - logger.info( - f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}" - ) - if self.args.per_device_train_batch_size != self._train_batch_size: - logger.info( - f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}" - ) - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}" - ) - logger.info( - f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" - ) - logger.info(f" Total optimization steps = {max_steps:,}") - logger.info( - f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}" - ) - - self.state.epoch = 0 - start_time = time.time() - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - steps_trained_progress_bar = None - - # Check if continuing training from a checkpoint - if resume_from_checkpoint is not None and os.path.isfile( - os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) - ): - self.state = TrainerState.load_from_json( - os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) - ) - self.compare_trainer_and_checkpoint_args(self.args, self.state) - self._load_callback_state() - epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) - if not args.ignore_data_skip: - steps_trained_in_current_epoch = self.state.global_step % ( - num_update_steps_per_epoch - ) - steps_trained_in_current_epoch *= args.gradient_accumulation_steps - else: - steps_trained_in_current_epoch = 0 - - logger.info( - " Continuing training from checkpoint, will skip to saved global_step" - ) - logger.info(f" Continuing training from epoch {epochs_trained}") - logger.info( - f" Continuing training from global step {self.state.global_step}" - ) - if not args.ignore_data_skip: - logger.info( - f" Will skip the first {epochs_trained} epochs then the first" - f" {steps_trained_in_current_epoch} batches in the first epoch." - ) - - # Update the references - self.callback_handler.model = self.model - self.callback_handler.optimizer = self.optimizer - self.callback_handler.lr_scheduler = self.lr_scheduler - self.callback_handler.train_dataloader = train_dataloader - if self.hp_name is not None and self._trial is not None: - # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial - # parameter to Train when using DDP. - self.state.trial_name = self.hp_name(self._trial) - if trial is not None: - assignments = ( - trial.assignments - if self.hp_search_backend == HPSearchBackend.SIGOPT - else trial - ) - self.state.trial_params = hp_params(assignments) - else: - self.state.trial_params = None - # This should be the same if the state has been saved but in case the training arguments changed, it's safer - # to set this after the load. - self.state.max_steps = max_steps - self.state.num_train_epochs = num_train_epochs - self.state.is_local_process_zero = self.is_local_process_zero() - self.state.is_world_process_zero = self.is_world_process_zero() - - # tr_loss is a tensor to avoid synchronization of TPUs through .item() - tr_loss = torch.tensor(0.0).to(args.device) - # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses - self._total_loss_scalar = 0.0 - self._globalstep_last_logged = self.state.global_step - model.zero_grad() - grad_norm: Optional[float] = None - self.control = self.callback_handler.on_train_begin( - args, self.state, self.control - ) - - if args.eval_on_start: - self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - - for epoch in range(epochs_trained, num_train_epochs): - epoch_dataloader = train_dataloader - if hasattr(epoch_dataloader, "set_epoch"): - epoch_dataloader.set_epoch(epoch) - - # Reset the past mems state at the beginning of each epoch if necessary. - if args.past_index >= 0: - self._past = None - - steps_in_epoch = ( - len(epoch_dataloader) - if len_dataloader is not None - else args.max_steps * args.gradient_accumulation_steps - ) - self.control = self.callback_handler.on_epoch_begin( - args, self.state, self.control - ) - - if ( - epoch == epochs_trained - and resume_from_checkpoint is not None - and steps_trained_in_current_epoch == 0 - ): - self._load_rng_state(resume_from_checkpoint) - - rng_to_sync = False - steps_skipped = 0 - if steps_trained_in_current_epoch > 0: - epoch_dataloader = skip_first_batches( - epoch_dataloader, steps_trained_in_current_epoch - ) - steps_skipped = steps_trained_in_current_epoch - steps_trained_in_current_epoch = 0 - rng_to_sync = True - - step = -1 - epoch_iterator = iter(epoch_dataloader) - # We chunkify the epoch iterator into gradient accumulation steps `n` batches - remainder = num_examples % args.gradient_accumulation_steps - if remainder == 0: - remainder = args.gradient_accumulation_steps - update_step = -1 - total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 - for _ in range(total_updates): - update_step += 1 - num_batches = ( - args.gradient_accumulation_steps - if update_step != (total_updates - 1) - else remainder - ) - batch_samples, num_items_in_batch = self.get_batch_samples( - epoch_iterator, num_batches - ) - for i, inputs in enumerate(batch_samples): - step += 1 - do_sync_step = ( - step + 1 - ) % args.gradient_accumulation_steps == 0 or ( - step + 1 - ) == steps_in_epoch - # Since we perform prefetching, we need to manually set sync_gradients - if not do_sync_step: - self.accelerator.gradient_state._set_sync_gradients(False) - else: - self.accelerator.gradient_state._set_sync_gradients(True) - - if self.args.include_num_input_tokens_seen: - main_input_name = getattr( - self.model, "main_input_name", "input_ids" - ) - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) - else: - input_tokens = inputs[main_input_name].numel() - input_tokens = torch.tensor( - input_tokens, device=self.args.device, dtype=torch.int64 - ) - self.state.num_input_tokens_seen += ( - self.accelerator.gather(input_tokens).sum().cpu().item() - ) - if rng_to_sync: - self._load_rng_state(resume_from_checkpoint) - rng_to_sync = False - - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - if steps_trained_progress_bar is not None: - steps_trained_progress_bar.update(1) - if steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) - continue - elif steps_trained_progress_bar is not None: - steps_trained_progress_bar.close() - steps_trained_progress_bar = None - - if step % args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin( - args, self.state, self.control - ) - - # We explicitly want to avoid relying on `accelerator.accumulate` for generation training - context = ( - functools.partial(self.accelerator.no_sync, model=model) - if i != len(batch_samples) - 1 - and self.accelerator.distributed_type - != DistributedType.DEEPSPEED - else contextlib.nullcontext - ) - with context(): - tr_loss_step = self.training_step( - model, inputs, num_items_in_batch - ) - - if ( - args.logging_nan_inf_filter - and not is_torch_xla_available() - and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) - ): - # if loss is nan or inf simply add the average of previous logged losses - tr_loss = tr_loss + tr_loss / ( - 1 + self.state.global_step - self._globalstep_last_logged - ) - else: - if tr_loss.device != tr_loss_step.device: - raise ValueError( - f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" - ) - tr_loss = tr_loss + tr_loss_step - - self.current_flos += float(self.floating_point_ops(inputs)) - - if do_sync_step: - # Since we perform prefetching, we need to manually set sync_gradients to True - self.accelerator.gradient_state._set_sync_gradients(True) - - # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0: - # deepspeed does its own clipping - - if is_sagemaker_mp_enabled() and args.fp16: - _grad_norm = self.optimizer.clip_master_grads( - args.max_grad_norm - ) - elif self.use_apex: - # Revert to normal clipping otherwise, handling Apex or full precision - _grad_norm = nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer), - args.max_grad_norm, - ) - else: - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, - ) - - if ( - is_accelerate_available() - and self.accelerator.distributed_type - == DistributedType.DEEPSPEED - ): - grad_norm = model.get_global_grad_norm() - # In some cases the grad norm may not return a float - if hasattr(grad_norm, "item"): - grad_norm = grad_norm.item() - else: - grad_norm = _grad_norm - - self.control = self.callback_handler.on_pre_optimizer_step( - args, self.state, self.control - ) - - self.optimizer.step() - - self.control = self.callback_handler.on_optimizer_step( - args, self.state, self.control - ) - - optimizer_was_run = ( - not self.accelerator.optimizer_step_was_skipped - ) - if optimizer_was_run: - # Delay optimizer scheduling until metrics are generated - if not isinstance( - self.lr_scheduler, - torch.optim.lr_scheduler.ReduceLROnPlateau, - ): - self.lr_scheduler.step() - - model.zero_grad() - self.state.global_step += 1 - self.state.epoch = ( - epoch + (step + 1 + steps_skipped) / steps_in_epoch - ) - self.control = self.callback_handler.on_step_end( - args, self.state, self.control - ) - self._maybe_log_save_evaluate( - tr_loss, - grad_norm, - model, - trial, - epoch, - ignore_keys_for_eval, - start_time, - ) - else: - self.control = self.callback_handler.on_substep_end( - args, self.state, self.control - ) - - # PyTorch/XLA relies on the data loader to insert the mark_step for - # each step. Since we are breaking the loop early, we need to manually - # insert the mark_step here. - if ( - self.control.should_epoch_stop - or self.control.should_training_stop - ): - if is_torch_xla_available(): - xm.mark_step() - break - # We also need to break out of the nested loop - if self.control.should_epoch_stop or self.control.should_training_stop: - if is_torch_xla_available(): - xm.mark_step() - break - if step < 0: - logger.warning( - "There seems not to be a single sample in your epoch_iterator, stopping training at step" - f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" - f" num_steps ({max_steps}) higher than the number of available samples." - ) - self.control.should_training_stop = True - - self.control = self.callback_handler.on_epoch_end( - args, self.state, self.control - ) - self._maybe_log_save_evaluate( - tr_loss, - grad_norm, - model, - trial, - epoch, - ignore_keys_for_eval, - start_time, - ) - - if DebugOption.TPU_METRICS_DEBUG in self.args.debug: - if is_torch_xla_available(): - # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) - xm.master_print(met.metrics_report()) - else: - logger.warning( - "You enabled PyTorch/XLA debug metrics but you don't have a TPU " - "configured. Check your training configuration if this is unexpected." - ) - if self.control.should_training_stop: - break - - if args.past_index and hasattr(self, "_past"): - # Clean the state at the end of training - delattr(self, "_past") - - logger.info( - "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" - ) - if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: - # Wait for everyone to get here so we are sure the model has been saved by process 0. - if is_torch_xla_available(): - xm.rendezvous("load_best_model_at_end") - elif args.parallel_mode == ParallelMode.DISTRIBUTED: - dist.barrier() - elif is_sagemaker_mp_enabled(): - smp.barrier() - - self._load_best_model() - - # add remaining tr_loss - self._total_loss_scalar += tr_loss.item() - effective_global_step = max( - self.state.global_step, 0.001 - ) # Avoid ZeroDivisionError - train_loss = self._total_loss_scalar / effective_global_step - - metrics = speed_metrics( - "train", - start_time, - num_samples=num_train_samples, - num_steps=self.state.max_steps, - num_tokens=num_train_tokens, - ) - self.store_flos() - metrics["total_flos"] = self.state.total_flos - metrics["train_loss"] = train_loss - - self.is_in_train = False - - self._memory_tracker.stop_and_update_metrics(metrics) - - self.log(metrics) - - run_dir = self._get_output_dir(trial) - checkpoints_sorted = self._sorted_checkpoints( - use_mtime=False, output_dir=run_dir - ) - - # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. - if ( - self.args.should_save - and self.state.best_model_checkpoint is not None - and self.args.save_total_limit == 1 - ): - for checkpoint in checkpoints_sorted: - if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): - logger.info( - f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit" - ) - shutil.rmtree(checkpoint, ignore_errors=True) - - self.control = self.callback_handler.on_train_end( - args, self.state, self.control - ) - - # Wait for the checkpoint to be uploaded. - self._finish_current_push() - - # After training we make sure to retrieve back the original forward pass method - # for the embedding layer by removing the forward post hook. - if self.neftune_noise_alpha is not None: - self._deactivate_neftune(self.model) - - return TrainOutput(self.state.global_step, train_loss, metrics) + # def compute_loss( + # self, model, inputs, return_outputs=False, num_items_in_batch=None + # ): + # """ + # How the loss is computed by Trainer. By default, all models return the loss in the first element. + + # Subclass and override for custom behavior. + # """ + # if ( + # self.label_smoother is not None or self.compute_loss_func is not None + # ) and "labels" in inputs: + # labels = inputs.pop("labels") + # else: + # labels = None + # if self.model_accepts_loss_kwargs: + # loss_kwargs = {} + # if num_items_in_batch is not None: + # loss_kwargs["num_items_in_batch"] = num_items_in_batch + # inputs = {**inputs, **loss_kwargs} + # outputs = model(**inputs) + # # Save past state if it exists + # # TODO: this needs to be fixed and made cleaner later. + # if self.args.past_index >= 0: + # self._past = outputs[self.args.past_index] + + # if labels is not None: + # unwrapped_model = self.accelerator.unwrap_model(model) + # if _is_peft_model(unwrapped_model): + # model_name = unwrapped_model.base_model.model._get_name() + # else: + # model_name = unwrapped_model._get_name() + # # User-defined compute_loss function + # if self.compute_loss_func is not None: + # loss = self.compute_loss_func( + # outputs, labels, num_items_in_batch=num_items_in_batch + # ) + # elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + # loss = self.label_smoother(outputs, labels, shift_labels=True) + # else: + # loss = self.label_smoother(outputs, labels) + # else: + # if isinstance(outputs, dict) and "loss" not in outputs: + # raise ValueError( + # "The model did not return a loss from the inputs, only the following keys: " + # f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + # ) + # # We don't use .loss here since the model may return tuples instead of ModelOutput. + # loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + # return (loss, outputs) if return_outputs else loss + + # def _inner_training_loop( + # self, + # batch_size=None, + # args=None, + # resume_from_checkpoint=None, + # trial=None, + # ignore_keys_for_eval=None, + # ): + # self.accelerator.free_memory() + # self._train_batch_size = batch_size + # if self.args.auto_find_batch_size: + # if self.state.train_batch_size != self._train_batch_size: + # from accelerate.utils import release_memory + + # (self.model_wrapped,) = release_memory(self.model_wrapped) + # self.model_wrapped = self.model + + # # Check for DeepSpeed *after* the intial pass and modify the config + # if self.is_deepspeed_enabled: + # # Temporarily unset `self.args.train_batch_size` + # original_bs = self.args.per_device_train_batch_size + # self.args.per_device_train_batch_size = ( + # self._train_batch_size // max(1, self.args.n_gpu) + # ) + # self.propagate_args_to_deepspeed(True) + # self.args.per_device_train_batch_size = original_bs + # self.state.train_batch_size = self._train_batch_size + # logger.debug( + # f"Currently training with a batch size of: {self._train_batch_size}" + # ) + # # Data loader and number of training steps + # train_dataloader = self.get_train_dataloader() + # if self.is_fsdp_xla_v2_enabled: + # train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # # Setting up training control variables: + # # number of training epochs: num_train_epochs + # # number of training steps per epoch: num_update_steps_per_epoch + # # total number of training steps to execute: max_steps + # total_train_batch_size = ( + # self._train_batch_size * args.gradient_accumulation_steps * args.world_size + # ) + + # len_dataloader = None + # num_train_tokens = None + # if has_length(train_dataloader): + # len_dataloader = len(train_dataloader) + # num_update_steps_per_epoch = ( + # len_dataloader // args.gradient_accumulation_steps + # ) + # num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + # num_examples = self.num_examples(train_dataloader) + # if args.max_steps > 0: + # max_steps = args.max_steps + # num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + # args.max_steps % num_update_steps_per_epoch > 0 + # ) + # # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # # the best we can do. + # num_train_samples = args.max_steps * total_train_batch_size + # if args.include_tokens_per_second: + # num_train_tokens = ( + # self.num_tokens(train_dataloader, args.max_steps) + # * args.gradient_accumulation_steps + # ) + # else: + # max_steps = math.ceil( + # args.num_train_epochs * num_update_steps_per_epoch + # ) + # num_train_epochs = math.ceil(args.num_train_epochs) + # num_train_samples = ( + # self.num_examples(train_dataloader) * args.num_train_epochs + # ) + # if args.include_tokens_per_second: + # num_train_tokens = ( + # self.num_tokens(train_dataloader) * args.num_train_epochs + # ) + # elif ( + # args.max_steps > 0 + # ): # Rely on max_steps when dataloader does not have a working size + # max_steps = args.max_steps + # # Setting a very large number of epochs so we go as many times as necessary over the iterator. + # num_train_epochs = sys.maxsize + # num_update_steps_per_epoch = max_steps + # num_examples = total_train_batch_size * args.max_steps + # num_train_samples = args.max_steps * total_train_batch_size + # if args.include_tokens_per_second: + # num_train_tokens = ( + # self.num_tokens(train_dataloader, args.max_steps) + # * args.gradient_accumulation_steps + # ) + # else: + # raise ValueError( + # "args.max_steps must be set to a positive value if dataloader does not have a length, was" + # f" {args.max_steps}" + # ) + + # if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + # if self.args.n_gpu > 1: + # # nn.DataParallel(model) replicates the model, creating new variables and module + # # references registered here no longer work on other gpus, breaking the module + # raise ValueError( + # "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + # " (torchrun or torch.distributed.launch (deprecated))." + # ) + # else: + # debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + # delay_optimizer_creation = ( + # is_sagemaker_mp_enabled() + # or self.is_fsdp_xla_enabled + # or self.is_fsdp_enabled + # ) + + # # We need to reset the scheduler, as its parameters may be different on subsequent calls + # if self._created_lr_scheduler: + # self.lr_scheduler = None + # self._created_lr_scheduler = False + + # if self.is_deepspeed_enabled: + # self.optimizer, self.lr_scheduler = deepspeed_init( + # self, num_training_steps=max_steps + # ) + + # if not delay_optimizer_creation: + # self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # self.state = TrainerState( + # stateful_callbacks=[ + # cb + # for cb in self.callback_handler.callbacks + [self.control] + # if isinstance(cb, ExportableState) + # ] + # ) + # self.state.is_hyper_param_search = trial is not None + # self.state.train_batch_size = self._train_batch_size + + # # Compute absolute values for logging, eval, and save if given as ratio + # if args.logging_steps is not None: + # if args.logging_steps < 1: + # self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + # else: + # self.state.logging_steps = args.logging_steps + # if args.eval_steps is not None: + # if args.eval_steps < 1: + # self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + # else: + # self.state.eval_steps = args.eval_steps + # if args.save_steps is not None: + # if args.save_steps < 1: + # self.state.save_steps = math.ceil(max_steps * args.save_steps) + # else: + # self.state.save_steps = args.save_steps + + # # Activate gradient checkpointing if needed + # if args.gradient_checkpointing: + # self.model.gradient_checkpointing_enable( + # gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs + # ) + + # model = self._wrap_model(self.model_wrapped) + + # # as the model is wrapped, don't use `accelerator.prepare` + # # this is for unhandled cases such as + # # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + # use_accelerator_prepare = True if model is self.model else False + + # if use_accelerator_prepare and self.is_fsdp_enabled: + # # In case of auto_find_batch_size=True + # # Remove FSDP wrapping from sub-models. + # self.model = unwrap_model(self.model, recursive=True) + + # if delay_optimizer_creation: + # if use_accelerator_prepare: + # # configure fsdp plugin for qlora if any + # self._fsdp_qlora_plugin_updates() + # if self.accelerator.mixed_precision != "fp8": + # self.model = self.accelerator.prepare(self.model) + # self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # # prepare using `accelerator` prepare + # if use_accelerator_prepare: + # self.model.train() + # if hasattr(self.lr_scheduler, "step"): + # if self.use_apex: + # model = self.accelerator.prepare(self.model) + # else: + # model, self.optimizer = self.accelerator.prepare( + # self.model, self.optimizer + # ) + # else: + # # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + # model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + # self.model, self.optimizer, self.lr_scheduler + # ) + # elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # # In this case we are in DDP + LOMO, which should be supported + # self.optimizer = self.accelerator.prepare(self.optimizer) + + # if self.is_fsdp_enabled: + # self.model = self.model_wrapped = model + + # # for the rest of this function `model` is the outside model, whether it was wrapped or not + # if model is not self.model: + # self.model_wrapped = model + + # # backward compatibility + # if self.is_deepspeed_enabled: + # self.deepspeed = self.model_wrapped + + # # ckpt loading + # if resume_from_checkpoint is not None: + # if self.is_deepspeed_enabled: + # deepspeed_load_checkpoint( + # self.model_wrapped, + # resume_from_checkpoint, + # load_module_strict=not _is_peft_model(self.model), + # ) + # elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + # self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # # Check if saved optimizer or scheduler states exist + # self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # # important: at this point: + # # self.model is the Transformers Model + # # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # # Train! + # logger.info("***** Running training *****") + # logger.info(f" Num examples = {num_examples:,}") + # logger.info(f" Num Epochs = {num_train_epochs:,}") + # logger.info( + # f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}" + # ) + # if self.args.per_device_train_batch_size != self._train_batch_size: + # logger.info( + # f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}" + # ) + # logger.info( + # f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}" + # ) + # logger.info( + # f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" + # ) + # logger.info(f" Total optimization steps = {max_steps:,}") + # logger.info( + # f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}" + # ) + + # self.state.epoch = 0 + # start_time = time.time() + # epochs_trained = 0 + # steps_trained_in_current_epoch = 0 + # steps_trained_progress_bar = None + + # # Check if continuing training from a checkpoint + # if resume_from_checkpoint is not None and os.path.isfile( + # os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + # ): + # self.state = TrainerState.load_from_json( + # os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + # ) + # self.compare_trainer_and_checkpoint_args(self.args, self.state) + # self._load_callback_state() + # epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + # if not args.ignore_data_skip: + # steps_trained_in_current_epoch = self.state.global_step % ( + # num_update_steps_per_epoch + # ) + # steps_trained_in_current_epoch *= args.gradient_accumulation_steps + # else: + # steps_trained_in_current_epoch = 0 + + # logger.info( + # " Continuing training from checkpoint, will skip to saved global_step" + # ) + # logger.info(f" Continuing training from epoch {epochs_trained}") + # logger.info( + # f" Continuing training from global step {self.state.global_step}" + # ) + # if not args.ignore_data_skip: + # logger.info( + # f" Will skip the first {epochs_trained} epochs then the first" + # f" {steps_trained_in_current_epoch} batches in the first epoch." + # ) + + # # Update the references + # self.callback_handler.model = self.model + # self.callback_handler.optimizer = self.optimizer + # self.callback_handler.lr_scheduler = self.lr_scheduler + # self.callback_handler.train_dataloader = train_dataloader + # if self.hp_name is not None and self._trial is not None: + # # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # # parameter to Train when using DDP. + # self.state.trial_name = self.hp_name(self._trial) + # if trial is not None: + # assignments = ( + # trial.assignments + # if self.hp_search_backend == HPSearchBackend.SIGOPT + # else trial + # ) + # self.state.trial_params = hp_params(assignments) + # else: + # self.state.trial_params = None + # # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # # to set this after the load. + # self.state.max_steps = max_steps + # self.state.num_train_epochs = num_train_epochs + # self.state.is_local_process_zero = self.is_local_process_zero() + # self.state.is_world_process_zero = self.is_world_process_zero() + + # # tr_loss is a tensor to avoid synchronization of TPUs through .item() + # tr_loss = torch.tensor(0.0).to(args.device) + # # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + # self._total_loss_scalar = 0.0 + # self._globalstep_last_logged = self.state.global_step + # model.zero_grad() + # grad_norm: Optional[float] = None + # self.control = self.callback_handler.on_train_begin( + # args, self.state, self.control + # ) + + # if args.eval_on_start: + # self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + + # for epoch in range(epochs_trained, num_train_epochs): + # epoch_dataloader = train_dataloader + # if hasattr(epoch_dataloader, "set_epoch"): + # epoch_dataloader.set_epoch(epoch) + + # # Reset the past mems state at the beginning of each epoch if necessary. + # if args.past_index >= 0: + # self._past = None + + # steps_in_epoch = ( + # len(epoch_dataloader) + # if len_dataloader is not None + # else args.max_steps * args.gradient_accumulation_steps + # ) + # self.control = self.callback_handler.on_epoch_begin( + # args, self.state, self.control + # ) + + # if ( + # epoch == epochs_trained + # and resume_from_checkpoint is not None + # and steps_trained_in_current_epoch == 0 + # ): + # self._load_rng_state(resume_from_checkpoint) + + # rng_to_sync = False + # steps_skipped = 0 + # if steps_trained_in_current_epoch > 0: + # epoch_dataloader = skip_first_batches( + # epoch_dataloader, steps_trained_in_current_epoch + # ) + # steps_skipped = steps_trained_in_current_epoch + # steps_trained_in_current_epoch = 0 + # rng_to_sync = True + + # step = -1 + # epoch_iterator = iter(epoch_dataloader) + # # We chunkify the epoch iterator into gradient accumulation steps `n` batches + # remainder = num_examples % args.gradient_accumulation_steps + # if remainder == 0: + # remainder = args.gradient_accumulation_steps + # update_step = -1 + # total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 + # for _ in range(total_updates): + # update_step += 1 + # num_batches = ( + # args.gradient_accumulation_steps + # if update_step != (total_updates - 1) + # else remainder + # ) + # batch_samples, num_items_in_batch = self.get_batch_samples( + # epoch_iterator, num_batches + # ) + # for i, inputs in enumerate(batch_samples): + # step += 1 + # do_sync_step = ( + # step + 1 + # ) % args.gradient_accumulation_steps == 0 or ( + # step + 1 + # ) == steps_in_epoch + # # Since we perform prefetching, we need to manually set sync_gradients + # if not do_sync_step: + # self.accelerator.gradient_state._set_sync_gradients(False) + # else: + # self.accelerator.gradient_state._set_sync_gradients(True) + + # if self.args.include_num_input_tokens_seen: + # main_input_name = getattr( + # self.model, "main_input_name", "input_ids" + # ) + # if main_input_name not in inputs: + # logger.warning( + # "Tried to track the number of tokens seen, however the current model is " + # "not configured properly to know what item is the input. To fix this, add " + # "a `main_input_name` attribute to the model class you are using." + # ) + # else: + # input_tokens = inputs[main_input_name].numel() + # input_tokens = torch.tensor( + # input_tokens, device=self.args.device, dtype=torch.int64 + # ) + # self.state.num_input_tokens_seen += ( + # self.accelerator.gather(input_tokens).sum().cpu().item() + # ) + # if rng_to_sync: + # self._load_rng_state(resume_from_checkpoint) + # rng_to_sync = False + + # # Skip past any already trained steps if resuming training + # if steps_trained_in_current_epoch > 0: + # steps_trained_in_current_epoch -= 1 + # if steps_trained_progress_bar is not None: + # steps_trained_progress_bar.update(1) + # if steps_trained_in_current_epoch == 0: + # self._load_rng_state(resume_from_checkpoint) + # continue + # elif steps_trained_progress_bar is not None: + # steps_trained_progress_bar.close() + # steps_trained_progress_bar = None + + # if step % args.gradient_accumulation_steps == 0: + # self.control = self.callback_handler.on_step_begin( + # args, self.state, self.control + # ) + + # # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + # context = ( + # functools.partial(self.accelerator.no_sync, model=model) + # if i != len(batch_samples) - 1 + # and self.accelerator.distributed_type + # != DistributedType.DEEPSPEED + # else contextlib.nullcontext + # ) + # with context(): + # tr_loss_step = self.training_step( + # model, inputs, num_items_in_batch + # ) + + # if ( + # args.logging_nan_inf_filter + # and not is_torch_xla_available() + # and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + # ): + # # if loss is nan or inf simply add the average of previous logged losses + # tr_loss = tr_loss + tr_loss / ( + # 1 + self.state.global_step - self._globalstep_last_logged + # ) + # else: + # if tr_loss.device != tr_loss_step.device: + # raise ValueError( + # f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + # ) + # tr_loss = tr_loss + tr_loss_step + + # self.current_flos += float(self.floating_point_ops(inputs)) + + # if do_sync_step: + # # Since we perform prefetching, we need to manually set sync_gradients to True + # self.accelerator.gradient_state._set_sync_gradients(True) + + # # Gradient clipping + # if args.max_grad_norm is not None and args.max_grad_norm > 0: + # # deepspeed does its own clipping + + # if is_sagemaker_mp_enabled() and args.fp16: + # _grad_norm = self.optimizer.clip_master_grads( + # args.max_grad_norm + # ) + # elif self.use_apex: + # # Revert to normal clipping otherwise, handling Apex or full precision + # _grad_norm = nn.utils.clip_grad_norm_( + # amp.master_params(self.optimizer), + # args.max_grad_norm, + # ) + # else: + # _grad_norm = self.accelerator.clip_grad_norm_( + # model.parameters(), + # args.max_grad_norm, + # ) + + # if ( + # is_accelerate_available() + # and self.accelerator.distributed_type + # == DistributedType.DEEPSPEED + # ): + # grad_norm = model.get_global_grad_norm() + # # In some cases the grad norm may not return a float + # if hasattr(grad_norm, "item"): + # grad_norm = grad_norm.item() + # else: + # grad_norm = _grad_norm + + # self.control = self.callback_handler.on_pre_optimizer_step( + # args, self.state, self.control + # ) + + # self.optimizer.step() + + # self.control = self.callback_handler.on_optimizer_step( + # args, self.state, self.control + # ) + + # optimizer_was_run = ( + # not self.accelerator.optimizer_step_was_skipped + # ) + # if optimizer_was_run: + # # Delay optimizer scheduling until metrics are generated + # if not isinstance( + # self.lr_scheduler, + # torch.optim.lr_scheduler.ReduceLROnPlateau, + # ): + # self.lr_scheduler.step() + + # model.zero_grad() + # self.state.global_step += 1 + # self.state.epoch = ( + # epoch + (step + 1 + steps_skipped) / steps_in_epoch + # ) + # self.control = self.callback_handler.on_step_end( + # args, self.state, self.control + # ) + # self._maybe_log_save_evaluate( + # tr_loss, + # grad_norm, + # model, + # trial, + # epoch, + # ignore_keys_for_eval, + # start_time, + # ) + # else: + # self.control = self.callback_handler.on_substep_end( + # args, self.state, self.control + # ) + + # # PyTorch/XLA relies on the data loader to insert the mark_step for + # # each step. Since we are breaking the loop early, we need to manually + # # insert the mark_step here. + # if ( + # self.control.should_epoch_stop + # or self.control.should_training_stop + # ): + # if is_torch_xla_available(): + # xm.mark_step() + # break + # # We also need to break out of the nested loop + # if self.control.should_epoch_stop or self.control.should_training_stop: + # if is_torch_xla_available(): + # xm.mark_step() + # break + # if step < 0: + # logger.warning( + # "There seems not to be a single sample in your epoch_iterator, stopping training at step" + # f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + # f" num_steps ({max_steps}) higher than the number of available samples." + # ) + # self.control.should_training_stop = True + + # self.control = self.callback_handler.on_epoch_end( + # args, self.state, self.control + # ) + # self._maybe_log_save_evaluate( + # tr_loss, + # grad_norm, + # model, + # trial, + # epoch, + # ignore_keys_for_eval, + # start_time, + # ) + + # if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # if is_torch_xla_available(): + # # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + # xm.master_print(met.metrics_report()) + # else: + # logger.warning( + # "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + # "configured. Check your training configuration if this is unexpected." + # ) + # if self.control.should_training_stop: + # break + + # if args.past_index and hasattr(self, "_past"): + # # Clean the state at the end of training + # delattr(self, "_past") + + # logger.info( + # "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" + # ) + # if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # # Wait for everyone to get here so we are sure the model has been saved by process 0. + # if is_torch_xla_available(): + # xm.rendezvous("load_best_model_at_end") + # elif args.parallel_mode == ParallelMode.DISTRIBUTED: + # dist.barrier() + # elif is_sagemaker_mp_enabled(): + # smp.barrier() + + # self._load_best_model() + + # # add remaining tr_loss + # self._total_loss_scalar += tr_loss.item() + # effective_global_step = max( + # self.state.global_step, 0.001 + # ) # Avoid ZeroDivisionError + # train_loss = self._total_loss_scalar / effective_global_step + + # metrics = speed_metrics( + # "train", + # start_time, + # num_samples=num_train_samples, + # num_steps=self.state.max_steps, + # num_tokens=num_train_tokens, + # ) + # self.store_flos() + # metrics["total_flos"] = self.state.total_flos + # metrics["train_loss"] = train_loss + + # self.is_in_train = False + + # self._memory_tracker.stop_and_update_metrics(metrics) + + # self.log(metrics) + + # run_dir = self._get_output_dir(trial) + # checkpoints_sorted = self._sorted_checkpoints( + # use_mtime=False, output_dir=run_dir + # ) + + # # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + # if ( + # self.args.should_save + # and self.state.best_model_checkpoint is not None + # and self.args.save_total_limit == 1 + # ): + # for checkpoint in checkpoints_sorted: + # if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + # logger.info( + # f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit" + # ) + # shutil.rmtree(checkpoint, ignore_errors=True) + + # self.control = self.callback_handler.on_train_end( + # args, self.state, self.control + # ) + + # # Wait for the checkpoint to be uploaded. + # self._finish_current_push() + + # # After training we make sure to retrieve back the original forward pass method + # # for the embedding layer by removing the forward post hook. + # if self.neftune_noise_alpha is not None: + # self._deactivate_neftune(self.model) + + # return TrainOutput(self.state.global_step, train_loss, metrics)