refactor: 优化代码结构,提升可读性和一致性,移除不必要的文件

This commit is contained in:
YunyaoZhou 2025-05-28 19:32:13 +08:00
parent 0bc1034f35
commit 70c446e548
3 changed files with 21 additions and 47 deletions

1
src/.gitignore vendored
View File

@ -1,2 +1,3 @@
checkpoint/*
wandb/*
test.py

View File

@ -1,35 +0,0 @@
# import sys
# sys.path.insert(0, "transformers_repo/src/")
# sys.path.insert(0, "peft_repo/src/")
# from calflops import calculate_flops_hf
# batch_size, max_seq_length = 1, 128
# model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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")
# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params))
# Transformers Model, such as bert.
from calflops import calculate_flops
from transformers import AutoModel
from transformers import AutoTokenizer
batch_size, max_seq_length = 1, 128
# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
flops, macs, params = calculate_flops(
model=model,
input_shape=(batch_size, max_seq_length),
transformer_tokenizer=tokenizer,
)
print(
"Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n"
% (flops, macs, params)
)
# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M

View File

@ -13,7 +13,11 @@ from trl import (
TrlParser,
)
from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments
from utils.args import (
ContinualScriptArguments,
ContinualModelConfig,
ContiunalRegularizationArguments,
)
import logging
from typing import TYPE_CHECKING
@ -23,19 +27,23 @@ logging.basicConfig(level=logging.INFO)
if __name__ == "__main__":
parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments)
(
ContinualScriptArguments,
TrainingArguments,
ContinualModelConfig,
ContiunalRegularizationArguments,
) # type: ignore
)
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
# for type hint
if 1 == 0:
script_args = ContinualScriptArguments
training_args = TrainingArguments
model_args = ContinualModelConfig
reg_args = ContiunalRegularizationArguments
if TYPE_CHECKING:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ContinualModelConfig()
reg_args = ContiunalRegularizationArguments()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
from model_library.factory import get_model
@ -85,9 +93,9 @@ if __name__ == "__main__":
model=model,
args=training_args,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
eval_dataset=(
dataset[script_args.dataset_test_split]
dataset[script_args.dataset_test_split] # type: ignore
if training_args.eval_strategy != "no"
else None
),