Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
3fe2c85f6b
commit
0bc1034f35
15
.vscode/settings.json
vendored
15
.vscode/settings.json
vendored
@ -1,14 +1,19 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"src/transformers_repo/src/",
|
||||
"src/peft_repo/src/"
|
||||
"./src/peft_repo/src/",
|
||||
"./src/transformers_repo/src/",
|
||||
],
|
||||
"python.analysis.exclude": [
|
||||
"dataset/**/*",
|
||||
],
|
||||
"python.languageServer": "Default",
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
"python.analysis.include": [
|
||||
"src/**/*"
|
||||
]
|
||||
// "python.analysis.include": [
|
||||
// "src/**/*"
|
||||
// ],
|
||||
"python.analysis.languageServerMode": "default",
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.analysis.userFileIndexingLimit": -1,
|
||||
"python.analysis.usePullDiagnostics": false,
|
||||
"python.analysis.importFormat": "relative"
|
||||
}
|
@ -2,6 +2,7 @@
|
||||
dependencies = [
|
||||
"absl-py>=2.1.0",
|
||||
"accelerate==1.2.1",
|
||||
"calflops>=0.3.2",
|
||||
"datasets==3.2.0",
|
||||
"deepspeed==0.16.2",
|
||||
"evaluate==0.4.3",
|
||||
@ -20,9 +21,9 @@ dependencies = [
|
||||
"safetensors>=0.5.2",
|
||||
"setuptools>=70.0.0",
|
||||
"soundfile>=0.13.0",
|
||||
"torch==2.5.1+cu124",
|
||||
"torchaudio==2.5.1+cu124",
|
||||
"torchvision==0.20.1+cu124",
|
||||
"torch==2.6.0",
|
||||
"torchaudio==2.6.0",
|
||||
"torchvision==0.21.0",
|
||||
"transformers==4.48.0",
|
||||
"trl==0.13.0",
|
||||
"wandb>=0.19.4",
|
||||
|
@ -11,7 +11,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
num_processes: 3
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
@ -89,4 +89,28 @@ def get_model(model_args: ContinualModelConfig):
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
||||
from transformers.models.qwen2_5_omni import (
|
||||
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||
Qwen2_5OmniProcessor
|
||||
)
|
||||
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
||||
|
@ -1,7 +1,8 @@
|
||||
from transformers import Qwen2AudioProcessor
|
||||
from dataset_library.format import Conversation
|
||||
|
||||
|
||||
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
def collate_fn_for_train(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
@ -60,7 +61,7 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
def collate_fn_for_evaluate(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
@ -91,3 +92,4 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
# answers_ids torch.Size([3, 10])
|
||||
# answers_mask torch.Size([3, 10])
|
||||
return batch
|
||||
|
||||
|
@ -1,9 +1,22 @@
|
||||
from transformers import Qwen2VLProcessor
|
||||
# from transformers import Qwen2VLProcessor
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "transformers_repo/src/")
|
||||
sys.path.insert(0, "peft_repo/src/")
|
||||
import transformers
|
||||
import peft
|
||||
from dataset_library.format import DatasetOutput
|
||||
|
||||
import torch
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Qwen2VLProcessor
|
||||
from transformers import Qwen2_5_VLProcessor
|
||||
|
||||
|
||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: "Qwen2VLProcessor"):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
@ -32,36 +45,36 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
|
||||
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||
# 151644 151645 8948 872 77091 None 151643
|
||||
|
||||
# for i, label in enumerate(labels):
|
||||
# now_index = 0
|
||||
# while now_index < len(label):
|
||||
# if label[now_index] == im_start_token_id:
|
||||
# label[now_index] = -100
|
||||
# now_index += 1
|
||||
# if (
|
||||
# label[now_index] == system_token_id
|
||||
# or label[now_index] == user_token_id
|
||||
# ):
|
||||
# while label[now_index] != im_end_token_id:
|
||||
# label[now_index] = -100
|
||||
# now_index += 1
|
||||
# label[now_index] = -100
|
||||
# elif label[now_index] == assistant_token_id:
|
||||
# label[now_index] = -100
|
||||
# label[now_index + 1] = -100
|
||||
# now_index += 2
|
||||
# while (
|
||||
# now_index < len(label) and label[now_index] != im_end_token_id
|
||||
# ):
|
||||
# now_index += 1
|
||||
# now_index += 1
|
||||
for i, label in enumerate(labels):
|
||||
now_index = 0
|
||||
while now_index < len(label):
|
||||
if label[now_index] == im_start_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
if (
|
||||
label[now_index] == system_token_id
|
||||
or label[now_index] == user_token_id
|
||||
):
|
||||
while label[now_index] != im_end_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
label[now_index] = -100
|
||||
elif label[now_index] == assistant_token_id:
|
||||
label[now_index] = -100
|
||||
label[now_index + 1] = -100
|
||||
now_index += 2
|
||||
while (
|
||||
now_index < len(label) and label[now_index] != im_end_token_id
|
||||
):
|
||||
now_index += 1
|
||||
now_index += 1
|
||||
batch["labels"] = labels
|
||||
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
def collate_fn_for_evaluate(examples, processor: "Qwen2VLProcessor"):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
@ -89,3 +102,68 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
# answers_ids torch.Size([3, 10])
|
||||
# answers_mask torch.Size([3, 10])
|
||||
return batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
from PIL import Image
|
||||
|
||||
# 随机生成一个图片
|
||||
import numpy as np
|
||||
|
||||
random_image = Image.fromarray(
|
||||
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
)
|
||||
example = {
|
||||
"chat": [
|
||||
# {"role": "user", "content": "What is the capital of France?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is the capital of France?",
|
||||
},
|
||||
],
|
||||
}, # Assuming no image for this example
|
||||
{"role": "assistant", "content": "The capital of France is Paris."},
|
||||
],
|
||||
"images": [
|
||||
random_image,
|
||||
random_image,
|
||||
random_image,
|
||||
], # Assuming no images for this example
|
||||
}
|
||||
batch = collate_fn_for_train([example], processor)
|
||||
# print(batch)
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
print(f"{k}: {v.shape}")
|
||||
else:
|
||||
print(f"{k}: {v}")
|
||||
# input_ids: torch.Size([1, 101])
|
||||
# attention_mask: torch.Size([1, 101])
|
||||
# pixel_values: torch.Size([256, 1176])
|
||||
# image_grid_thw: torch.Size([1, 3])
|
||||
# labels: torch.Size([1, 101])
|
||||
# Load model directly
|
||||
from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-3B", torch_dtype="auto", device_map="auto")
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")
|
||||
|
||||
|
||||
|
35
src/test.py
Normal file
35
src/test.py
Normal file
@ -0,0 +1,35 @@
|
||||
# import sys
|
||||
|
||||
# sys.path.insert(0, "transformers_repo/src/")
|
||||
# sys.path.insert(0, "peft_repo/src/")
|
||||
|
||||
# from calflops import calculate_flops_hf
|
||||
|
||||
# batch_size, max_seq_length = 1, 128
|
||||
# model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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")
|
||||
# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params))
|
||||
|
||||
# Transformers Model, such as bert.
|
||||
from calflops import calculate_flops
|
||||
from transformers import AutoModel
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
batch_size, max_seq_length = 1, 128
|
||||
# Load model directly
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
|
||||
flops, macs, params = calculate_flops(
|
||||
model=model,
|
||||
input_shape=(batch_size, max_seq_length),
|
||||
transformer_tokenizer=tokenizer,
|
||||
)
|
||||
print(
|
||||
"Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n"
|
||||
% (flops, macs, params)
|
||||
)
|
||||
# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M
|
16
src/train.py
16
src/train.py
@ -13,7 +13,7 @@ from trl import (
|
||||
TrlParser,
|
||||
)
|
||||
from utils.trainer import ContinualTrainer
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -23,18 +23,19 @@ logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser(
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments)
|
||||
)
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 1 == 0:
|
||||
script_args = ContinualScriptArguments
|
||||
training_args = TrainingArguments
|
||||
model_args = ContinualModelConfig
|
||||
reg_args = ContiunalRegularizationArguments
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
@ -91,6 +92,7 @@ if __name__ == "__main__":
|
||||
else None
|
||||
),
|
||||
accelerator=accelerator,
|
||||
reg_args=reg_args,
|
||||
)
|
||||
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
|
||||
|
@ -1,18 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name refcoco \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type LORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir checkpoint/qwen2_alllinear/ \
|
||||
--learning_rate 1e-4 \
|
||||
--learning_rate 5e-5 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 30 \
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 684f12be1c8f26c46b1eebad50ce21ce6e3378b3
|
||||
Subproject commit c8a4ee5b9daf9865b372a483fd04a984f0b265dc
|
@ -7,6 +7,29 @@ from transformers import (
|
||||
)
|
||||
from .args import ContiunalRegularizationArguments
|
||||
from peft_library.regularizations import EWC, LWF
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
|
||||
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
|
||||
logits = outputs.logits
|
||||
device = logits.device
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:].to(device)
|
||||
# Save memory
|
||||
masks = shift_labels != -100
|
||||
shift_logits = shift_logits[masks]
|
||||
shift_labels = shift_labels[masks]
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="none")
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if num_items_in_batch is None:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
# compat transformers>=4.46
|
||||
loss = loss.sum() / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
||||
class ContinualTrainer(Trainer):
|
||||
@ -18,17 +41,30 @@ class ContinualTrainer(Trainer):
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
accelerator,
|
||||
regularization_args: ContiunalRegularizationArguments = None,
|
||||
reg_args: ContiunalRegularizationArguments = None,
|
||||
):
|
||||
self.accelerator = accelerator
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
# compute_loss_func=ce_loss_func,
|
||||
)
|
||||
|
||||
# if regularization_args.ewc_enable:
|
||||
# self.ewc_lambda = regularization_args.ewc_lambda
|
||||
# # fisher = t
|
||||
if reg_args.ewc_enable:
|
||||
self.ewc_lambda = reg_args.ewc_lambda
|
||||
from peft_library.regularizations.ewc import EWC
|
||||
|
||||
# if regularization_args.lwf_enable:
|
||||
# self.lwf_lambda = regularization_args.lwf_lambda
|
||||
self.EWC = EWC()
|
||||
# fisher = t
|
||||
|
||||
if reg_args.lwf_enable:
|
||||
self.lwf_lambda = reg_args.lwf_lambda
|
||||
from peft_library.regularizations.lwf import LWF
|
||||
|
||||
self.LWF = LWF()
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user