添加数据集类OCRVQADataset及相关配置文件,包含训练和测试数据处理逻辑

This commit is contained in:
YunyaoZhou 2024-12-30 13:30:50 +00:00
parent 5b5fcda5e5
commit 09734720b0
11 changed files with 395 additions and 0 deletions

1
src/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
checkpoint/*

View File

@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
zero3_init_flag: false
zero_stage: 1
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: "NO"
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,86 @@
from PIL import Image
from torch.utils.data import Dataset
import json
import os
class OCRVQADataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"content": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
},
{"role": "assistant", "content": answer},
]
return {
"image": image,
"chat": chat,
"image_id": sample["image_id"],
}

View File

12
src/run.sh Normal file
View File

@ -0,0 +1,12 @@
#!/bin/bash
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml sft_vlm.py \
--dataset_name OCR_VQA_200K \
--use_peft \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--output_dir sft-llava-1.5-7b-hf \
--bf16 \
--torch_dtype bfloat16

176
src/sft_vlm.py Normal file
View File

@ -0,0 +1,176 @@
import torch
from datasets import load_dataset
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
LlavaForConditionalGeneration,
)
from trl import (
ModelConfig,
SFTScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
script_args: SFTScriptArguments
training_args: SFTConfig
model_args: ModelConfig
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
################
# Model, Tokenizer & Processor
################
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="right",
)
model = AutoModelForVision2Seq.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
################
# Create a data collator to encode text and image pairs
################
def collate_fn_qwenv2(examples):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["chat"], tokenize=False, add_generation_prompt=True
)
for example in examples
]
# print(texts)
images = [example["image"] for example in examples]
if isinstance(model, LlavaForConditionalGeneration):
# LLava1.5 does not support multiple images
images = [image[0] for image in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
# 对<|im_start|>system *** <|im_end|>\n加掩码
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
# 151644 151645 8948 872 77091 None 151643
for i, label in enumerate(labels):
now_index = 0
while now_index < len(label):
if label[now_index] == im_start_token_id:
label[now_index] = -100
now_index += 1
if (
label[now_index] == system_token_id
or label[now_index] == user_token_id
):
while label[now_index] != im_end_token_id:
label[now_index] = -100
now_index += 1
label[now_index] = -100
elif label[now_index] == assistant_token_id:
label[now_index] = -100
label[now_index + 1] = -100
now_index += 2
while (
now_index < len(label)
and label[now_index] != im_end_token_id
):
now_index += 1
now_index += 1
# print(labels[0], batch["input_ids"][0])
# image_token_id = processor.tokenizer.convert_tokens_to_ids(
# processor.image_token
# )
# labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
################
# Dataset
################
base_path = "/home/zyy/research/accelerate/dataset"
if script_args.dataset_name == "OCR_VQA_200K":
import os.path as osp
from datasets_library.OCRVQADataset import OCRVQADataset
dataset = {
"train": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="train",
),
"test": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="val",
),
}
else:
dataset = load_dataset(script_args.dataset_name, name=script_args.config)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn_qwenv2,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
tokenizer=processor.tokenizer,
# processing_class=processor.tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)