from accelerate import Accelerator, DataLoaderConfiguration from transformers import ( TrainingArguments, ) def create_accelerator_and_postprocess(args: TrainingArguments): # We explicitly don't rely on the `Accelerator` to do gradient accumulation grad_acc_kwargs = {} if args.accelerator_config.gradient_accumulation_kwargs is not None: grad_acc_kwargs = args.accelerator_config.gradient_accumulation_kwargs # check if num_steps is attempted to be passed in gradient_accumulation_kwargs if "num_steps" in grad_acc_kwargs: if args.gradient_accumulation_steps > 1: # raise because we do not know which setting is intended. raise ValueError( "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." ) else: args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] accelerator_config = args.accelerator_config.to_dict() dataloader_config = DataLoaderConfiguration( split_batches=accelerator_config.pop("split_batches"), dispatch_batches=accelerator_config.pop("dispatch_batches"), even_batches=accelerator_config.pop("even_batches"), use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), ) dataloader_config.data_seed = args.data_seed non_blocking = accelerator_config.pop("non_blocking") dataloader_config.non_blocking = non_blocking # this would have been updated above, no need for it anymore accelerator_config.pop("gradient_accumulation_kwargs") accelerator_args = {"deepspeed_plugin": args.deepspeed_plugin, "log_with": "wandb"} accelerator_args["dataloader_config"] = dataloader_config # create accelerator object accelerator = Accelerator(**accelerator_args) # deepspeed and accelerate flags covering both trainer args and accelerate launcher is_deepspeed_enabled = ( getattr(accelerator.state, "deepspeed_plugin", None) is not None ) is_fsdp_enabled = getattr(accelerator.state, "fsdp_plugin", None) is not None # post accelerator creation setup if is_fsdp_enabled: fsdp_plugin = accelerator.state.fsdp_plugin fsdp_plugin.limit_all_gathers = args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) fsdp_plugin.activation_checkpointing = args.fsdp_config.get( "activation_checkpointing", fsdp_plugin.activation_checkpointing ) if fsdp_plugin.activation_checkpointing and args.gradient_checkpointing: raise ValueError( "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " "when using FSDP." ) if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None: from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig ds_plugin = accelerator.state.deepspeed_plugin ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size=False) return accelerator