75 lines
3.5 KiB
Python
75 lines
3.5 KiB
Python
from accelerate import Accelerator, DataLoaderConfiguration
|
|
from transformers import (
|
|
TrainingArguments,
|
|
)
|
|
|
|
|
|
def create_accelerator_and_postprocess(args: TrainingArguments):
|
|
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
|
|
grad_acc_kwargs = {}
|
|
if args.accelerator_config.gradient_accumulation_kwargs is not None:
|
|
grad_acc_kwargs = args.accelerator_config.gradient_accumulation_kwargs
|
|
|
|
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
|
|
if "num_steps" in grad_acc_kwargs:
|
|
if args.gradient_accumulation_steps > 1:
|
|
# raise because we do not know which setting is intended.
|
|
raise ValueError(
|
|
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
|
|
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
|
|
)
|
|
else:
|
|
args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
|
|
|
|
accelerator_config = args.accelerator_config.to_dict()
|
|
|
|
dataloader_config = DataLoaderConfiguration(
|
|
split_batches=accelerator_config.pop("split_batches"),
|
|
dispatch_batches=accelerator_config.pop("dispatch_batches"),
|
|
even_batches=accelerator_config.pop("even_batches"),
|
|
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
|
)
|
|
dataloader_config.data_seed = args.data_seed
|
|
|
|
non_blocking = accelerator_config.pop("non_blocking")
|
|
dataloader_config.non_blocking = non_blocking
|
|
# this would have been updated above, no need for it anymore
|
|
accelerator_config.pop("gradient_accumulation_kwargs")
|
|
|
|
accelerator_args = {"deepspeed_plugin": args.deepspeed_plugin, "log_with": "wandb"}
|
|
accelerator_args["dataloader_config"] = dataloader_config
|
|
# create accelerator object
|
|
accelerator = Accelerator(**accelerator_args)
|
|
|
|
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
|
is_deepspeed_enabled = (
|
|
getattr(accelerator.state, "deepspeed_plugin", None) is not None
|
|
)
|
|
is_fsdp_enabled = getattr(accelerator.state, "fsdp_plugin", None) is not None
|
|
|
|
# post accelerator creation setup
|
|
if is_fsdp_enabled:
|
|
fsdp_plugin = accelerator.state.fsdp_plugin
|
|
fsdp_plugin.limit_all_gathers = args.fsdp_config.get(
|
|
"limit_all_gathers", fsdp_plugin.limit_all_gathers
|
|
)
|
|
fsdp_plugin.activation_checkpointing = args.fsdp_config.get(
|
|
"activation_checkpointing", fsdp_plugin.activation_checkpointing
|
|
)
|
|
if fsdp_plugin.activation_checkpointing and args.gradient_checkpointing:
|
|
raise ValueError(
|
|
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
|
|
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
|
|
"when using FSDP."
|
|
)
|
|
|
|
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
|
|
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
|
|
|
ds_plugin = accelerator.state.deepspeed_plugin
|
|
|
|
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
|
|
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
|
|
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size=False)
|
|
return accelerator
|