refactor: 优化代码结构,提升可读性和一致性,移除不必要的文件
This commit is contained in:
parent
0bc1034f35
commit
70c446e548
3
src/.gitignore
vendored
3
src/.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
checkpoint/*
|
checkpoint/*
|
||||||
wandb/*
|
wandb/*
|
||||||
|
test.py
|
35
src/test.py
35
src/test.py
@ -1,35 +0,0 @@
|
|||||||
# import sys
|
|
||||||
|
|
||||||
# sys.path.insert(0, "transformers_repo/src/")
|
|
||||||
# sys.path.insert(0, "peft_repo/src/")
|
|
||||||
|
|
||||||
# from calflops import calculate_flops_hf
|
|
||||||
|
|
||||||
# batch_size, max_seq_length = 1, 128
|
|
||||||
# model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
||||||
|
|
||||||
# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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")
|
|
||||||
# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params))
|
|
||||||
|
|
||||||
# Transformers Model, such as bert.
|
|
||||||
from calflops import calculate_flops
|
|
||||||
from transformers import AutoModel
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
batch_size, max_seq_length = 1, 128
|
|
||||||
# Load model directly
|
|
||||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
|
||||||
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
|
||||||
|
|
||||||
flops, macs, params = calculate_flops(
|
|
||||||
model=model,
|
|
||||||
input_shape=(batch_size, max_seq_length),
|
|
||||||
transformer_tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n"
|
|
||||||
% (flops, macs, params)
|
|
||||||
)
|
|
||||||
# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M
|
|
30
src/train.py
30
src/train.py
@ -13,7 +13,11 @@ from trl import (
|
|||||||
TrlParser,
|
TrlParser,
|
||||||
)
|
)
|
||||||
from utils.trainer import ContinualTrainer
|
from utils.trainer import ContinualTrainer
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments
|
from utils.args import (
|
||||||
|
ContinualScriptArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContiunalRegularizationArguments,
|
||||||
|
)
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@ -23,19 +27,23 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = TrlParser(
|
parser = TrlParser(
|
||||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments)
|
(
|
||||||
|
ContinualScriptArguments,
|
||||||
|
TrainingArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContiunalRegularizationArguments,
|
||||||
|
) # type: ignore
|
||||||
)
|
)
|
||||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
# for type hint
|
||||||
if 1 == 0:
|
if TYPE_CHECKING:
|
||||||
script_args = ContinualScriptArguments
|
script_args = ContinualScriptArguments()
|
||||||
training_args = TrainingArguments
|
training_args = TrainingArguments()
|
||||||
model_args = ContinualModelConfig
|
model_args = ContinualModelConfig()
|
||||||
reg_args = ContiunalRegularizationArguments
|
reg_args = ContiunalRegularizationArguments()
|
||||||
|
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
||||||
|
|
||||||
|
|
||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
@ -85,9 +93,9 @@ if __name__ == "__main__":
|
|||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=collate_fn_for_train,
|
data_collator=collate_fn_for_train,
|
||||||
train_dataset=dataset[script_args.dataset_train_split],
|
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
||||||
eval_dataset=(
|
eval_dataset=(
|
||||||
dataset[script_args.dataset_test_split]
|
dataset[script_args.dataset_test_split] # type: ignore
|
||||||
if training_args.eval_strategy != "no"
|
if training_args.eval_strategy != "no"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user