添加数据集类OCRVQADataset及相关配置文件,包含训练和测试数据处理逻辑
This commit is contained in:
parent
5b5fcda5e5
commit
09734720b0
1
src/.gitignore
vendored
Normal file
1
src/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
checkpoint/*
|
20
src/accelerate_configs/deepspeed_zero1.yaml
Normal file
20
src/accelerate_configs/deepspeed_zero1.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
gradient_accumulation_steps: 1
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
21
src/accelerate_configs/deepspeed_zero2.yaml
Normal file
21
src/accelerate_configs/deepspeed_zero2.yaml
Normal file
@ -0,0 +1,21 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
22
src/accelerate_configs/deepspeed_zero3.yaml
Normal file
22
src/accelerate_configs/deepspeed_zero3.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
25
src/accelerate_configs/fsdp_qlora.yaml
Normal file
25
src/accelerate_configs/fsdp_qlora.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
16
src/accelerate_configs/multi_gpu.yaml
Normal file
16
src/accelerate_configs/multi_gpu.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
16
src/accelerate_configs/single_gpu.yaml
Normal file
16
src/accelerate_configs/single_gpu.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: "NO"
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
86
src/datasets_library/OCRVQADataset.py
Normal file
86
src/datasets_library/OCRVQADataset.py
Normal file
@ -0,0 +1,86 @@
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class OCRVQADataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(ann_path, split=1)[:200]
|
||||
elif split == "test":
|
||||
self.data = self.create_data(ann_path, split=3)[:200]
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||
# ]
|
||||
|
||||
def create_data(self, ann_path, split=1):
|
||||
processed_data = []
|
||||
with open(ann_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for k in data.keys():
|
||||
if data[k]["split"] != split:
|
||||
continue # 1 for training, 2 for validation, 3 for test
|
||||
ext = os.path.splitext(data[k]["imageURL"])[1]
|
||||
imageFile = k + ext
|
||||
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
||||
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
||||
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
||||
processed_data.append(
|
||||
{
|
||||
"question": q,
|
||||
"answer": a,
|
||||
"image_path": imageFile,
|
||||
"image_id": k,
|
||||
"title": data[k]["title"],
|
||||
"genre": data[k]["genre"],
|
||||
}
|
||||
)
|
||||
return processed_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||
"RGB"
|
||||
)
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{
|
||||
"type": "text",
|
||||
"content": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": answer},
|
||||
]
|
||||
return {
|
||||
"image": image,
|
||||
"chat": chat,
|
||||
"image_id": sample["image_id"],
|
||||
}
|
0
src/datasets_library/__init__.py
Normal file
0
src/datasets_library/__init__.py
Normal file
12
src/run.sh
Normal file
12
src/run.sh
Normal file
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml sft_vlm.py \
|
||||
--dataset_name OCR_VQA_200K \
|
||||
--use_peft \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--lora_target_modules q_proj v_proj \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--output_dir sft-llava-1.5-7b-hf \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16
|
176
src/sft_vlm.py
Normal file
176
src/sft_vlm.py
Normal file
@ -0,0 +1,176 @@
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
LlavaForConditionalGeneration,
|
||||
)
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
SFTScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
script_args: SFTScriptArguments
|
||||
training_args: SFTConfig
|
||||
model_args: ModelConfig
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
################
|
||||
# Model, Tokenizer & Processor
|
||||
################
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="right",
|
||||
)
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
################
|
||||
# Create a data collator to encode text and image pairs
|
||||
################
|
||||
def collate_fn_qwenv2(examples):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["chat"], tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
images = [example["image"] for example in examples]
|
||||
if isinstance(model, LlavaForConditionalGeneration):
|
||||
# LLava1.5 does not support multiple images
|
||||
images = [image[0] for image in images]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
# 对<|im_start|>system *** <|im_end|>\n加掩码
|
||||
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
|
||||
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
|
||||
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
|
||||
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
|
||||
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||
# 151644 151645 8948 872 77091 None 151643
|
||||
|
||||
for i, label in enumerate(labels):
|
||||
now_index = 0
|
||||
while now_index < len(label):
|
||||
if label[now_index] == im_start_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
if (
|
||||
label[now_index] == system_token_id
|
||||
or label[now_index] == user_token_id
|
||||
):
|
||||
while label[now_index] != im_end_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
label[now_index] = -100
|
||||
elif label[now_index] == assistant_token_id:
|
||||
label[now_index] = -100
|
||||
label[now_index + 1] = -100
|
||||
now_index += 2
|
||||
while (
|
||||
now_index < len(label)
|
||||
and label[now_index] != im_end_token_id
|
||||
):
|
||||
now_index += 1
|
||||
now_index += 1
|
||||
# print(labels[0], batch["input_ids"][0])
|
||||
|
||||
# image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
# processor.image_token
|
||||
# )
|
||||
# labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
base_path = "/home/zyy/research/accelerate/dataset"
|
||||
if script_args.dataset_name == "OCR_VQA_200K":
|
||||
import os.path as osp
|
||||
from datasets_library.OCRVQADataset import OCRVQADataset
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||
split="val",
|
||||
),
|
||||
}
|
||||
|
||||
else:
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.config)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn_qwenv2,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
tokenizer=processor.tokenizer,
|
||||
# processing_class=processor.tokenizer,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
if trainer.accelerator.is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
Loading…
Reference in New Issue
Block a user