diff --git a/src/.gitignore b/src/.gitignore index e49c2cd..5efd1d5 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -1,2 +1,3 @@ checkpoint/* -wandb/* \ No newline at end of file +wandb/* +test.py \ No newline at end of file diff --git a/src/test.py b/src/test.py deleted file mode 100644 index 404eba0..0000000 --- a/src/test.py +++ /dev/null @@ -1,35 +0,0 @@ -# import sys - -# sys.path.insert(0, "transformers_repo/src/") -# sys.path.insert(0, "peft_repo/src/") - -# from calflops import calculate_flops_hf - -# batch_size, max_seq_length = 1, 128 -# model_name = "Qwen/Qwen2.5-VL-3B-Instruct" - -# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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") -# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params)) - -# Transformers Model, such as bert. -from calflops import calculate_flops -from transformers import AutoModel -from transformers import AutoTokenizer - -batch_size, max_seq_length = 1, 128 -# Load model directly -from transformers import AutoProcessor, AutoModelForImageTextToText - -processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") -model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") - -flops, macs, params = calculate_flops( - model=model, - input_shape=(batch_size, max_seq_length), - transformer_tokenizer=tokenizer, -) -print( - "Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n" - % (flops, macs, params) -) -# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M diff --git a/src/train.py b/src/train.py index e171691..aee47d1 100644 --- a/src/train.py +++ b/src/train.py @@ -13,7 +13,11 @@ from trl import ( TrlParser, ) from utils.trainer import ContinualTrainer -from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments +from utils.args import ( + ContinualScriptArguments, + ContinualModelConfig, + ContiunalRegularizationArguments, +) import logging from typing import TYPE_CHECKING @@ -23,19 +27,23 @@ logging.basicConfig(level=logging.INFO) if __name__ == "__main__": parser = TrlParser( - (ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments) + ( + ContinualScriptArguments, + TrainingArguments, + ContinualModelConfig, + ContiunalRegularizationArguments, + ) # type: ignore ) script_args, training_args, model_args, reg_args = parser.parse_args_and_config() # for type hint - if 1 == 0: - script_args = ContinualScriptArguments - training_args = TrainingArguments - model_args = ContinualModelConfig - reg_args = ContiunalRegularizationArguments + if TYPE_CHECKING: + script_args = ContinualScriptArguments() + training_args = TrainingArguments() + model_args = ContinualModelConfig() + reg_args = ContiunalRegularizationArguments() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} - from model_library.factory import get_model @@ -85,9 +93,9 @@ if __name__ == "__main__": model=model, args=training_args, data_collator=collate_fn_for_train, - train_dataset=dataset[script_args.dataset_train_split], + train_dataset=dataset[script_args.dataset_train_split], # type: ignore eval_dataset=( - dataset[script_args.dataset_test_split] + dataset[script_args.dataset_test_split] # type: ignore if training_args.eval_strategy != "no" else None ),