179 lines
6.2 KiB
Python
179 lines
6.2 KiB
Python
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
import json
|
|
import os
|
|
|
|
|
|
class OCRVQADataset(Dataset):
|
|
def __init__(
|
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
):
|
|
"""
|
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
ann_root (string): directory to store the annotation file
|
|
"""
|
|
self.vis_root = vis_root
|
|
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
if split == "train":
|
|
self.data = self.create_data(ann_path, split=1)
|
|
elif split == "test":
|
|
self.data = self.create_data(ann_path, split=3)
|
|
|
|
# self.instruction_pool = [
|
|
# "[vqa] {}",
|
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
|
# ]
|
|
|
|
def create_data(self, ann_path, split=1):
|
|
processed_data = []
|
|
with open(ann_path, "r") as f:
|
|
data = json.load(f)
|
|
for k in data.keys():
|
|
if data[k]["split"] != split:
|
|
continue # 1 for training, 2 for validation, 3 for test
|
|
ext = os.path.splitext(data[k]["imageURL"])[1]
|
|
imageFile = k + ext
|
|
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
|
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
|
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
|
processed_data.append(
|
|
{
|
|
"question": q,
|
|
"answer": a,
|
|
"image_path": imageFile,
|
|
"image_id": k,
|
|
"title": data[k]["title"],
|
|
"genre": data[k]["genre"],
|
|
}
|
|
)
|
|
return processed_data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
image: Image.Image = Image.open(
|
|
os.path.join(self.vis_root, sample["image_path"])
|
|
).convert("RGB")
|
|
# resize image
|
|
width, height = image.size
|
|
if width > 500 or height > 500:
|
|
max_size = max(width, height)
|
|
ratio = 500 / max_size
|
|
new_width = int(width * ratio)
|
|
new_height = int(height * ratio)
|
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
|
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
if self.vis_processor is not None:
|
|
image = self.vis_processor(image)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
answer = self.text_processor(answer)
|
|
|
|
chat = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image"},
|
|
{
|
|
"type": "text",
|
|
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
},
|
|
],
|
|
},
|
|
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
|
]
|
|
return {
|
|
"image": image,
|
|
"chat": chat,
|
|
"image_id": sample["image_id"],
|
|
}
|
|
|
|
|
|
class OCRVQADatasetForGeneration(Dataset):
|
|
def __init__(
|
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
):
|
|
"""
|
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
ann_root (string): directory to store the annotation file
|
|
"""
|
|
self.vis_root = vis_root
|
|
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
if split == "train":
|
|
self.data = self.create_data(ann_path, split=1)[:200]
|
|
elif split == "test":
|
|
self.data = self.create_data(ann_path, split=3)[:200]
|
|
|
|
# self.instruction_pool = [
|
|
# "[vqa] {}",
|
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
|
# ]
|
|
|
|
def create_data(self, ann_path, split=1):
|
|
processed_data = []
|
|
with open(ann_path, "r") as f:
|
|
data = json.load(f)
|
|
for k in data.keys():
|
|
if data[k]["split"] != split:
|
|
continue # 1 for training, 2 for validation, 3 for test
|
|
ext = os.path.splitext(data[k]["imageURL"])[1]
|
|
imageFile = k + ext
|
|
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
|
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
|
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
|
processed_data.append(
|
|
{
|
|
"question": q,
|
|
"answer": a,
|
|
"image_path": imageFile,
|
|
"image_id": k,
|
|
"title": data[k]["title"],
|
|
"genre": data[k]["genre"],
|
|
}
|
|
)
|
|
return processed_data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
|
"RGB"
|
|
)
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
if self.vis_processor is not None:
|
|
image = self.vis_processor(image)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
answer = self.text_processor(answer)
|
|
|
|
chat = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image"},
|
|
{
|
|
"type": "text",
|
|
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
},
|
|
],
|
|
}
|
|
# {"role": "assistant", "content": answer},
|
|
]
|
|
return {
|
|
"image": image,
|
|
"chat": chat,
|
|
"answer": answer,
|
|
"image_id": sample["image_id"],
|
|
}
|