feat✨: 更新数据集名称为OCR-VQA-200K,优化训练脚本和损失计算逻辑
This commit is contained in:
@@ -107,7 +107,7 @@ class OCRVQADataset(Dataset):
|
||||
}
|
||||
|
||||
|
||||
class OCRVQADatasetForGeneration(Dataset):
|
||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
|
||||
@@ -6,7 +6,7 @@ def get_dataset(
|
||||
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
|
||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||
if dataset_name == "OCR_VQA_200K":
|
||||
if dataset_name == "OCR-VQA-200K":
|
||||
import os.path as osp
|
||||
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
|
||||
|
||||
@@ -1,5 +1,22 @@
|
||||
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
is_flash_attn_2_available,
|
||||
Qwen2VLAttention,
|
||||
Qwen2VLFlashAttention2,
|
||||
Qwen2VLDecoderLayer,
|
||||
Qwen2VLModel,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLConfig,
|
||||
logger,
|
||||
apply_rotary_pos_emb_vision,
|
||||
Cache,
|
||||
apply_multimodal_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
math,
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
Qwen2VisionTransformerPretrainedModel,
|
||||
)
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
@@ -72,6 +72,24 @@ if __name__ == "__main__":
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "VITA-MLLM/VITA-1.5":
|
||||
# from transformers import
|
||||
# from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
print(model)
|
||||
from model_library.qwen2vl import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
@@ -104,6 +122,7 @@ if __name__ == "__main__":
|
||||
|
||||
for dataset_name in script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_name)
|
||||
print(dataset)
|
||||
model.train()
|
||||
|
||||
trainer = ContinualTrainer(
|
||||
|
||||
+1
-1
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name CHEM \
|
||||
--dataset_name OCR-VQA-200K \
|
||||
--use_peft \
|
||||
--peft_type MMOELORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
|
||||
@@ -62,15 +62,12 @@ def create_accelerator_and_postprocess(args):
|
||||
"when using FSDP."
|
||||
)
|
||||
|
||||
def propagate_args_to_deepspeed(auto_find_batch_size=False):
|
||||
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
|
||||
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
||||
|
||||
ds_plugin = accelerator.state.deepspeed_plugin
|
||||
|
||||
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
|
||||
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
|
||||
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size)
|
||||
|
||||
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
|
||||
propagate_args_to_deepspeed()
|
||||
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size=False)
|
||||
return accelerator
|
||||
|
||||
+59
-3
@@ -1,6 +1,11 @@
|
||||
# _________________________________________________________
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers.trainer import (
|
||||
Trainer,
|
||||
_is_peft_model,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
|
||||
class ContinualTrainer(Trainer):
|
||||
@@ -13,8 +18,7 @@ class ContinualTrainer(Trainer):
|
||||
def create_accelerator_and_postprocess(self):
|
||||
if self.accelerator is not None:
|
||||
self.is_deepspeed_enabled = (
|
||||
getattr(self.accelerator.state, "deepspeed_plugin", None)
|
||||
is not None
|
||||
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
)
|
||||
self.is_fsdp_enabled = (
|
||||
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
@@ -22,4 +26,56 @@ class ContinualTrainer(Trainer):
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
return
|
||||
else:
|
||||
super().create_accelerator_and_postprocess()
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if (
|
||||
self.label_smoother is not None or self.compute_loss_func is not None
|
||||
) and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
print(labels)
|
||||
|
||||
if labels is not None:
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
if _is_peft_model(unwrapped_model):
|
||||
model_name = unwrapped_model.base_model.model._get_name()
|
||||
else:
|
||||
model_name = unwrapped_model._get_name()
|
||||
# User-defined compute_loss function
|
||||
if self.compute_loss_func is not None:
|
||||
loss = self.compute_loss_func(
|
||||
outputs, labels, num_items_in_batch=num_items_in_batch
|
||||
)
|
||||
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||
else:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
else:
|
||||
if isinstance(outputs, dict) and "loss" not in outputs:
|
||||
raise ValueError(
|
||||
"The model did not return a loss from the inputs, only the following keys: "
|
||||
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
||||
)
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
Reference in New Issue
Block a user