Refactor code structure for improved readability and maintainability

This commit is contained in:
YunyaoZhou 2025-05-27 16:09:48 +08:00
parent 3fe2c85f6b
commit 0bc1034f35
13 changed files with 2226 additions and 947 deletions

15
.vscode/settings.json vendored
View File

@ -1,14 +1,19 @@
{ {
"python.analysis.extraPaths": [ "python.analysis.extraPaths": [
"src/transformers_repo/src/", "./src/peft_repo/src/",
"src/peft_repo/src/" "./src/transformers_repo/src/",
], ],
"python.analysis.exclude": [ "python.analysis.exclude": [
"dataset/**/*", "dataset/**/*",
], ],
"python.languageServer": "Default", "python.languageServer": "Default",
"python.terminal.activateEnvInCurrentTerminal": true, "python.terminal.activateEnvInCurrentTerminal": true,
"python.analysis.include": [ // "python.analysis.include": [
"src/**/*" // "src/**/*"
] // ],
"python.analysis.languageServerMode": "default",
"python.analysis.typeCheckingMode": "basic",
"python.analysis.userFileIndexingLimit": -1,
"python.analysis.usePullDiagnostics": false,
"python.analysis.importFormat": "relative"
} }

View File

@ -2,6 +2,7 @@
dependencies = [ dependencies = [
"absl-py>=2.1.0", "absl-py>=2.1.0",
"accelerate==1.2.1", "accelerate==1.2.1",
"calflops>=0.3.2",
"datasets==3.2.0", "datasets==3.2.0",
"deepspeed==0.16.2", "deepspeed==0.16.2",
"evaluate==0.4.3", "evaluate==0.4.3",
@ -20,9 +21,9 @@ dependencies = [
"safetensors>=0.5.2", "safetensors>=0.5.2",
"setuptools>=70.0.0", "setuptools>=70.0.0",
"soundfile>=0.13.0", "soundfile>=0.13.0",
"torch==2.5.1+cu124", "torch==2.6.0",
"torchaudio==2.5.1+cu124", "torchaudio==2.6.0",
"torchvision==0.20.1+cu124", "torchvision==0.21.0",
"transformers==4.48.0", "transformers==4.48.0",
"trl==0.13.0", "trl==0.13.0",
"wandb>=0.19.4", "wandb>=0.19.4",

View File

@ -11,7 +11,7 @@ machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: 'bf16' mixed_precision: 'bf16'
num_machines: 1 num_machines: 1
num_processes: 4 num_processes: 3
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@ -89,4 +89,28 @@ def get_model(model_args: ContinualModelConfig):
collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
from transformers.models.qwen2_5_omni import (
Qwen2_5OmniThinkerForConditionalGeneration,
Qwen2_5OmniProcessor
)
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
processor = Qwen2_5OmniProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
return model, processor, collate_fn_for_train, collate_fn_for_evaluate return model, processor, collate_fn_for_train, collate_fn_for_evaluate

View File

@ -1,7 +1,8 @@
from transformers import Qwen2AudioProcessor from transformers import Qwen2AudioProcessor
from dataset_library.format import Conversation
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor): def collate_fn_for_train(examples:list[Conversation], processor: Qwen2AudioProcessor):
# Get the texts and images, and apply the chat template # Get the texts and images, and apply the chat template
texts = [ texts = [
processor.apply_chat_template(example["chat"], tokenize=False) processor.apply_chat_template(example["chat"], tokenize=False)
@ -60,7 +61,7 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
return batch return batch
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor): def collate_fn_for_evaluate(examples:list[Conversation], processor: Qwen2AudioProcessor):
# Get the texts and images, and apply the chat template # Get the texts and images, and apply the chat template
texts = [ texts = [
processor.apply_chat_template( processor.apply_chat_template(
@ -91,3 +92,4 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
# answers_ids torch.Size([3, 10]) # answers_ids torch.Size([3, 10])
# answers_mask torch.Size([3, 10]) # answers_mask torch.Size([3, 10])
return batch return batch

View File

@ -1,9 +1,22 @@
from transformers import Qwen2VLProcessor # from transformers import Qwen2VLProcessor
import sys
sys.path.insert(0, "transformers_repo/src/")
sys.path.insert(0, "peft_repo/src/")
import transformers
import peft
from dataset_library.format import DatasetOutput from dataset_library.format import DatasetOutput
import torch import torch
from typing import TYPE_CHECKING
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor): if TYPE_CHECKING:
from transformers import Qwen2VLProcessor
from transformers import Qwen2_5_VLProcessor
def collate_fn_for_train(examples: list[DatasetOutput], processor: "Qwen2VLProcessor"):
# Get the texts and images, and apply the chat template # Get the texts and images, and apply the chat template
texts = [ texts = [
processor.apply_chat_template(example["chat"], tokenize=False) processor.apply_chat_template(example["chat"], tokenize=False)
@ -32,36 +45,36 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id) # print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
# 151644 151645 8948 872 77091 None 151643 # 151644 151645 8948 872 77091 None 151643
# for i, label in enumerate(labels): for i, label in enumerate(labels):
# now_index = 0 now_index = 0
# while now_index < len(label): while now_index < len(label):
# if label[now_index] == im_start_token_id: if label[now_index] == im_start_token_id:
# label[now_index] = -100 label[now_index] = -100
# now_index += 1 now_index += 1
# if ( if (
# label[now_index] == system_token_id label[now_index] == system_token_id
# or label[now_index] == user_token_id or label[now_index] == user_token_id
# ): ):
# while label[now_index] != im_end_token_id: while label[now_index] != im_end_token_id:
# label[now_index] = -100 label[now_index] = -100
# now_index += 1 now_index += 1
# label[now_index] = -100 label[now_index] = -100
# elif label[now_index] == assistant_token_id: elif label[now_index] == assistant_token_id:
# label[now_index] = -100 label[now_index] = -100
# label[now_index + 1] = -100 label[now_index + 1] = -100
# now_index += 2 now_index += 2
# while ( while (
# now_index < len(label) and label[now_index] != im_end_token_id now_index < len(label) and label[now_index] != im_end_token_id
# ): ):
# now_index += 1 now_index += 1
# now_index += 1 now_index += 1
batch["labels"] = labels batch["labels"] = labels
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long) # batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
return batch return batch
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): def collate_fn_for_evaluate(examples, processor: "Qwen2VLProcessor"):
# Get the texts and images, and apply the chat template # Get the texts and images, and apply the chat template
texts = [ texts = [
processor.apply_chat_template( processor.apply_chat_template(
@ -89,3 +102,68 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
# answers_ids torch.Size([3, 10]) # answers_ids torch.Size([3, 10])
# answers_mask torch.Size([3, 10]) # answers_mask torch.Size([3, 10])
return batch return batch
if __name__ == "__main__":
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
from PIL import Image
# 随机生成一个图片
import numpy as np
random_image = Image.fromarray(
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
)
example = {
"chat": [
# {"role": "user", "content": "What is the capital of France?"},
{
"role": "user",
"content": [
{
"type": "image",
"image_url": "", # Assuming no image for this example
},
{
"type": "image",
"image_url": "", # Assuming no image for this example
},
{
"type": "image",
"image_url": "", # Assuming no image for this example
},
{
"type": "text",
"text": "What is the capital of France?",
},
],
}, # Assuming no image for this example
{"role": "assistant", "content": "The capital of France is Paris."},
],
"images": [
random_image,
random_image,
random_image,
], # Assuming no images for this example
}
batch = collate_fn_for_train([example], processor)
# print(batch)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
print(f"{k}: {v.shape}")
else:
print(f"{k}: {v}")
# input_ids: torch.Size([1, 101])
# attention_mask: torch.Size([1, 101])
# pixel_values: torch.Size([256, 1176])
# image_grid_thw: torch.Size([1, 3])
# labels: torch.Size([1, 101])
# Load model directly
from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-3B", torch_dtype="auto", device_map="auto")
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")

35
src/test.py Normal file
View File

@ -0,0 +1,35 @@
# import sys
# sys.path.insert(0, "transformers_repo/src/")
# sys.path.insert(0, "peft_repo/src/")
# from calflops import calculate_flops_hf
# batch_size, max_seq_length = 1, 128
# model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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")
# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params))
# Transformers Model, such as bert.
from calflops import calculate_flops
from transformers import AutoModel
from transformers import AutoTokenizer
batch_size, max_seq_length = 1, 128
# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
flops, macs, params = calculate_flops(
model=model,
input_shape=(batch_size, max_seq_length),
transformer_tokenizer=tokenizer,
)
print(
"Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n"
% (flops, macs, params)
)
# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M

View File

@ -13,7 +13,7 @@ from trl import (
TrlParser, TrlParser,
) )
from utils.trainer import ContinualTrainer from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments, ContinualModelConfig from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -23,18 +23,19 @@ logging.basicConfig(level=logging.INFO)
if __name__ == "__main__": if __name__ == "__main__":
parser = TrlParser( parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig) (ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments)
) )
script_args, training_args, model_args = parser.parse_args_and_config() script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# for type hint # for type hint
if 1 == 0: if 1 == 0:
script_args = ContinualScriptArguments script_args = ContinualScriptArguments
training_args = TrainingArguments training_args = TrainingArguments
model_args = ContinualModelConfig model_args = ContinualModelConfig
reg_args = ContiunalRegularizationArguments
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
from model_library.factory import get_model from model_library.factory import get_model
@ -91,6 +92,7 @@ if __name__ == "__main__":
else None else None
), ),
accelerator=accelerator, accelerator=accelerator,
reg_args=reg_args,
) )
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

View File

@ -1,18 +1,19 @@
#!/bin/bash #!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
--dataset_name refcoco \ --dataset_name textvqa \
--use_peft \ --use_peft \
--peft_type LORA \ --peft_type LORA \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ --model_name_or_path Qwen/Qwen2.5-Omni-3B \
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \ --lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
--lora_r 8 \ --lora_r 8 \
--lora_alpha 32 \ --lora_alpha 32 \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 3 \
--per_device_eval_batch_size 1 \ --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--num_train_epochs 1 \
--output_dir checkpoint/qwen2_alllinear/ \ --output_dir checkpoint/qwen2_alllinear/ \
--learning_rate 1e-4 \ --learning_rate 5e-5 \
--bf16 \ --bf16 \
--torch_dtype bfloat16 \ --torch_dtype bfloat16 \
--logging_steps 30 \ --logging_steps 30 \

@ -1 +1 @@
Subproject commit 684f12be1c8f26c46b1eebad50ce21ce6e3378b3 Subproject commit c8a4ee5b9daf9865b372a483fd04a984f0b265dc

View File

@ -7,6 +7,29 @@ from transformers import (
) )
from .args import ContiunalRegularizationArguments from .args import ContiunalRegularizationArguments
from peft_library.regularizations import EWC, LWF from peft_library.regularizations import EWC, LWF
from torch.nn import CrossEntropyLoss
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
logits = outputs.logits
device = logits.device
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:].to(device)
# Save memory
masks = shift_labels != -100
shift_logits = shift_logits[masks]
shift_labels = shift_labels[masks]
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits, shift_labels)
if num_items_in_batch is None:
loss = loss.mean()
else:
# compat transformers>=4.46
loss = loss.sum() / num_items_in_batch
return loss
class ContinualTrainer(Trainer): class ContinualTrainer(Trainer):
@ -18,17 +41,30 @@ class ContinualTrainer(Trainer):
train_dataset, train_dataset,
eval_dataset, eval_dataset,
accelerator, accelerator,
regularization_args: ContiunalRegularizationArguments = None, reg_args: ContiunalRegularizationArguments = None,
): ):
self.accelerator = accelerator self.accelerator = accelerator
super().__init__(model, args, data_collator, train_dataset, eval_dataset) super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
# compute_loss_func=ce_loss_func,
)
# if regularization_args.ewc_enable: if reg_args.ewc_enable:
# self.ewc_lambda = regularization_args.ewc_lambda self.ewc_lambda = reg_args.ewc_lambda
# # fisher = t from peft_library.regularizations.ewc import EWC
# if regularization_args.lwf_enable: self.EWC = EWC()
# self.lwf_lambda = regularization_args.lwf_lambda # fisher = t
if reg_args.lwf_enable:
self.lwf_lambda = reg_args.lwf_lambda
from peft_library.regularizations.lwf import LWF
self.LWF = LWF()
def create_accelerator_and_postprocess(self): def create_accelerator_and_postprocess(self):

2875
uv.lock generated

File diff suppressed because it is too large Load Diff