FIX: 调试 _maybe_include_all_linear_layers 函数,添加打印线性模块名称

This commit is contained in:
2025-05-15 15:35:11 +08:00
parent bcb0494f52
commit da99ec4564
16 changed files with 764 additions and 185 deletions
+37 -44
View File
@@ -1,5 +1,6 @@
from PIL import Image
from torch.utils.data import Dataset
from .format import Conversation, ConverstationText, ConverstationImage, DatasetOutput
import json
import os
@@ -25,9 +26,9 @@ class ChemDataset(Dataset):
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 800 or height > 800:
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 800 / max_size
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
@@ -85,31 +86,30 @@ class ChemDataset(Dataset):
answer = self.text_processor(answer)
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": sample["system"],
},
Conversation(
type="system",
content=[ConverstationText(type="text", text=sample["system"])],
),
Conversation(
type="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text", text=f"[vqa] {question.replace('<image>','')}"
),
],
},
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] {question.replace('<image>','')}",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
),
Conversation(
type="assistant", content=[ConverstationText(type="text", text=answer)]
),
]
return {
"image": image,
"chat": chat,
}
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
)
class ChemDatasetForGeneration(ChemDataset):
@@ -127,27 +127,20 @@ class ChemDatasetForGeneration(ChemDataset):
answer = self.text_processor(answer)
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": sample["system"],
},
Conversation(
type="system",
content=[ConverstationText(type="text", text=sample["system"])],
),
Conversation(
type="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text", text=f"[vqa] {question.replace('<image>','')}"
),
],
},
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] {question.replace('<image>','')}",
},
],
},
),
]
from .format import DatasetOutput
return DatasetOutput(
images=[image],
+2 -2
View File
@@ -51,7 +51,7 @@ class GigaspeechDataset(Dataset):
]
return DatasetOutput(
audio=[(audio, sampling_rate)],
audios=[(audio, sampling_rate)],
chat=chat,
original=sample,
)
@@ -83,7 +83,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
]
return DatasetOutput(
audio=[(audio, sampling_rate)],
audios=[(audio, sampling_rate)],
chat=chat,
answer=text,
original=sample,
+1
View File
@@ -0,0 +1 @@
from .factory import get_dataset
+78 -73
View File
@@ -7,84 +7,89 @@ def get_dataset(
dataset_name, base_path="/home/zyy/dataset"
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
if dataset_name == "ocrvqa200k":
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
match dataset_name:
case "ocrvqa200k":
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
if dataset_name == "chem":
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
case "chem":
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
dataset = {
"train": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="train",
),
"test": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
"generation": ChemDatasetForGeneration(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
}
dataset = {
"train": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="train",
),
"test": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
"generation": ChemDatasetForGeneration(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
}
case "gigaspeech":
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
if dataset_name == "gigaspeech":
from .GigaspeechDataset import GigaspeechDataset, GigaspeechDatasetForGeneration
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
case "textvqa":
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
if dataset_name == "textvqa":
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
dataset = {
"train": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
}
case "scienceqa":
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
dataset = {
"train": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
}
if dataset_name == "scienceqa":
from .ScienceQADataset import ScienceQADataset, ScienceQADatasetForGeneration
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
return dataset
+3 -3
View File
@@ -88,8 +88,7 @@ if __name__ == "__main__":
accelerator = create_accelerator_and_postprocess(training_args)
if accelerator.is_local_main_process:
print(model)
accelerator.print(model)
for dataset_name in script_args.dataset_name:
dataset = get_dataset(dataset_name)
@@ -100,10 +99,11 @@ if __name__ == "__main__":
from torch.utils.data import DataLoader
bs = 3 if dataset_name not in ["scienceqa"] else 1
accelerator.print("Batch size:", bs)
val_dataloader = DataLoader(
dataset[script_args.dataset_generation_split],
batch_size=1,
batch_size=bs,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
+1 -1
View File
@@ -1,7 +1,7 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \
--dataset_name scienceqa \
--dataset_name chem \
--use_peft \
--peft_type MMOELORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
+5 -1
View File
@@ -1,4 +1,8 @@
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
from .model import Qwen2VLForConditionalGeneration_modified
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]
__all__ = [
"collate_fn_for_train",
"collate_fn_for_evaluate",
"Qwen2VLForConditionalGeneration_modified",
]
+5 -1
View File
@@ -1,4 +1,8 @@
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
from .model import Qwen2VLForConditionalGeneration_modified
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]
__all__ = [
"collate_fn_for_train",
"collate_fn_for_evaluate",
"Qwen2VLForConditionalGeneration_modified",
]
+23 -23
View File
@@ -32,29 +32,29 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
# 151644 151645 8948 872 77091 None 151643
for i, label in enumerate(labels):
now_index = 0
while now_index < len(label):
if label[now_index] == im_start_token_id:
label[now_index] = -100
now_index += 1
if (
label[now_index] == system_token_id
or label[now_index] == user_token_id
):
while label[now_index] != im_end_token_id:
label[now_index] = -100
now_index += 1
label[now_index] = -100
elif label[now_index] == assistant_token_id:
label[now_index] = -100
label[now_index + 1] = -100
now_index += 2
while (
now_index < len(label) and label[now_index] != im_end_token_id
):
now_index += 1
now_index += 1
# for i, label in enumerate(labels):
# now_index = 0
# while now_index < len(label):
# if label[now_index] == im_start_token_id:
# label[now_index] = -100
# now_index += 1
# if (
# label[now_index] == system_token_id
# or label[now_index] == user_token_id
# ):
# while label[now_index] != im_end_token_id:
# label[now_index] = -100
# now_index += 1
# label[now_index] = -100
# elif label[now_index] == assistant_token_id:
# label[now_index] = -100
# label[now_index + 1] = -100
# now_index += 2
# while (
# now_index < len(label) and label[now_index] != im_end_token_id
# ):
# now_index += 1
# now_index += 1
batch["labels"] = labels
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
+17 -12
View File
@@ -56,7 +56,12 @@ if __name__ == "__main__":
elif model_args.peft_type == "LORA":
from peft.tuners.lora import LoraConfig
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
peft_config = LoraConfig(
target_modules=model_args.lora_target_modules,
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
)
model.add_adapter(peft_config)
@@ -82,7 +87,7 @@ if __name__ == "__main__":
),
accelerator=accelerator,
)
trainer.train()
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if accelerator.is_local_main_process:
print("Saving model")
@@ -93,17 +98,17 @@ if __name__ == "__main__":
# 同步 accelerator
accelerator.wait_for_everyone()
model.eval()
# model.eval()
from torch.utils.data import DataLoader
# from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset[script_args.dataset_generation_split],
batch_size=3,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare(val_dataloader)
# val_dataloader = DataLoader(
# dataset[script_args.dataset_generation_split],
# batch_size=3,
# collate_fn=collate_fn_for_evaluate,
# )
# val_dataloader = accelerator.prepare(val_dataloader)
from utils.evaluate_tool import evaluate_rouge
# from utils.evaluate_tool import evaluate_rouge
evaluate_rouge(model, val_dataloader, processor)
# evaluate_rouge(model, val_dataloader, processor)
+10 -5
View File
@@ -1,15 +1,20 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name scienceqa \
--dataset_name chem \
--use_peft \
--peft_type LORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
--lora_r 8 \
--lora_alpha 32 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 4 \
--output_dir checkpoint/qwen2mmoe/ \
--output_dir checkpoint/qwen2_alllinear/ \
--learning_rate 1e-4 \
--bf16 \
--torch_dtype bfloat16 \
--logging_steps 30
--logging_steps 30 \
--gradient_checkpointing \
--weight_decay 0.1
+5 -4
View File
@@ -1,7 +1,10 @@
from accelerate import Accelerator, DataLoaderConfiguration
from transformers import (
TrainingArguments,
)
def create_accelerator_and_postprocess(args):
def create_accelerator_and_postprocess(args: TrainingArguments):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if args.accelerator_config.gradient_accumulation_kwargs is not None:
@@ -33,9 +36,7 @@ def create_accelerator_and_postprocess(args):
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
accelerator_args = {
"deepspeed_plugin": args.deepspeed_plugin,
}
accelerator_args = {"deepspeed_plugin": args.deepspeed_plugin, "log_with": "wandb"}
accelerator_args["dataloader_config"] = dataloader_config
# create accelerator object
accelerator = Accelerator(**accelerator_args)
+4 -2
View File
@@ -38,8 +38,9 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
mtime = time
# 获取目录最后修改时间
if not os.path.exists(f"results/{mtime}"):
os.makedirs(f"results/{mtime}")
if accelerator.is_local_main_process:
if not os.path.exists(f"results/{mtime}"):
os.makedirs(f"results/{mtime}")
from tqdm import tqdm
@@ -77,6 +78,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
for answer in answers:
f.write(json.dumps(answer) + "\n")
if accelerator.is_local_main_process:
bar.update(1)
accelerator.wait_for_everyone()
+9
View File
@@ -1,6 +1,15 @@
# _________________________________________________________
from transformers.trainer import (
Trainer,
_is_peft_model,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
tpu_spmd_dataloader,
logger,
has_length,
sys,
)
from transformers.trainer import *