From f2f921113ec59c30ea4cab253bfa6d6fa41e824f Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Tue, 31 Dec 2024 17:53:16 +0000 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0.gitignore=E4=BB=A5=E6=8E=92?= =?UTF-8?q?=E9=99=A4=E8=99=9A=E6=8B=9F=E7=8E=AF=E5=A2=83=E5=92=8C=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E6=96=87=E4=BB=B6=EF=BC=8C=E4=BF=AE=E6=94=B9TODO?= =?UTF-8?q?=E5=88=97=E8=A1=A8=EF=BC=8C=E9=87=8D=E5=91=BD=E5=90=8D=E8=AF=84?= =?UTF-8?q?=E4=BC=B0=E8=84=9A=E6=9C=AC=EF=BC=8C=E6=B7=BB=E5=8A=A0=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=92=8C=E8=AF=84=E4=BC=B0=E8=84=9A=E6=9C=AC=EF=BC=8C?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=95=B0=E6=8D=AE=E9=9B=86=E5=B7=A5=E5=8E=82?= =?UTF-8?q?=E5=92=8C=E8=AF=84=E4=BC=B0=E5=B7=A5=E5=85=B7=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +- .../__init__.py | 0 .../qwen2.py | 0 src/datasets_library/factory.py | 30 ++++ src/evaluate.py | 155 ++++++++++++++++++ src/{evaluation.sh => evaluate.sh} | 2 +- src/evaluation.py | 153 ----------------- src/sft_vlm.py | 2 - src/todo.md | 6 +- src/train.py | 155 ++++++++++++++++++ src/train.sh | 15 ++ .../__init__.py | 0 src/utils/evaluate_tool.py | 26 +++ src/utils/trainer.py | 24 +++ 14 files changed, 411 insertions(+), 161 deletions(-) rename src/{collate_fn_library => collatefn_library}/__init__.py (100%) rename src/{collate_fn_library => collatefn_library}/qwen2.py (100%) create mode 100644 src/datasets_library/factory.py create mode 100644 src/evaluate.py rename src/{evaluation.sh => evaluate.sh} (95%) delete mode 100644 src/evaluation.py create mode 100644 src/train.py create mode 100755 src/train.sh rename src/{evaluations_library => utils}/__init__.py (100%) create mode 100644 src/utils/evaluate_tool.py create mode 100644 src/utils/trainer.py diff --git a/.gitignore b/.gitignore index 827d37b..c4cf06a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -**/.venv/ -**/__pycache__/ +**/.venv/* +**/__pycache__/* diff --git a/src/collate_fn_library/__init__.py b/src/collatefn_library/__init__.py similarity index 100% rename from src/collate_fn_library/__init__.py rename to src/collatefn_library/__init__.py diff --git a/src/collate_fn_library/qwen2.py b/src/collatefn_library/qwen2.py similarity index 100% rename from src/collate_fn_library/qwen2.py rename to src/collatefn_library/qwen2.py diff --git a/src/datasets_library/factory.py b/src/datasets_library/factory.py new file mode 100644 index 0000000..8e032c6 --- /dev/null +++ b/src/datasets_library/factory.py @@ -0,0 +1,30 @@ +from torch.utils.data import Dataset +from typing import Literal + + +def get_dataset( + script_args, base_path="/home/zyy/research/accelerate/dataset" +) -> dict[Literal["train", "test", "generation"], Dataset]: + dataset: dict[Literal["train", "test", "generation"], Dataset] = {} + if script_args.dataset_name == "OCR_VQA_200K": + import os.path as osp + from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration + + dataset = { + "train": OCRVQADataset( + osp.join(base_path, "OCR-VQA-200K/images"), + osp.join(base_path, "OCR-VQA-200K/dataset.json"), + split="train", + ), + "test": OCRVQADataset( + osp.join(base_path, "OCR-VQA-200K/images"), + osp.join(base_path, "OCR-VQA-200K/dataset.json"), + split="test", + ), + "generation": OCRVQADatasetForGeneration( + osp.join(base_path, "OCR-VQA-200K/images"), + osp.join(base_path, "OCR-VQA-200K/dataset.json"), + split="test", + ), + } + return dataset diff --git a/src/evaluate.py b/src/evaluate.py new file mode 100644 index 0000000..ed8cfb3 --- /dev/null +++ b/src/evaluate.py @@ -0,0 +1,155 @@ +import torch +from datasets_library.factory import get_dataset +from transformers import ( + AutoModelForVision2Seq, + AutoProcessor, +) + +from trl import ( + ModelConfig, + SFTScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from peft import get_peft_model + +from utils.trainer import ContinualTrainer + + +if __name__ == "__main__": + parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + script_args: SFTScriptArguments = script_args + training_args: SFTConfig = training_args + model_args: ModelConfig = model_args + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + + model = AutoModelForVision2Seq.from_pretrained( + training_args.output_dir, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + + if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": + from collatefn_library.qwen2 import ( + collate_fn_for_train, + collate_fn_for_evaluate, + ) + from functools import partial + + collate_fn_for_train = partial(collate_fn_for_train, processor=processor) + collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) + + ################ + # Dataset + ################ + + dataset = get_dataset(script_args) + + # peft_config = get_peft_config(model_args) + # model = get_peft_model(model, peft_config) + # 仅在rank1 rank2 rank3时打印 + if torch.distributed.get_rank() in [1]: + print(model) + + # _________________________________________________________ + model.train() + import copy + + training_args_init = copy.copy(training_args) + training_args_init.do_train = False + training_args_init.do_eval = False + training_args_init.do_predict = False + training_args_init.num_train_epochs = 0 + trainer = SFTTrainer( + model=model, + args=training_args_init, + data_collator=collate_fn_for_train, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + ) + trainer.train() + + model.eval() + accelerator = trainer.accelerator + + from torch.utils.data import DataLoader + + val_dataloader = DataLoader( + dataset["generation"], + batch_size=3, + collate_fn=collate_fn_for_evaluate, + ) + val_dataloader = accelerator.prepare(val_dataloader) + from utils.evaluate_tool import evaluate_rouge + + evaluate_rouge(model, val_dataloader, processor) + + model.train() + trainer = ContinualTrainer( + model=model, + args=training_args, + data_collator=collate_fn_for_train, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + accelerator=accelerator, + ) + trainer.train() + + trainer.save_model(training_args.output_dir) + + # 清理cache + torch.cuda.empty_cache() + + # load_model + from transformers import AutoModelForVision2Seq + model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir) + model = accelerator.prepare(model) + + model.eval() + accelerator = trainer.accelerator + + from torch.utils.data import DataLoader + + val_dataloader = DataLoader( + dataset["generation"], + batch_size=3, + collate_fn=collate_fn_for_evaluate, + ) + val_dataloader = accelerator.prepare(val_dataloader) + from utils.evaluate_tool import evaluate_rouge + + evaluate_rouge(model, val_dataloader, processor) diff --git a/src/evaluation.sh b/src/evaluate.sh similarity index 95% rename from src/evaluation.sh rename to src/evaluate.sh index 87f8290..b65fe7a 100755 --- a/src/evaluation.sh +++ b/src/evaluate.sh @@ -1,6 +1,6 @@ #!/bin/bash -accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \ +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluate.py \ --dataset_name OCR_VQA_200K \ --use_peft \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ diff --git a/src/evaluation.py b/src/evaluation.py deleted file mode 100644 index 375257d..0000000 --- a/src/evaluation.py +++ /dev/null @@ -1,153 +0,0 @@ -import torch -from datasets import load_dataset -from transformers import ( - AutoModelForVision2Seq, - AutoProcessor, - LlavaForConditionalGeneration, -) - -from trl import ( - ModelConfig, - SFTScriptArguments, - SFTConfig, - SFTTrainer, - TrlParser, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) - - -if __name__ == "__main__": - parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_args = parser.parse_args_and_config() - script_args: SFTScriptArguments - training_args: SFTConfig - model_args: ModelConfig - training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) - training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} - - ################ - # Model, Tokenizer & Processor - ################ - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - padding_side="left", - ) - - model = AutoModelForVision2Seq.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - **model_kwargs, - ) - - if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": - from collate_fn_library.qwen2 import collate_fn_for_train - from functools import partial - - collate_fn_for_train = partial(collate_fn_for_train, processor=processor) - - ################ - # Dataset - ################ - base_path = "/home/zyy/research/accelerate/dataset" - if script_args.dataset_name == "OCR_VQA_200K": - import os.path as osp - from datasets_library.OCRVQADataset import OCRVQADataset - - dataset = { - "train": OCRVQADataset( - osp.join(base_path, "OCR-VQA-200K/images"), - osp.join(base_path, "OCR-VQA-200K/dataset.json"), - split="train", - ), - "test": OCRVQADataset( - osp.join(base_path, "OCR-VQA-200K/images"), - osp.join(base_path, "OCR-VQA-200K/dataset.json"), - split="test", - ), - } - - else: - dataset = load_dataset(script_args.dataset_name, name=script_args.config) - - ################ - # Training - ################ - trainer = SFTTrainer( - model=model, - args=training_args, - data_collator=collate_fn_for_train, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=( - dataset[script_args.dataset_test_split] - if training_args.eval_strategy != "no" - else None - ), - tokenizer=processor.tokenizer, - # processing_class=processor.tokenizer, - peft_config=get_peft_config(model_args), - ) - trainer.train() - - model = trainer.model - # model进gpu - accelerator = trainer.accelerator - # model = accelerator.prepare(model) - - from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration - from collate_fn_library.qwen2 import collate_fn_for_evaluate - - dataset = OCRVQADatasetForGeneration( - vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images", - ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json", - split="train", - ) - examples = [dataset[i] for i in range(3)] - # print(collate_fn_for_evaluate(examples, processor)) - - from torch.utils.data import DataLoader - - val_dataloader = DataLoader( - dataset, - batch_size=3, - collate_fn=lambda x: collate_fn_for_evaluate(x, processor), - ) - val_dataloader = accelerator.prepare(val_dataloader) - - import evaluate - glue = evaluate.load("rouge") - - for batch in val_dataloader: - completion = model.generate( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - pixel_values=batch["pixel_values"], - image_grid_thw=batch["image_grid_thw"], - max_length=1000, - ) - target = batch["answers_ids"] - generated_text = [ - out_ids[len(in_ids) :] - for out_ids, in_ids in zip(completion, batch["input_ids"]) - ] - generated_text = processor.tokenizer.batch_decode(generated_text, skip_special_tokens=True) - target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True) - glue.add_batch(predictions=generated_text, references=target_text) - - print(glue.compute()) diff --git a/src/sft_vlm.py b/src/sft_vlm.py index ae36e00..1148665 100644 --- a/src/sft_vlm.py +++ b/src/sft_vlm.py @@ -99,8 +99,6 @@ if __name__ == "__main__": if training_args.eval_strategy != "no" else None ), - tokenizer=processor.tokenizer, - # processing_class=processor.tokenizer, peft_config=get_peft_config(model_args), ) diff --git a/src/todo.md b/src/todo.md index f99af50..d7fe7b8 100644 --- a/src/todo.md +++ b/src/todo.md @@ -1,6 +1,6 @@ ## TODO: -[2024.12.31] +[2024.12.31] -- [ ] 采用数据集多次训练 -- [ ] 整理evaluate的代码 +- [X] 采用数据集多次训练 +- [X] 整理evaluate的代码 diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..e829cd7 --- /dev/null +++ b/src/train.py @@ -0,0 +1,155 @@ +import torch +from datasets_library.factory import get_dataset +from transformers import ( + AutoModelForVision2Seq, + AutoProcessor, +) + +from trl import ( + ModelConfig, + SFTScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from peft import get_peft_model + +from utils.trainer import ContinualTrainer + + +if __name__ == "__main__": + parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + script_args: SFTScriptArguments + training_args: SFTConfig + model_args: ModelConfig + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + + model = AutoModelForVision2Seq.from_pretrained( + training_args.output_dir, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + + if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": + from collatefn_library.qwen2 import ( + collate_fn_for_train, + collate_fn_for_evaluate, + ) + from functools import partial + + collate_fn_for_train = partial(collate_fn_for_train, processor=processor) + collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) + + ################ + # Dataset + ################ + + dataset = get_dataset(script_args) + + peft_config = get_peft_config(model_args) + model = get_peft_model(model, peft_config) + # 仅在rank1 rank2 rank3时打印 + if torch.distributed.get_rank() in [1]: + print(model.print_trainable_parameters) + + # _________________________________________________________ + model.train() + import copy + + training_args_init = copy.copy(training_args) + training_args_init.do_train = False + training_args_init.do_eval = False + training_args_init.do_predict = False + training_args_init.num_train_epochs = 0 + trainer = SFTTrainer( + model=model, + args=training_args_init, + data_collator=collate_fn_for_train, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + ) + trainer.train() + + model.eval() + accelerator = trainer.accelerator + + from torch.utils.data import DataLoader + + val_dataloader = DataLoader( + dataset["generation"], + batch_size=3, + collate_fn=collate_fn_for_evaluate, + ) + val_dataloader = accelerator.prepare(val_dataloader) + from utils.evaluate_tool import evaluate_rouge + + evaluate_rouge(model, val_dataloader, processor) + + model.train() + trainer = ContinualTrainer( + model=model, + args=training_args, + data_collator=collate_fn_for_train, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + accelerator=accelerator, + ) + trainer.train() + + trainer.save_model(training_args.output_dir) + + # 清理cache + torch.cuda.empty_cache() + + # load_model + from transformers import AutoModelForVision2Seq + model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir) + model = accelerator.prepare(model) + + model.eval() + accelerator = trainer.accelerator + + from torch.utils.data import DataLoader + + val_dataloader = DataLoader( + dataset["generation"], + batch_size=3, + collate_fn=collate_fn_for_evaluate, + ) + val_dataloader = accelerator.prepare(val_dataloader) + from utils.evaluate_tool import evaluate_rouge + + evaluate_rouge(model, val_dataloader, processor) diff --git a/src/train.sh b/src/train.sh new file mode 100755 index 0000000..0836cb0 --- /dev/null +++ b/src/train.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py \ + --dataset_name OCR_VQA_200K \ + --use_peft \ + --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ + --lora_target_modules q_proj v_proj \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --max_seq_length 1024 \ + --output_dir checkpoint/sft-llava-1.5-7b-hf \ + --bf16 \ + --torch_dtype bfloat16 +# --eval_strategy epoch \ diff --git a/src/evaluations_library/__init__.py b/src/utils/__init__.py similarity index 100% rename from src/evaluations_library/__init__.py rename to src/utils/__init__.py diff --git a/src/utils/evaluate_tool.py b/src/utils/evaluate_tool.py new file mode 100644 index 0000000..18ab1f7 --- /dev/null +++ b/src/utils/evaluate_tool.py @@ -0,0 +1,26 @@ +import evaluate + + +def evaluate_rouge(model, val_dataloader, processor): + glue = evaluate.load("rouge") + + for batch in val_dataloader: + completion = model.generate( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + image_grid_thw=batch["image_grid_thw"], + max_length=1000, + ) + target = batch["answers_ids"] + generated_text = [ + out_ids[len(in_ids) :] + for out_ids, in_ids in zip(completion, batch["input_ids"]) + ] + generated_text = processor.tokenizer.batch_decode( + generated_text, skip_special_tokens=True + ) + target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True) + glue.add_batch(predictions=generated_text, references=target_text) + + print(glue.compute()) \ No newline at end of file diff --git a/src/utils/trainer.py b/src/utils/trainer.py new file mode 100644 index 0000000..645a1f7 --- /dev/null +++ b/src/utils/trainer.py @@ -0,0 +1,24 @@ +# _________________________________________________________ + +from trl import SFTTrainer + + +class ContinualTrainer(SFTTrainer): + def __init__( + self, model, args, data_collator, train_dataset, eval_dataset, accelerator + ): + self.accelerator = accelerator + super().__init__(model, args, data_collator, train_dataset, eval_dataset) + + def create_accelerator_and_postprocess(self): + if self.accelerator is not None: + self.is_deepspeed_enabled = ( + getattr(self.accelerator.state, "deepspeed_plugin", None) + is not None + ) + self.is_fsdp_enabled = ( + getattr(self.accelerator.state, "fsdp_plugin", None) is not None + ) + return + else: + super().create_accelerator_and_postprocess() \ No newline at end of file