feat: 添加CHEM数据集支持并优化图像处理逻辑

This commit is contained in:
YunyaoZhou 2025-01-10 00:18:50 +08:00
parent 8d6e5d5416
commit b766c21c9b
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
7 changed files with 213 additions and 242 deletions

View File

@ -8,3 +8,4 @@ torchaudio==2.5.1+cu124
torchvision==0.20.1+cu124
transformers==4.46.1
trl==0.13.0
pillow==9.5.0

171
src/dataset_library/CHEM.py Normal file
View File

@ -0,0 +1,171 @@
from PIL import Image
from torch.utils.data import Dataset
import json
import os
class CHEMDataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split="train")
elif split == "test":
self.data = self.create_data(ann_path, split="test")
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 800 or height > 800:
max_size = max(width, height)
ratio = 800 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def create_data(self, ann_path, split=1):
import os.path as osp
if split == "train":
json_path = osp.join(ann_path, "conversations_loc_train.jsonl")
elif split == "test":
json_path = osp.join(ann_path, "conversations_loc_val.jsonl")
processed_data = []
with open(json_path, "r") as f:
for line in f:
data = json.loads(line)
image_file = osp.join(self.vis_root, data["images"][0].split("/")[-1])
conversation = data["messages"]
system = conversation[0]["content"]
query = conversation[1]["content"]
answer = conversation[2]["content"]
processed_data.append(
{
"question": query,
"answer": answer,
"image_path": image_file,
"system": system,
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": sample["system"],
},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] {question.replace('<image>','')}",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
return {
"image": image,
"chat": chat,
}
class CHEMDatasetForGeneration(CHEMDataset):
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": sample["system"],
},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] {question.replace('<image>','')}",
},
],
},
]
return {
"image": image,
"chat": chat,
"answer": answer,
}
if __name__ == "__main__":
dataset = CHEMDataset(
"/home/zyy/research/accelerate/dataset/chem/images",
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
split="train",
)
print(len(dataset))
print(dataset[0])
dataset = CHEMDatasetForGeneration(
"/home/zyy/research/accelerate/dataset/chem/images",
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
split="train",
)
print(len(dataset))
print(dataset[0])
pass

View File

@ -14,7 +14,9 @@ class OCRVQADataset(Dataset):
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)
@ -53,12 +55,7 @@ class OCRVQADataset(Dataset):
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
os.path.join(self.vis_root, sample["image_path"])
).convert("RGB")
# resize image
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
@ -74,6 +71,15 @@ class OCRVQADataset(Dataset):
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
os.path.join(self.vis_root, sample["image_path"])
).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
@ -98,64 +104,17 @@ class OCRVQADataset(Dataset):
return {
"image": image,
"chat": chat,
"image_id": sample["image_id"],
}
class OCRVQADatasetForGeneration(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
@ -181,5 +140,4 @@ class OCRVQADatasetForGeneration(Dataset):
"image": image,
"chat": chat,
"answer": answer,
"image_id": sample["image_id"],
}

View File

@ -1,181 +0,0 @@
from PIL import Image
from torch.utils.data import Dataset
import json
import os
class ChemDataseet(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
return {
"image": image,
"chat": chat,
"image_id": sample["image_id"],
}
class OCRVQADatasetForGeneration(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
}
# {"role": "assistant", "content": answer},
]
return {
"image": image,
"chat": chat,
"answer": answer,
"image_id": sample["image_id"],
}
if __name__ == "__main__":
dataset = ChemDataseet(
"/home/zyy/research/accelerate/dataset/chem/images/",
"/home/zyy/research/accelerate/dataset/chem/qwen_data/conversations_loc_train.jsonl",
split="train",
)
print(len(dataset))
print(dataset[0])
dataset = OCRVQADatasetForGeneration(
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
split="train",
)
print(len(dataset))
print(dataset[0])
pass

View File

@ -27,4 +27,25 @@ def get_dataset(
split="test",
),
}
if dataset_name == "CHEM":
import os.path as osp
from .CHEM import CHEMDataset, CHEMDatasetForGeneration
dataset = {
"train": CHEMDataset(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem/qwen_data"),
split="train",
),
"test": CHEMDataset(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem/qwen_data"),
split="test",
),
"generation": CHEMDatasetForGeneration(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem/qwen_data"),
split="test",
),
}
return dataset

View File

@ -13,7 +13,9 @@ from utils.args import ContinualScriptArguments, ContinualModelConfig
if __name__ == "__main__":
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig))
parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
)
script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:
@ -48,10 +50,9 @@ if __name__ == "__main__":
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
print(model)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from model_library.qwen2 import (
from model_library.qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
)

View File

@ -1,7 +1,7 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \
--dataset_name CHEM \
--use_peft \
--peft_type MMOELORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \