feat: 更新数据集名称为OCR-VQA-200K,优化训练脚本和损失计算逻辑

This commit is contained in:
2025-01-11 23:59:03 +08:00
parent 90b3181f3f
commit e4c4a7b0a0
9 changed files with 224 additions and 113 deletions
+1 -1
View File
@@ -107,7 +107,7 @@ class OCRVQADataset(Dataset):
}
class OCRVQADatasetForGeneration(Dataset):
class OCRVQADatasetForGeneration(OCRVQADataset):
def __getitem__(self, index):
sample = self.data[index]
+1 -1
View File
@@ -6,7 +6,7 @@ def get_dataset(
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
if dataset_name == "OCR_VQA_200K":
if dataset_name == "OCR-VQA-200K":
import os.path as osp
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
+18 -1
View File
@@ -1,5 +1,22 @@
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
is_flash_attn_2_available,
Qwen2VLAttention,
Qwen2VLFlashAttention2,
Qwen2VLDecoderLayer,
Qwen2VLModel,
Qwen2VLForConditionalGeneration,
Qwen2VLConfig,
logger,
apply_rotary_pos_emb_vision,
Cache,
apply_multimodal_rotary_pos_emb,
repeat_kv,
is_flash_attn_greater_or_equal_2_10,
math,
Qwen2VLCausalLMOutputWithPast,
Qwen2VisionTransformerPretrainedModel,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
+19
View File
@@ -72,6 +72,24 @@ if __name__ == "__main__":
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
if model_args.model_name_or_path == "VITA-MLLM/VITA-1.5":
# from transformers import
# from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
print(model)
from model_library.qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
################
# Dataset
################
@@ -104,6 +122,7 @@ if __name__ == "__main__":
for dataset_name in script_args.dataset_name:
dataset = get_dataset(dataset_name)
print(dataset)
model.train()
trainer = ContinualTrainer(
+1 -1
View File
@@ -1,7 +1,7 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name CHEM \
--dataset_name OCR-VQA-200K \
--use_peft \
--peft_type MMOELORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
+2 -5
View File
@@ -62,15 +62,12 @@ def create_accelerator_and_postprocess(args):
"when using FSDP."
)
def propagate_args_to_deepspeed(auto_find_batch_size=False):
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
ds_plugin = accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size)
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
propagate_args_to_deepspeed()
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size=False)
return accelerator
+59 -3
View File
@@ -1,6 +1,11 @@
# _________________________________________________________
from transformers import Trainer
from transformers.trainer import (
Trainer,
_is_peft_model,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
)
class ContinualTrainer(Trainer):
@@ -13,8 +18,7 @@ class ContinualTrainer(Trainer):
def create_accelerator_and_postprocess(self):
if self.accelerator is not None:
self.is_deepspeed_enabled = (
getattr(self.accelerator.state, "deepspeed_plugin", None)
is not None
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
)
self.is_fsdp_enabled = (
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
@@ -22,4 +26,56 @@ class ContinualTrainer(Trainer):
self.gather_function = self.accelerator.gather_for_metrics
return
else:
super().create_accelerator_and_postprocess()
super().create_accelerator_and_postprocess()
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if (
self.label_smoother is not None or self.compute_loss_func is not None
) and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
print(labels)
if labels is not None:
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
if self.compute_loss_func is not None:
loss = self.compute_loss_func(
outputs, labels, num_items_in_batch=num_items_in_batch
)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss