refactor: 优化代码结构,提升可读性和一致性,移除不必要的文件
This commit is contained in:
parent
0bc1034f35
commit
70c446e548
3
src/.gitignore
vendored
3
src/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
checkpoint/*
|
||||
wandb/*
|
||||
wandb/*
|
||||
test.py
|
35
src/test.py
35
src/test.py
@ -1,35 +0,0 @@
|
||||
# import sys
|
||||
|
||||
# sys.path.insert(0, "transformers_repo/src/")
|
||||
# sys.path.insert(0, "peft_repo/src/")
|
||||
|
||||
# from calflops import calculate_flops_hf
|
||||
|
||||
# batch_size, max_seq_length = 1, 128
|
||||
# model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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")
|
||||
# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params))
|
||||
|
||||
# Transformers Model, such as bert.
|
||||
from calflops import calculate_flops
|
||||
from transformers import AutoModel
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
batch_size, max_seq_length = 1, 128
|
||||
# Load model directly
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
|
||||
flops, macs, params = calculate_flops(
|
||||
model=model,
|
||||
input_shape=(batch_size, max_seq_length),
|
||||
transformer_tokenizer=tokenizer,
|
||||
)
|
||||
print(
|
||||
"Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n"
|
||||
% (flops, macs, params)
|
||||
)
|
||||
# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M
|
30
src/train.py
30
src/train.py
@ -13,7 +13,11 @@ from trl import (
|
||||
TrlParser,
|
||||
)
|
||||
from utils.trainer import ContinualTrainer
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments
|
||||
from utils.args import (
|
||||
ContinualScriptArguments,
|
||||
ContinualModelConfig,
|
||||
ContiunalRegularizationArguments,
|
||||
)
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -23,19 +27,23 @@ logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser(
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments)
|
||||
(
|
||||
ContinualScriptArguments,
|
||||
TrainingArguments,
|
||||
ContinualModelConfig,
|
||||
ContiunalRegularizationArguments,
|
||||
) # type: ignore
|
||||
)
|
||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 1 == 0:
|
||||
script_args = ContinualScriptArguments
|
||||
training_args = TrainingArguments
|
||||
model_args = ContinualModelConfig
|
||||
reg_args = ContiunalRegularizationArguments
|
||||
if TYPE_CHECKING:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ContinualModelConfig()
|
||||
reg_args = ContiunalRegularizationArguments()
|
||||
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
@ -85,9 +93,9 @@ if __name__ == "__main__":
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn_for_train,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
dataset[script_args.dataset_test_split] # type: ignore
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
|
Loading…
Reference in New Issue
Block a user