添加OCRVQADataset类及评估脚本,更新训练脚本以支持新的数据集和评估策略
This commit is contained in:
parent
09734720b0
commit
d6b4ec79ad
0
src/collate_fn_library/__init__.py
Normal file
0
src/collate_fn_library/__init__.py
Normal file
125
src/collate_fn_library/qwen2.py
Normal file
125
src/collate_fn_library/qwen2.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from transformers import Qwen2VLProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
|
||||||
|
# Get the texts and images, and apply the chat template
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
# print(texts)
|
||||||
|
images = [example["image"] for example in examples]
|
||||||
|
|
||||||
|
# Tokenize the texts and process the images
|
||||||
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
# 对<|im_start|>system *** <|im_end|>\n加掩码
|
||||||
|
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||||
|
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
|
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
|
||||||
|
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
|
||||||
|
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
|
||||||
|
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||||
|
# 151644 151645 8948 872 77091 None 151643
|
||||||
|
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
now_index = 0
|
||||||
|
while now_index < len(label):
|
||||||
|
if label[now_index] == im_start_token_id:
|
||||||
|
label[now_index] = -100
|
||||||
|
now_index += 1
|
||||||
|
if (
|
||||||
|
label[now_index] == system_token_id
|
||||||
|
or label[now_index] == user_token_id
|
||||||
|
):
|
||||||
|
while label[now_index] != im_end_token_id:
|
||||||
|
label[now_index] = -100
|
||||||
|
now_index += 1
|
||||||
|
label[now_index] = -100
|
||||||
|
elif label[now_index] == assistant_token_id:
|
||||||
|
label[now_index] = -100
|
||||||
|
label[now_index + 1] = -100
|
||||||
|
now_index += 2
|
||||||
|
while (
|
||||||
|
now_index < len(label) and label[now_index] != im_end_token_id
|
||||||
|
):
|
||||||
|
now_index += 1
|
||||||
|
now_index += 1
|
||||||
|
batch["labels"] = labels
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||||
|
# Get the texts and images, and apply the chat template
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
example["chat"], tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
# print(texts)
|
||||||
|
images = [example["image"] for example in examples]
|
||||||
|
|
||||||
|
# Tokenize the texts and process the images
|
||||||
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
answers = [example["answer"] for example in examples]
|
||||||
|
answers = processor(text=answers, return_tensors="pt", padding=True)
|
||||||
|
batch["answers_ids"] = answers["input_ids"]
|
||||||
|
batch["answers_mask"] = answers["attention_mask"]
|
||||||
|
# input_ids torch.Size([3, 370])
|
||||||
|
# attention_mask torch.Size([3, 370])
|
||||||
|
# pixel_values torch.Size([3888, 1176])
|
||||||
|
# image_grid_thw torch.Size([3, 3])
|
||||||
|
# answers_ids torch.Size([3, 10])
|
||||||
|
# answers_mask torch.Size([3, 10])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from transformers import Qwen2VLProcessor
|
||||||
|
from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration
|
||||||
|
|
||||||
|
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
dataset = OCRVQADatasetForGeneration(
|
||||||
|
vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
|
||||||
|
ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
examples = [dataset[i] for i in range(3)]
|
||||||
|
# print(collate_fn_for_evaluate(examples, processor))
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=3,
|
||||||
|
collate_fn=lambda x: collate_fn_for_evaluate(x, processor),
|
||||||
|
)
|
||||||
|
accelerator = Accelerator()
|
||||||
|
model = accelerator.prepare(model)
|
||||||
|
val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
|
||||||
|
for batch in val_dataloader:
|
||||||
|
completion = model.generate(
|
||||||
|
input_ids=batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
pixel_values=batch["pixel_values"],
|
||||||
|
image_grid_thw=batch["image_grid_thw"],
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
target = batch["answers_ids"]
|
||||||
|
generated_text = [out_ids[len(in_ids) :] for out_ids, in_ids in zip(completion, batch["input_ids"])]
|
||||||
|
generated_text = processor.tokenizer.batch_decode(generated_text)
|
||||||
|
target_text = processor.tokenizer.batch_decode(target)
|
||||||
|
print(generated_text, target_text)
|
||||||
|
|
@ -5,6 +5,88 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
class OCRVQADataset(Dataset):
|
class OCRVQADataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
self.vis_root = vis_root
|
||||||
|
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
if split == "train":
|
||||||
|
self.data = self.create_data(ann_path, split=1)[:1]
|
||||||
|
elif split == "test":
|
||||||
|
self.data = self.create_data(ann_path, split=3)[:1]
|
||||||
|
|
||||||
|
# self.instruction_pool = [
|
||||||
|
# "[vqa] {}",
|
||||||
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
def create_data(self, ann_path, split=1):
|
||||||
|
processed_data = []
|
||||||
|
with open(ann_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for k in data.keys():
|
||||||
|
if data[k]["split"] != split:
|
||||||
|
continue # 1 for training, 2 for validation, 3 for test
|
||||||
|
ext = os.path.splitext(data[k]["imageURL"])[1]
|
||||||
|
imageFile = k + ext
|
||||||
|
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
||||||
|
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
||||||
|
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
||||||
|
processed_data.append(
|
||||||
|
{
|
||||||
|
"question": q,
|
||||||
|
"answer": a,
|
||||||
|
"image_path": imageFile,
|
||||||
|
"image_id": k,
|
||||||
|
"title": data[k]["title"],
|
||||||
|
"genre": data[k]["genre"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return processed_data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
|
"RGB"
|
||||||
|
)
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"chat": chat,
|
||||||
|
"image_id": sample["image_id"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OCRVQADatasetForGeneration(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
):
|
):
|
||||||
@ -73,14 +155,15 @@ class OCRVQADataset(Dataset):
|
|||||||
{"type": "image"},
|
{"type": "image"},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"content": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
}
|
||||||
{"role": "assistant", "content": answer},
|
# {"role": "assistant", "content": answer},
|
||||||
]
|
]
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"image": image,
|
||||||
"chat": chat,
|
"chat": chat,
|
||||||
|
"answer": answer,
|
||||||
"image_id": sample["image_id"],
|
"image_id": sample["image_id"],
|
||||||
}
|
}
|
||||||
|
153
src/evaluation.py
Normal file
153
src/evaluation.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForVision2Seq,
|
||||||
|
AutoProcessor,
|
||||||
|
LlavaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from trl import (
|
||||||
|
ModelConfig,
|
||||||
|
SFTScriptArguments,
|
||||||
|
SFTConfig,
|
||||||
|
SFTTrainer,
|
||||||
|
TrlParser,
|
||||||
|
get_kbit_device_map,
|
||||||
|
get_peft_config,
|
||||||
|
get_quantization_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
||||||
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||||
|
script_args: SFTScriptArguments
|
||||||
|
training_args: SFTConfig
|
||||||
|
model_args: ModelConfig
|
||||||
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
|
training_args.remove_unused_columns = False
|
||||||
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
|
################
|
||||||
|
# Model, Tokenizer & Processor
|
||||||
|
################
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
|
quantization_config = get_quantization_config(model_args)
|
||||||
|
model_kwargs = dict(
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
|
from collate_fn_library.qwen2 import collate_fn_for_train
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
|
||||||
|
################
|
||||||
|
# Dataset
|
||||||
|
################
|
||||||
|
base_path = "/home/zyy/research/accelerate/dataset"
|
||||||
|
if script_args.dataset_name == "OCR_VQA_200K":
|
||||||
|
import os.path as osp
|
||||||
|
from datasets_library.OCRVQADataset import OCRVQADataset
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": OCRVQADataset(
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": OCRVQADataset(
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(script_args.dataset_name, name=script_args.config)
|
||||||
|
|
||||||
|
################
|
||||||
|
# Training
|
||||||
|
################
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
data_collator=collate_fn_for_train,
|
||||||
|
train_dataset=dataset[script_args.dataset_train_split],
|
||||||
|
eval_dataset=(
|
||||||
|
dataset[script_args.dataset_test_split]
|
||||||
|
if training_args.eval_strategy != "no"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
tokenizer=processor.tokenizer,
|
||||||
|
# processing_class=processor.tokenizer,
|
||||||
|
peft_config=get_peft_config(model_args),
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
model = trainer.model
|
||||||
|
# model进gpu
|
||||||
|
accelerator = trainer.accelerator
|
||||||
|
# model = accelerator.prepare(model)
|
||||||
|
|
||||||
|
from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration
|
||||||
|
from collate_fn_library.qwen2 import collate_fn_for_evaluate
|
||||||
|
|
||||||
|
dataset = OCRVQADatasetForGeneration(
|
||||||
|
vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
|
||||||
|
ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
examples = [dataset[i] for i in range(3)]
|
||||||
|
# print(collate_fn_for_evaluate(examples, processor))
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=3,
|
||||||
|
collate_fn=lambda x: collate_fn_for_evaluate(x, processor),
|
||||||
|
)
|
||||||
|
val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
|
||||||
|
import evaluate
|
||||||
|
glue = evaluate.load("rouge")
|
||||||
|
|
||||||
|
for batch in val_dataloader:
|
||||||
|
completion = model.generate(
|
||||||
|
input_ids=batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
pixel_values=batch["pixel_values"],
|
||||||
|
image_grid_thw=batch["image_grid_thw"],
|
||||||
|
max_length=1000,
|
||||||
|
)
|
||||||
|
target = batch["answers_ids"]
|
||||||
|
generated_text = [
|
||||||
|
out_ids[len(in_ids) :]
|
||||||
|
for out_ids, in_ids in zip(completion, batch["input_ids"])
|
||||||
|
]
|
||||||
|
generated_text = processor.tokenizer.batch_decode(generated_text, skip_special_tokens=True)
|
||||||
|
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
||||||
|
glue.add_batch(predictions=generated_text, references=target_text)
|
||||||
|
|
||||||
|
print(glue.compute())
|
15
src/evaluation.sh
Executable file
15
src/evaluation.sh
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \
|
||||||
|
--dataset_name OCR_VQA_200K \
|
||||||
|
--use_peft \
|
||||||
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
--lora_target_modules q_proj v_proj \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--max_seq_length 1024 \
|
||||||
|
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||||
|
--bf16 \
|
||||||
|
--torch_dtype bfloat16
|
||||||
|
# --eval_strategy epoch \
|
0
src/evaluations_library/__init__.py
Normal file
0
src/evaluations_library/__init__.py
Normal file
8
src/run.sh
Normal file → Executable file
8
src/run.sh
Normal file → Executable file
@ -6,7 +6,11 @@ accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml sft_vlm.
|
|||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
--lora_target_modules q_proj v_proj \
|
--lora_target_modules q_proj v_proj \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 2 \
|
||||||
--gradient_accumulation_steps 8 \
|
--gradient_accumulation_steps 8 \
|
||||||
--output_dir sft-llava-1.5-7b-hf \
|
--max_seq_length 1024 \
|
||||||
|
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
--torch_dtype bfloat16
|
--torch_dtype bfloat16 \
|
||||||
|
--eval_strategy epoch \
|
||||||
|
|
||||||
|
@ -56,73 +56,11 @@ if __name__ == "__main__":
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
################
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
# Create a data collator to encode text and image pairs
|
from collate_fn_library.qwen2 import collate_fn_for_train
|
||||||
################
|
from functools import partial
|
||||||
def collate_fn_qwenv2(examples):
|
|
||||||
# Get the texts and images, and apply the chat template
|
|
||||||
texts = [
|
|
||||||
processor.apply_chat_template(
|
|
||||||
example["chat"], tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
for example in examples
|
|
||||||
]
|
|
||||||
# print(texts)
|
|
||||||
images = [example["image"] for example in examples]
|
|
||||||
if isinstance(model, LlavaForConditionalGeneration):
|
|
||||||
# LLava1.5 does not support multiple images
|
|
||||||
images = [image[0] for image in images]
|
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
||||||
|
|
||||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
|
||||||
labels = batch["input_ids"].clone()
|
|
||||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
|
||||||
# 对<|im_start|>system *** <|im_end|>\n加掩码
|
|
||||||
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
|
||||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
||||||
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
|
|
||||||
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
|
|
||||||
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
|
|
||||||
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
|
|
||||||
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
|
||||||
# 151644 151645 8948 872 77091 None 151643
|
|
||||||
|
|
||||||
for i, label in enumerate(labels):
|
|
||||||
now_index = 0
|
|
||||||
while now_index < len(label):
|
|
||||||
if label[now_index] == im_start_token_id:
|
|
||||||
label[now_index] = -100
|
|
||||||
now_index += 1
|
|
||||||
if (
|
|
||||||
label[now_index] == system_token_id
|
|
||||||
or label[now_index] == user_token_id
|
|
||||||
):
|
|
||||||
while label[now_index] != im_end_token_id:
|
|
||||||
label[now_index] = -100
|
|
||||||
now_index += 1
|
|
||||||
label[now_index] = -100
|
|
||||||
elif label[now_index] == assistant_token_id:
|
|
||||||
label[now_index] = -100
|
|
||||||
label[now_index + 1] = -100
|
|
||||||
now_index += 2
|
|
||||||
while (
|
|
||||||
now_index < len(label)
|
|
||||||
and label[now_index] != im_end_token_id
|
|
||||||
):
|
|
||||||
now_index += 1
|
|
||||||
now_index += 1
|
|
||||||
# print(labels[0], batch["input_ids"][0])
|
|
||||||
|
|
||||||
# image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
# processor.image_token
|
|
||||||
# )
|
|
||||||
# labels[labels == image_token_id] = -100
|
|
||||||
batch["labels"] = labels
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -141,7 +79,7 @@ if __name__ == "__main__":
|
|||||||
"test": OCRVQADataset(
|
"test": OCRVQADataset(
|
||||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
split="val",
|
split="test",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,7 +92,7 @@ if __name__ == "__main__":
|
|||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=collate_fn_qwenv2,
|
data_collator=collate_fn_for_train,
|
||||||
train_dataset=dataset[script_args.dataset_train_split],
|
train_dataset=dataset[script_args.dataset_train_split],
|
||||||
eval_dataset=(
|
eval_dataset=(
|
||||||
dataset[script_args.dataset_test_split]
|
dataset[script_args.dataset_test_split]
|
||||||
@ -168,6 +106,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
model = trainer.model
|
||||||
|
|
||||||
|
# trainer.evaluate()
|
||||||
|
# 并行evaluate进行补全
|
||||||
|
|
||||||
# Save and push to hub
|
# Save and push to hub
|
||||||
trainer.save_model(training_args.output_dir)
|
trainer.save_model(training_args.output_dir)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
|
6
src/todo.md
Normal file
6
src/todo.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
## TODO:
|
||||||
|
|
||||||
|
[2024.12.31]
|
||||||
|
|
||||||
|
- [ ] 采用数据集多次训练
|
||||||
|
- [ ] 整理evaluate的代码
|
Loading…
Reference in New Issue
Block a user