cl-lmm/src/dataset_library/TextVQADataset.py

174 lines
6.1 KiB
Python

from PIL import Image
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
import json
import os.path as osp
from pathlib import Path
class TextVQADataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_train.json"),
vis_root=Path(vis_root, "train_images"),
)
elif split == "test":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_val.json"),
vis_root=Path(vis_root, "train_images"),
)
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, vis_root):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
data = data["data"]
for i in range(len(data)):
# print(data[0])
# {'question': 'what is the brand of phone?', 'image_id': '0054c91397f2fe05', 'image_classes': ['Belt', 'Headphones', 'Goggles', 'Scale', 'Bottle opener', 'Mobile phone', 'Mirror', 'Digital clock', 'Television', 'Telephone', 'Tool', 'Wheel', 'Camera', 'Watch', 'Glasses', 'Aircraft'], 'flickr_original_url': 'https://farm6.staticflickr.com/2891/9134076951_f65b421097_o.jpg', 'flickr_300k_url': 'https://c4.staticflickr.com/3/2891/9134076951_9db89d3e0f_z.jpg', 'image_width': 1024, 'image_height': 730, 'answers': ['nokia', 'nokia', 'nokia', 'nokia', 'toshiba', 'nokia', 'nokia', 'nokia', 'nokia', 'nokia'], 'question_tokens': ['what', 'is', 'the', 'brand', 'of', 'phone'], 'question_id': 0, 'set_name': 'train'}
try:
imageFile = data[i]["image_id"] + ".jpg"
question = data[i]["question"]
answer = data[i]["answers"][0]
processed_data.append(
{
"question": question,
"answer": answer,
"image_path": Path(vis_root, imageFile),
"image_id": data[i]["image_id"],
"title": data[i]["image_id"],
"genre": data[i]["image_classes"],
"original": data[i],
}
)
except:
print(data[i])
pass
return processed_data
def __len__(self):
return len(self.data)
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(sample["image_path"]).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
Conversation(
role="assistant", content=[ConverstationText(type="text", text=answer)]
),
]
return DatasetOutput(
chat=chat,
original=sample["original"],
images=[image],
)
class TextVQADatasetForGeneration(TextVQADataset):
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(sample["image_path"]).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
]
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
)
def test_dataset():
vis_root = "/home/zyy/dataset/TextVQA/images"
ann_path = "/home/zyy/dataset/TextVQA"
dataset = TextVQADataset(vis_root, ann_path)
for i in range(10):
print(dataset[i])
if __name__ == "__main__":
test_dataset()