cl-lmm/src/utils/accelerator.py

75 lines
3.5 KiB
Python

from accelerate import Accelerator, DataLoaderConfiguration
from transformers import (
TrainingArguments,
)
def create_accelerator_and_postprocess(args: TrainingArguments):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs:
if args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
accelerator_config = args.accelerator_config.to_dict()
dataloader_config = DataLoaderConfiguration(
split_batches=accelerator_config.pop("split_batches"),
dispatch_batches=accelerator_config.pop("dispatch_batches"),
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
)
dataloader_config.data_seed = args.data_seed
non_blocking = accelerator_config.pop("non_blocking")
dataloader_config.non_blocking = non_blocking
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
accelerator_args = {"deepspeed_plugin": args.deepspeed_plugin, "log_with": "wandb"}
accelerator_args["dataloader_config"] = dataloader_config
# create accelerator object
accelerator = Accelerator(**accelerator_args)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
is_deepspeed_enabled = (
getattr(accelerator.state, "deepspeed_plugin", None) is not None
)
is_fsdp_enabled = getattr(accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if is_fsdp_enabled:
fsdp_plugin = accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.activation_checkpointing = args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
if fsdp_plugin.activation_checkpointing and args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
ds_plugin = accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size=False)
return accelerator