diff --git a/.gitignore b/.gitignore index c4cf06a..5acf0bf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ **/.venv/* **/__pycache__/* +rsync.sh diff --git a/src/datasets_library/factory.py b/src/datasets_library/factory.py index 8e032c6..eb5f037 100644 --- a/src/datasets_library/factory.py +++ b/src/datasets_library/factory.py @@ -3,10 +3,10 @@ from typing import Literal def get_dataset( - script_args, base_path="/home/zyy/research/accelerate/dataset" + dataset_name, base_path="/home/zyy/research/accelerate/dataset" ) -> dict[Literal["train", "test", "generation"], Dataset]: dataset: dict[Literal["train", "test", "generation"], Dataset] = {} - if script_args.dataset_name == "OCR_VQA_200K": + if dataset_name == "OCR_VQA_200K": import os.path as osp from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration diff --git a/src/evaluate.sh b/src/evaluate.sh index b65fe7a..bf01d6e 100755 --- a/src/evaluate.sh +++ b/src/evaluate.sh @@ -1,6 +1,6 @@ #!/bin/bash -accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluate.py \ +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluate_1.py \ --dataset_name OCR_VQA_200K \ --use_peft \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ diff --git a/src/evaluate.py b/src/evaluate_temp.py similarity index 99% rename from src/evaluate.py rename to src/evaluate_temp.py index ed8cfb3..7abfbb5 100644 --- a/src/evaluate.py +++ b/src/evaluate_temp.py @@ -128,14 +128,15 @@ if __name__ == "__main__": accelerator=accelerator, ) trainer.train() - + trainer.save_model(training_args.output_dir) - + # 清理cache torch.cuda.empty_cache() - + # load_model from transformers import AutoModelForVision2Seq + model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir) model = accelerator.prepare(model) diff --git a/src/run.sh b/src/run.sh deleted file mode 100755 index d46514d..0000000 --- a/src/run.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml sft_vlm.py \ - --dataset_name OCR_VQA_200K \ - --use_peft \ - --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ - --lora_target_modules q_proj v_proj \ - --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 2 \ - --gradient_accumulation_steps 8 \ - --max_seq_length 1024 \ - --output_dir checkpoint/sft-llava-1.5-7b-hf \ - --bf16 \ - --torch_dtype bfloat16 \ - --eval_strategy epoch \ - diff --git a/src/sft_vlm.py b/src/sft_vlm.py deleted file mode 100644 index 1148665..0000000 --- a/src/sft_vlm.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import ( - AutoModelForVision2Seq, - AutoProcessor, - LlavaForConditionalGeneration, -) - -from trl import ( - ModelConfig, - SFTScriptArguments, - SFTConfig, - SFTTrainer, - TrlParser, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) - - -if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_args = parser.parse_args_and_config() - script_args: SFTScriptArguments - training_args: SFTConfig - model_args: ModelConfig - training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) - training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} - - ################ - # Model, Tokenizer & Processor - ################ - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - padding_side="right", - ) - - model = AutoModelForVision2Seq.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - **model_kwargs, - ) - - if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": - from collate_fn_library.qwen2 import collate_fn_for_train - from functools import partial - - collate_fn_for_train = partial(collate_fn_for_train, processor=processor) - - ################ - # Dataset - ################ - base_path = "/home/zyy/research/accelerate/dataset" - if script_args.dataset_name == "OCR_VQA_200K": - import os.path as osp - from datasets_library.OCRVQADataset import OCRVQADataset - - dataset = { - "train": OCRVQADataset( - osp.join(base_path, "OCR-VQA-200K/images"), - osp.join(base_path, "OCR-VQA-200K/dataset.json"), - split="train", - ), - "test": OCRVQADataset( - osp.join(base_path, "OCR-VQA-200K/images"), - osp.join(base_path, "OCR-VQA-200K/dataset.json"), - split="test", - ), - } - - else: - dataset = load_dataset(script_args.dataset_name, name=script_args.config) - - ################ - # Training - ################ - trainer = SFTTrainer( - model=model, - args=training_args, - data_collator=collate_fn_for_train, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=( - dataset[script_args.dataset_test_split] - if training_args.eval_strategy != "no" - else None - ), - peft_config=get_peft_config(model_args), - ) - - trainer.train() - - model = trainer.model - - # trainer.evaluate() - # 并行evaluate进行补全 - - # Save and push to hub - trainer.save_model(training_args.output_dir) - if training_args.push_to_hub: - trainer.push_to_hub(dataset_name=script_args.dataset_name) - if trainer.accelerator.is_main_process: - processor.push_to_hub(training_args.hub_model_id) diff --git a/src/todo.md b/src/todo.md index d7fe7b8..6998d65 100644 --- a/src/todo.md +++ b/src/todo.md @@ -4,3 +4,7 @@ - [X] 采用数据集多次训练 - [X] 整理evaluate的代码 + +[2025.01.01] + +- [ ] 处理peft逻辑 diff --git a/src/train.py b/src/train.py index e829cd7..50a6a0b 100644 --- a/src/train.py +++ b/src/train.py @@ -1,15 +1,9 @@ import torch from datasets_library.factory import get_dataset -from transformers import ( - AutoModelForVision2Seq, - AutoProcessor, -) +from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments from trl import ( ModelConfig, - SFTScriptArguments, - SFTConfig, - SFTTrainer, TrlParser, get_kbit_device_map, get_peft_config, @@ -18,14 +12,17 @@ from trl import ( from peft import get_peft_model from utils.trainer import ContinualTrainer +from utils.args import ContinualScriptArguments if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() - script_args: SFTScriptArguments - training_args: SFTConfig - model_args: ModelConfig + # for type hint + if 0 == 1: + script_args = ContinualScriptArguments() + training_args = TrainingArguments() + model_args = ModelConfig() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} @@ -65,91 +62,49 @@ if __name__ == "__main__": collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) + peft_config = get_peft_config(model_args) + model = get_peft_model(model, peft_config) + ################ # Dataset ################ - dataset = get_dataset(script_args) + from utils.accelerator import create_accelerator_and_postprocess - peft_config = get_peft_config(model_args) - model = get_peft_model(model, peft_config) - # 仅在rank1 rank2 rank3时打印 - if torch.distributed.get_rank() in [1]: - print(model.print_trainable_parameters) + accelerator = create_accelerator_and_postprocess(training_args) - # _________________________________________________________ - model.train() - import copy + if accelerator.is_local_main_process: + model.print_trainable_parameters() - training_args_init = copy.copy(training_args) - training_args_init.do_train = False - training_args_init.do_eval = False - training_args_init.do_predict = False - training_args_init.num_train_epochs = 0 - trainer = SFTTrainer( - model=model, - args=training_args_init, - data_collator=collate_fn_for_train, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=( - dataset[script_args.dataset_test_split] - if training_args.eval_strategy != "no" - else None - ), - ) - trainer.train() + for dataset_name in script_args.dataset_name: + dataset = get_dataset(dataset_name) + model.train() - model.eval() - accelerator = trainer.accelerator + trainer = ContinualTrainer( + model=model, + args=training_args, + data_collator=collate_fn_for_train, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + accelerator=accelerator, + ) + trainer.train() - from torch.utils.data import DataLoader + model.eval() + accelerator = trainer.accelerator - val_dataloader = DataLoader( - dataset["generation"], - batch_size=3, - collate_fn=collate_fn_for_evaluate, - ) - val_dataloader = accelerator.prepare(val_dataloader) - from utils.evaluate_tool import evaluate_rouge + from torch.utils.data import DataLoader - evaluate_rouge(model, val_dataloader, processor) + val_dataloader = DataLoader( + dataset[script_args.dataset_generation_split], + batch_size=3, + collate_fn=collate_fn_for_evaluate, + ) + val_dataloader = accelerator.prepare(val_dataloader) + from utils.evaluate_tool import evaluate_rouge - model.train() - trainer = ContinualTrainer( - model=model, - args=training_args, - data_collator=collate_fn_for_train, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=( - dataset[script_args.dataset_test_split] - if training_args.eval_strategy != "no" - else None - ), - accelerator=accelerator, - ) - trainer.train() - - trainer.save_model(training_args.output_dir) - - # 清理cache - torch.cuda.empty_cache() - - # load_model - from transformers import AutoModelForVision2Seq - model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir) - model = accelerator.prepare(model) - - model.eval() - accelerator = trainer.accelerator - - from torch.utils.data import DataLoader - - val_dataloader = DataLoader( - dataset["generation"], - batch_size=3, - collate_fn=collate_fn_for_evaluate, - ) - val_dataloader = accelerator.prepare(val_dataloader) - from utils.evaluate_tool import evaluate_rouge - - evaluate_rouge(model, val_dataloader, processor) + evaluate_rouge(model, val_dataloader, processor) diff --git a/src/train.sh b/src/train.sh index 0836cb0..caa9c8a 100755 --- a/src/train.sh +++ b/src/train.sh @@ -8,7 +8,6 @@ accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py --per_device_train_batch_size 1 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 8 \ - --max_seq_length 1024 \ --output_dir checkpoint/sft-llava-1.5-7b-hf \ --bf16 \ --torch_dtype bfloat16 diff --git a/src/utils/accelerator.py b/src/utils/accelerator.py new file mode 100644 index 0000000..44d6c33 --- /dev/null +++ b/src/utils/accelerator.py @@ -0,0 +1,76 @@ +from accelerate import Accelerator, DataLoaderConfiguration + + +def create_accelerator_and_postprocess(args): + # We explicitly don't rely on the `Accelerator` to do gradient accumulation + grad_acc_kwargs = {} + if args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs: + if args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + else: + args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] + + accelerator_config = args.accelerator_config.to_dict() + + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) + dataloader_config.data_seed = args.data_seed + + non_blocking = accelerator_config.pop("non_blocking") + dataloader_config.non_blocking = non_blocking + # this would have been updated above, no need for it anymore + accelerator_config.pop("gradient_accumulation_kwargs") + + accelerator_args = { + "deepspeed_plugin": args.deepspeed_plugin, + } + accelerator_args["dataloader_config"] = dataloader_config + # create accelerator object + accelerator = Accelerator(**accelerator_args) + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + is_deepspeed_enabled = ( + getattr(accelerator.state, "deepspeed_plugin", None) is not None + ) + is_fsdp_enabled = getattr(accelerator.state, "fsdp_plugin", None) is not None + + # post accelerator creation setup + if is_fsdp_enabled: + fsdp_plugin = accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = args.fsdp_config.get( + "limit_all_gathers", fsdp_plugin.limit_all_gathers + ) + fsdp_plugin.activation_checkpointing = args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) + + def propagate_args_to_deepspeed(auto_find_batch_size=False): + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + ds_plugin = accelerator.state.deepspeed_plugin + + ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) + ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config + ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size) + + if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None: + propagate_args_to_deepspeed() + return accelerator diff --git a/src/utils/args.py b/src/utils/args.py new file mode 100644 index 0000000..e2298a6 --- /dev/null +++ b/src/utils/args.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ContinualScriptArguments: + """Script arguments for continual learning.""" + + dataset_name: list[str] = field( + default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"] + ) + dataset_config: Optional[str] = None + dataset_train_split: str = "train" + dataset_test_split: str = "test" + dataset_generation_split: str = "generation" + gradient_checkpointing_use_reentrant: bool = False + ignore_bias_buffers: bool = False diff --git a/src/utils/trainer.py b/src/utils/trainer.py index 645a1f7..0e86bca 100644 --- a/src/utils/trainer.py +++ b/src/utils/trainer.py @@ -1,9 +1,9 @@ # _________________________________________________________ -from trl import SFTTrainer +from transformers import Trainer -class ContinualTrainer(SFTTrainer): +class ContinualTrainer(Trainer): def __init__( self, model, args, data_collator, train_dataset, eval_dataset, accelerator ): @@ -19,6 +19,7 @@ class ContinualTrainer(SFTTrainer): self.is_fsdp_enabled = ( getattr(self.accelerator.state, "fsdp_plugin", None) is not None ) + self.gather_function = self.accelerator.gather_for_metrics return else: super().create_accelerator_and_postprocess() \ No newline at end of file