cl-lmm/src/dataset_library/OCRVQADataset.py

179 lines
6.2 KiB
Python

from PIL import Image
from torch.utils.data import Dataset
import json
import os
class OCRVQADataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)
elif split == "test":
self.data = self.create_data(ann_path, split=3)
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
os.path.join(self.vis_root, sample["image_path"])
).convert("RGB")
# resize image
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
return {
"image": image,
"chat": chat,
"image_id": sample["image_id"],
}
class OCRVQADatasetForGeneration(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
}
# {"role": "assistant", "content": answer},
]
return {
"image": image,
"chat": chat,
"answer": answer,
"image_id": sample["image_id"],
}