feat: 更新训练脚本以支持新的lora_target_modules和输出目录,优化训练配置

This commit is contained in:
YunyaoZhou 2025-05-30 19:14:34 +08:00
parent baccca420a
commit b84ebb03c7
2 changed files with 79 additions and 717 deletions

View File

@ -5,14 +5,14 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml
--use_peft \
--peft_type MOELORA \
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
--lora_target_modules .*model\.layers.*proj \
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
--lora_r 8 \
--lora_alpha 32 \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--num_train_epochs 1 \
--output_dir checkpoint/qwen2_alllinear/ \
--output_dir checkpoint/qwen2_5omni_moelora/ \
--learning_rate 2e-4 \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
@ -21,4 +21,4 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml
--logging_steps 10 \
--gradient_checkpointing \
--weight_decay 0.1 \
--resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_alllinear/checkpoint-1000
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_alllinear/checkpoint-1000

View File

@ -80,717 +80,79 @@ class ContinualTrainer(Trainer):
else:
super().create_accelerator_and_postprocess()
# def compute_loss(
# self, model, inputs, return_outputs=False, num_items_in_batch=None
# ):
# """
# How the loss is computed by Trainer. By default, all models return the loss in the first element.
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate / 10,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.learning_rate,
},
]
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped / 2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped / 2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
# Subclass and override for custom behavior.
# """
# if (
# self.label_smoother is not None or self.compute_loss_func is not None
# ) and "labels" in inputs:
# labels = inputs.pop("labels")
# else:
# labels = None
# if self.model_accepts_loss_kwargs:
# loss_kwargs = {}
# if num_items_in_batch is not None:
# loss_kwargs["num_items_in_batch"] = num_items_in_batch
# inputs = {**inputs, **loss_kwargs}
# outputs = model(**inputs)
# # Save past state if it exists
# # TODO: this needs to be fixed and made cleaner later.
# if self.args.past_index >= 0:
# self._past = outputs[self.args.past_index]
# if labels is not None:
# unwrapped_model = self.accelerator.unwrap_model(model)
# if _is_peft_model(unwrapped_model):
# model_name = unwrapped_model.base_model.model._get_name()
# else:
# model_name = unwrapped_model._get_name()
# # User-defined compute_loss function
# if self.compute_loss_func is not None:
# loss = self.compute_loss_func(
# outputs, labels, num_items_in_batch=num_items_in_batch
# )
# elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
# loss = self.label_smoother(outputs, labels, shift_labels=True)
# else:
# loss = self.label_smoother(outputs, labels)
# else:
# if isinstance(outputs, dict) and "loss" not in outputs:
# raise ValueError(
# "The model did not return a loss from the inputs, only the following keys: "
# f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
# )
# # We don't use .loss here since the model may return tuples instead of ModelOutput.
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
# return (loss, outputs) if return_outputs else loss
# def _inner_training_loop(
# self,
# batch_size=None,
# args=None,
# resume_from_checkpoint=None,
# trial=None,
# ignore_keys_for_eval=None,
# ):
# self.accelerator.free_memory()
# self._train_batch_size = batch_size
# if self.args.auto_find_batch_size:
# if self.state.train_batch_size != self._train_batch_size:
# from accelerate.utils import release_memory
# (self.model_wrapped,) = release_memory(self.model_wrapped)
# self.model_wrapped = self.model
# # Check for DeepSpeed *after* the intial pass and modify the config
# if self.is_deepspeed_enabled:
# # Temporarily unset `self.args.train_batch_size`
# original_bs = self.args.per_device_train_batch_size
# self.args.per_device_train_batch_size = (
# self._train_batch_size // max(1, self.args.n_gpu)
# )
# self.propagate_args_to_deepspeed(True)
# self.args.per_device_train_batch_size = original_bs
# self.state.train_batch_size = self._train_batch_size
# logger.debug(
# f"Currently training with a batch size of: {self._train_batch_size}"
# )
# # Data loader and number of training steps
# train_dataloader = self.get_train_dataloader()
# if self.is_fsdp_xla_v2_enabled:
# train_dataloader = tpu_spmd_dataloader(train_dataloader)
# # Setting up training control variables:
# # number of training epochs: num_train_epochs
# # number of training steps per epoch: num_update_steps_per_epoch
# # total number of training steps to execute: max_steps
# total_train_batch_size = (
# self._train_batch_size * args.gradient_accumulation_steps * args.world_size
# )
# len_dataloader = None
# num_train_tokens = None
# if has_length(train_dataloader):
# len_dataloader = len(train_dataloader)
# num_update_steps_per_epoch = (
# len_dataloader // args.gradient_accumulation_steps
# )
# num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
# num_examples = self.num_examples(train_dataloader)
# if args.max_steps > 0:
# max_steps = args.max_steps
# num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
# args.max_steps % num_update_steps_per_epoch > 0
# )
# # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# # the best we can do.
# num_train_samples = args.max_steps * total_train_batch_size
# if args.include_tokens_per_second:
# num_train_tokens = (
# self.num_tokens(train_dataloader, args.max_steps)
# * args.gradient_accumulation_steps
# )
# else:
# max_steps = math.ceil(
# args.num_train_epochs * num_update_steps_per_epoch
# )
# num_train_epochs = math.ceil(args.num_train_epochs)
# num_train_samples = (
# self.num_examples(train_dataloader) * args.num_train_epochs
# )
# if args.include_tokens_per_second:
# num_train_tokens = (
# self.num_tokens(train_dataloader) * args.num_train_epochs
# )
# elif (
# args.max_steps > 0
# ): # Rely on max_steps when dataloader does not have a working size
# max_steps = args.max_steps
# # Setting a very large number of epochs so we go as many times as necessary over the iterator.
# num_train_epochs = sys.maxsize
# num_update_steps_per_epoch = max_steps
# num_examples = total_train_batch_size * args.max_steps
# num_train_samples = args.max_steps * total_train_batch_size
# if args.include_tokens_per_second:
# num_train_tokens = (
# self.num_tokens(train_dataloader, args.max_steps)
# * args.gradient_accumulation_steps
# )
# else:
# raise ValueError(
# "args.max_steps must be set to a positive value if dataloader does not have a length, was"
# f" {args.max_steps}"
# )
# if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
# if self.args.n_gpu > 1:
# # nn.DataParallel(model) replicates the model, creating new variables and module
# # references registered here no longer work on other gpus, breaking the module
# raise ValueError(
# "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
# " (torchrun or torch.distributed.launch (deprecated))."
# )
# else:
# debug_overflow = DebugUnderflowOverflow(self.model) # noqa
# delay_optimizer_creation = (
# is_sagemaker_mp_enabled()
# or self.is_fsdp_xla_enabled
# or self.is_fsdp_enabled
# )
# # We need to reset the scheduler, as its parameters may be different on subsequent calls
# if self._created_lr_scheduler:
# self.lr_scheduler = None
# self._created_lr_scheduler = False
# if self.is_deepspeed_enabled:
# self.optimizer, self.lr_scheduler = deepspeed_init(
# self, num_training_steps=max_steps
# )
# if not delay_optimizer_creation:
# self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# self.state = TrainerState(
# stateful_callbacks=[
# cb
# for cb in self.callback_handler.callbacks + [self.control]
# if isinstance(cb, ExportableState)
# ]
# )
# self.state.is_hyper_param_search = trial is not None
# self.state.train_batch_size = self._train_batch_size
# # Compute absolute values for logging, eval, and save if given as ratio
# if args.logging_steps is not None:
# if args.logging_steps < 1:
# self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
# else:
# self.state.logging_steps = args.logging_steps
# if args.eval_steps is not None:
# if args.eval_steps < 1:
# self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
# else:
# self.state.eval_steps = args.eval_steps
# if args.save_steps is not None:
# if args.save_steps < 1:
# self.state.save_steps = math.ceil(max_steps * args.save_steps)
# else:
# self.state.save_steps = args.save_steps
# # Activate gradient checkpointing if needed
# if args.gradient_checkpointing:
# self.model.gradient_checkpointing_enable(
# gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs
# )
# model = self._wrap_model(self.model_wrapped)
# # as the model is wrapped, don't use `accelerator.prepare`
# # this is for unhandled cases such as
# # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
# use_accelerator_prepare = True if model is self.model else False
# if use_accelerator_prepare and self.is_fsdp_enabled:
# # In case of auto_find_batch_size=True
# # Remove FSDP wrapping from sub-models.
# self.model = unwrap_model(self.model, recursive=True)
# if delay_optimizer_creation:
# if use_accelerator_prepare:
# # configure fsdp plugin for qlora if any
# self._fsdp_qlora_plugin_updates()
# if self.accelerator.mixed_precision != "fp8":
# self.model = self.accelerator.prepare(self.model)
# self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# # prepare using `accelerator` prepare
# if use_accelerator_prepare:
# self.model.train()
# if hasattr(self.lr_scheduler, "step"):
# if self.use_apex:
# model = self.accelerator.prepare(self.model)
# else:
# model, self.optimizer = self.accelerator.prepare(
# self.model, self.optimizer
# )
# else:
# # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
# model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
# self.model, self.optimizer, self.lr_scheduler
# )
# elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# # In this case we are in DDP + LOMO, which should be supported
# self.optimizer = self.accelerator.prepare(self.optimizer)
# if self.is_fsdp_enabled:
# self.model = self.model_wrapped = model
# # for the rest of this function `model` is the outside model, whether it was wrapped or not
# if model is not self.model:
# self.model_wrapped = model
# # backward compatibility
# if self.is_deepspeed_enabled:
# self.deepspeed = self.model_wrapped
# # ckpt loading
# if resume_from_checkpoint is not None:
# if self.is_deepspeed_enabled:
# deepspeed_load_checkpoint(
# self.model_wrapped,
# resume_from_checkpoint,
# load_module_strict=not _is_peft_model(self.model),
# )
# elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
# self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
# # Check if saved optimizer or scheduler states exist
# self._load_optimizer_and_scheduler(resume_from_checkpoint)
# # important: at this point:
# # self.model is the Transformers Model
# # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
# # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
# # Train!
# logger.info("***** Running training *****")
# logger.info(f" Num examples = {num_examples:,}")
# logger.info(f" Num Epochs = {num_train_epochs:,}")
# logger.info(
# f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}"
# )
# if self.args.per_device_train_batch_size != self._train_batch_size:
# logger.info(
# f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}"
# )
# logger.info(
# f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}"
# )
# logger.info(
# f" Gradient Accumulation steps = {args.gradient_accumulation_steps}"
# )
# logger.info(f" Total optimization steps = {max_steps:,}")
# logger.info(
# f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}"
# )
# self.state.epoch = 0
# start_time = time.time()
# epochs_trained = 0
# steps_trained_in_current_epoch = 0
# steps_trained_progress_bar = None
# # Check if continuing training from a checkpoint
# if resume_from_checkpoint is not None and os.path.isfile(
# os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
# ):
# self.state = TrainerState.load_from_json(
# os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
# )
# self.compare_trainer_and_checkpoint_args(self.args, self.state)
# self._load_callback_state()
# epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
# if not args.ignore_data_skip:
# steps_trained_in_current_epoch = self.state.global_step % (
# num_update_steps_per_epoch
# )
# steps_trained_in_current_epoch *= args.gradient_accumulation_steps
# else:
# steps_trained_in_current_epoch = 0
# logger.info(
# " Continuing training from checkpoint, will skip to saved global_step"
# )
# logger.info(f" Continuing training from epoch {epochs_trained}")
# logger.info(
# f" Continuing training from global step {self.state.global_step}"
# )
# if not args.ignore_data_skip:
# logger.info(
# f" Will skip the first {epochs_trained} epochs then the first"
# f" {steps_trained_in_current_epoch} batches in the first epoch."
# )
# # Update the references
# self.callback_handler.model = self.model
# self.callback_handler.optimizer = self.optimizer
# self.callback_handler.lr_scheduler = self.lr_scheduler
# self.callback_handler.train_dataloader = train_dataloader
# if self.hp_name is not None and self._trial is not None:
# # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# # parameter to Train when using DDP.
# self.state.trial_name = self.hp_name(self._trial)
# if trial is not None:
# assignments = (
# trial.assignments
# if self.hp_search_backend == HPSearchBackend.SIGOPT
# else trial
# )
# self.state.trial_params = hp_params(assignments)
# else:
# self.state.trial_params = None
# # This should be the same if the state has been saved but in case the training arguments changed, it's safer
# # to set this after the load.
# self.state.max_steps = max_steps
# self.state.num_train_epochs = num_train_epochs
# self.state.is_local_process_zero = self.is_local_process_zero()
# self.state.is_world_process_zero = self.is_world_process_zero()
# # tr_loss is a tensor to avoid synchronization of TPUs through .item()
# tr_loss = torch.tensor(0.0).to(args.device)
# # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
# self._total_loss_scalar = 0.0
# self._globalstep_last_logged = self.state.global_step
# model.zero_grad()
# grad_norm: Optional[float] = None
# self.control = self.callback_handler.on_train_begin(
# args, self.state, self.control
# )
# if args.eval_on_start:
# self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
# for epoch in range(epochs_trained, num_train_epochs):
# epoch_dataloader = train_dataloader
# if hasattr(epoch_dataloader, "set_epoch"):
# epoch_dataloader.set_epoch(epoch)
# # Reset the past mems state at the beginning of each epoch if necessary.
# if args.past_index >= 0:
# self._past = None
# steps_in_epoch = (
# len(epoch_dataloader)
# if len_dataloader is not None
# else args.max_steps * args.gradient_accumulation_steps
# )
# self.control = self.callback_handler.on_epoch_begin(
# args, self.state, self.control
# )
# if (
# epoch == epochs_trained
# and resume_from_checkpoint is not None
# and steps_trained_in_current_epoch == 0
# ):
# self._load_rng_state(resume_from_checkpoint)
# rng_to_sync = False
# steps_skipped = 0
# if steps_trained_in_current_epoch > 0:
# epoch_dataloader = skip_first_batches(
# epoch_dataloader, steps_trained_in_current_epoch
# )
# steps_skipped = steps_trained_in_current_epoch
# steps_trained_in_current_epoch = 0
# rng_to_sync = True
# step = -1
# epoch_iterator = iter(epoch_dataloader)
# # We chunkify the epoch iterator into gradient accumulation steps `n` batches
# remainder = num_examples % args.gradient_accumulation_steps
# if remainder == 0:
# remainder = args.gradient_accumulation_steps
# update_step = -1
# total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
# for _ in range(total_updates):
# update_step += 1
# num_batches = (
# args.gradient_accumulation_steps
# if update_step != (total_updates - 1)
# else remainder
# )
# batch_samples, num_items_in_batch = self.get_batch_samples(
# epoch_iterator, num_batches
# )
# for i, inputs in enumerate(batch_samples):
# step += 1
# do_sync_step = (
# step + 1
# ) % args.gradient_accumulation_steps == 0 or (
# step + 1
# ) == steps_in_epoch
# # Since we perform prefetching, we need to manually set sync_gradients
# if not do_sync_step:
# self.accelerator.gradient_state._set_sync_gradients(False)
# else:
# self.accelerator.gradient_state._set_sync_gradients(True)
# if self.args.include_num_input_tokens_seen:
# main_input_name = getattr(
# self.model, "main_input_name", "input_ids"
# )
# if main_input_name not in inputs:
# logger.warning(
# "Tried to track the number of tokens seen, however the current model is "
# "not configured properly to know what item is the input. To fix this, add "
# "a `main_input_name` attribute to the model class you are using."
# )
# else:
# input_tokens = inputs[main_input_name].numel()
# input_tokens = torch.tensor(
# input_tokens, device=self.args.device, dtype=torch.int64
# )
# self.state.num_input_tokens_seen += (
# self.accelerator.gather(input_tokens).sum().cpu().item()
# )
# if rng_to_sync:
# self._load_rng_state(resume_from_checkpoint)
# rng_to_sync = False
# # Skip past any already trained steps if resuming training
# if steps_trained_in_current_epoch > 0:
# steps_trained_in_current_epoch -= 1
# if steps_trained_progress_bar is not None:
# steps_trained_progress_bar.update(1)
# if steps_trained_in_current_epoch == 0:
# self._load_rng_state(resume_from_checkpoint)
# continue
# elif steps_trained_progress_bar is not None:
# steps_trained_progress_bar.close()
# steps_trained_progress_bar = None
# if step % args.gradient_accumulation_steps == 0:
# self.control = self.callback_handler.on_step_begin(
# args, self.state, self.control
# )
# # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
# context = (
# functools.partial(self.accelerator.no_sync, model=model)
# if i != len(batch_samples) - 1
# and self.accelerator.distributed_type
# != DistributedType.DEEPSPEED
# else contextlib.nullcontext
# )
# with context():
# tr_loss_step = self.training_step(
# model, inputs, num_items_in_batch
# )
# if (
# args.logging_nan_inf_filter
# and not is_torch_xla_available()
# and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
# ):
# # if loss is nan or inf simply add the average of previous logged losses
# tr_loss = tr_loss + tr_loss / (
# 1 + self.state.global_step - self._globalstep_last_logged
# )
# else:
# if tr_loss.device != tr_loss_step.device:
# raise ValueError(
# f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
# )
# tr_loss = tr_loss + tr_loss_step
# self.current_flos += float(self.floating_point_ops(inputs))
# if do_sync_step:
# # Since we perform prefetching, we need to manually set sync_gradients to True
# self.accelerator.gradient_state._set_sync_gradients(True)
# # Gradient clipping
# if args.max_grad_norm is not None and args.max_grad_norm > 0:
# # deepspeed does its own clipping
# if is_sagemaker_mp_enabled() and args.fp16:
# _grad_norm = self.optimizer.clip_master_grads(
# args.max_grad_norm
# )
# elif self.use_apex:
# # Revert to normal clipping otherwise, handling Apex or full precision
# _grad_norm = nn.utils.clip_grad_norm_(
# amp.master_params(self.optimizer),
# args.max_grad_norm,
# )
# else:
# _grad_norm = self.accelerator.clip_grad_norm_(
# model.parameters(),
# args.max_grad_norm,
# )
# if (
# is_accelerate_available()
# and self.accelerator.distributed_type
# == DistributedType.DEEPSPEED
# ):
# grad_norm = model.get_global_grad_norm()
# # In some cases the grad norm may not return a float
# if hasattr(grad_norm, "item"):
# grad_norm = grad_norm.item()
# else:
# grad_norm = _grad_norm
# self.control = self.callback_handler.on_pre_optimizer_step(
# args, self.state, self.control
# )
# self.optimizer.step()
# self.control = self.callback_handler.on_optimizer_step(
# args, self.state, self.control
# )
# optimizer_was_run = (
# not self.accelerator.optimizer_step_was_skipped
# )
# if optimizer_was_run:
# # Delay optimizer scheduling until metrics are generated
# if not isinstance(
# self.lr_scheduler,
# torch.optim.lr_scheduler.ReduceLROnPlateau,
# ):
# self.lr_scheduler.step()
# model.zero_grad()
# self.state.global_step += 1
# self.state.epoch = (
# epoch + (step + 1 + steps_skipped) / steps_in_epoch
# )
# self.control = self.callback_handler.on_step_end(
# args, self.state, self.control
# )
# self._maybe_log_save_evaluate(
# tr_loss,
# grad_norm,
# model,
# trial,
# epoch,
# ignore_keys_for_eval,
# start_time,
# )
# else:
# self.control = self.callback_handler.on_substep_end(
# args, self.state, self.control
# )
# # PyTorch/XLA relies on the data loader to insert the mark_step for
# # each step. Since we are breaking the loop early, we need to manually
# # insert the mark_step here.
# if (
# self.control.should_epoch_stop
# or self.control.should_training_stop
# ):
# if is_torch_xla_available():
# xm.mark_step()
# break
# # We also need to break out of the nested loop
# if self.control.should_epoch_stop or self.control.should_training_stop:
# if is_torch_xla_available():
# xm.mark_step()
# break
# if step < 0:
# logger.warning(
# "There seems not to be a single sample in your epoch_iterator, stopping training at step"
# f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
# f" num_steps ({max_steps}) higher than the number of available samples."
# )
# self.control.should_training_stop = True
# self.control = self.callback_handler.on_epoch_end(
# args, self.state, self.control
# )
# self._maybe_log_save_evaluate(
# tr_loss,
# grad_norm,
# model,
# trial,
# epoch,
# ignore_keys_for_eval,
# start_time,
# )
# if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
# if is_torch_xla_available():
# # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
# xm.master_print(met.metrics_report())
# else:
# logger.warning(
# "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
# "configured. Check your training configuration if this is unexpected."
# )
# if self.control.should_training_stop:
# break
# if args.past_index and hasattr(self, "_past"):
# # Clean the state at the end of training
# delattr(self, "_past")
# logger.info(
# "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
# )
# if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# # Wait for everyone to get here so we are sure the model has been saved by process 0.
# if is_torch_xla_available():
# xm.rendezvous("load_best_model_at_end")
# elif args.parallel_mode == ParallelMode.DISTRIBUTED:
# dist.barrier()
# elif is_sagemaker_mp_enabled():
# smp.barrier()
# self._load_best_model()
# # add remaining tr_loss
# self._total_loss_scalar += tr_loss.item()
# effective_global_step = max(
# self.state.global_step, 0.001
# ) # Avoid ZeroDivisionError
# train_loss = self._total_loss_scalar / effective_global_step
# metrics = speed_metrics(
# "train",
# start_time,
# num_samples=num_train_samples,
# num_steps=self.state.max_steps,
# num_tokens=num_train_tokens,
# )
# self.store_flos()
# metrics["total_flos"] = self.state.total_flos
# metrics["train_loss"] = train_loss
# self.is_in_train = False
# self._memory_tracker.stop_and_update_metrics(metrics)
# self.log(metrics)
# run_dir = self._get_output_dir(trial)
# checkpoints_sorted = self._sorted_checkpoints(
# use_mtime=False, output_dir=run_dir
# )
# # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
# if (
# self.args.should_save
# and self.state.best_model_checkpoint is not None
# and self.args.save_total_limit == 1
# ):
# for checkpoint in checkpoints_sorted:
# if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
# logger.info(
# f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit"
# )
# shutil.rmtree(checkpoint, ignore_errors=True)
# self.control = self.callback_handler.on_train_end(
# args, self.state, self.control
# )
# # Wait for the checkpoint to be uploaded.
# self._finish_current_push()
# # After training we make sure to retrieve back the original forward pass method
# # for the embedding layer by removing the forward post hook.
# if self.neftune_noise_alpha is not None:
# self._deactivate_neftune(self.model)
# return TrainOutput(self.state.global_step, train_loss, metrics)