From 70c446e5485af25b73763a1b92bbe5cda651d347 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Wed, 28 May 2025 19:32:13 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E7=BB=93=E6=9E=84=EF=BC=8C=E6=8F=90=E5=8D=87=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7=E5=92=8C=E4=B8=80=E8=87=B4=E6=80=A7=EF=BC=8C?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/.gitignore | 3 ++- src/test.py | 35 ----------------------------------- src/train.py | 30 +++++++++++++++++++----------- 3 files changed, 21 insertions(+), 47 deletions(-) delete mode 100644 src/test.py diff --git a/src/.gitignore b/src/.gitignore index e49c2cd..5efd1d5 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -1,2 +1,3 @@ checkpoint/* -wandb/* \ No newline at end of file +wandb/* +test.py \ No newline at end of file diff --git a/src/test.py b/src/test.py deleted file mode 100644 index 404eba0..0000000 --- a/src/test.py +++ /dev/null @@ -1,35 +0,0 @@ -# import sys - -# sys.path.insert(0, "transformers_repo/src/") -# sys.path.insert(0, "peft_repo/src/") - -# from calflops import calculate_flops_hf - -# batch_size, max_seq_length = 1, 128 -# model_name = "Qwen/Qwen2.5-VL-3B-Instruct" - -# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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") -# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params)) - -# Transformers Model, such as bert. -from calflops import calculate_flops -from transformers import AutoModel -from transformers import AutoTokenizer - -batch_size, max_seq_length = 1, 128 -# Load model directly -from transformers import AutoProcessor, AutoModelForImageTextToText - -processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") -model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") - -flops, macs, params = calculate_flops( - model=model, - input_shape=(batch_size, max_seq_length), - transformer_tokenizer=tokenizer, -) -print( - "Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n" - % (flops, macs, params) -) -# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M diff --git a/src/train.py b/src/train.py index e171691..aee47d1 100644 --- a/src/train.py +++ b/src/train.py @@ -13,7 +13,11 @@ from trl import ( TrlParser, ) from utils.trainer import ContinualTrainer -from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments +from utils.args import ( + ContinualScriptArguments, + ContinualModelConfig, + ContiunalRegularizationArguments, +) import logging from typing import TYPE_CHECKING @@ -23,19 +27,23 @@ logging.basicConfig(level=logging.INFO) if __name__ == "__main__": parser = TrlParser( - (ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments) + ( + ContinualScriptArguments, + TrainingArguments, + ContinualModelConfig, + ContiunalRegularizationArguments, + ) # type: ignore ) script_args, training_args, model_args, reg_args = parser.parse_args_and_config() # for type hint - if 1 == 0: - script_args = ContinualScriptArguments - training_args = TrainingArguments - model_args = ContinualModelConfig - reg_args = ContiunalRegularizationArguments + if TYPE_CHECKING: + script_args = ContinualScriptArguments() + training_args = TrainingArguments() + model_args = ContinualModelConfig() + reg_args = ContiunalRegularizationArguments() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} - from model_library.factory import get_model @@ -85,9 +93,9 @@ if __name__ == "__main__": model=model, args=training_args, data_collator=collate_fn_for_train, - train_dataset=dataset[script_args.dataset_train_split], + train_dataset=dataset[script_args.dataset_train_split], # type: ignore eval_dataset=( - dataset[script_args.dataset_test_split] + dataset[script_args.dataset_test_split] # type: ignore if training_args.eval_strategy != "no" else None ),