cl-lmm/src/dataset_library/VizWizDataset.py

169 lines
5.1 KiB
Python

from PIL import Image
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
import json
import os
from pathlib import Path
class VizWizDataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.ann_path = ann_path
from .vis_processor import size_processor
self.vis_processor = (
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
self.split = split
def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "train.json"))
elif self.split == "test":
self.data = self.create_data(Path(self.ann_path, "val.json"))
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for i in range(len(data)):
if (
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
and data[i]["answerable"]
):
imageFile = data[i]["image"]
processed_data.append(
{
"question": data[i]["question"],
"answer": data[i]["answers"][0]["answer"],
"image_path": imageFile,
"original": data[i],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
os.path.join(self.vis_root, sample["image_path"])
).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
Conversation(
role="assistant", content=[ConverstationText(type="text", text=answer)]
),
]
return DatasetOutput(
chat=chat,
original=sample["original"],
images=[image],
) # type: ignore
class VizWizDatasetForGeneration(VizWizDataset):
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
]
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
) # type: ignore
from .format import dataset_dir
dataset = {
"train": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="train",
),
"test": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
"generation": VizWizDatasetForGeneration(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
}
from .factory import register_dataset
register_dataset(
dataset_name="vizwiz",
dataset=dataset,
tag=["image", "text"],
)