169 lines
5.1 KiB
Python
169 lines
5.1 KiB
Python
from PIL import Image
|
|
from .format import (
|
|
Conversation,
|
|
ConverstationAudio,
|
|
ConverstationImage,
|
|
ConverstationText,
|
|
DatasetOutput,
|
|
)
|
|
from torch.utils.data import Dataset
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
|
|
class VizWizDataset(Dataset):
|
|
def __init__(
|
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
):
|
|
"""
|
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
ann_root (string): directory to store the annotation file
|
|
"""
|
|
self.vis_root = vis_root
|
|
self.ann_path = ann_path
|
|
|
|
from .vis_processor import size_processor
|
|
|
|
self.vis_processor = (
|
|
vis_processor if vis_processor is not None else size_processor
|
|
)
|
|
self.text_processor = text_processor
|
|
self.split = split
|
|
|
|
def load_data(self):
|
|
if self.split == "train":
|
|
self.data = self.create_data(Path(self.ann_path, "train.json"))
|
|
elif self.split == "test":
|
|
self.data = self.create_data(Path(self.ann_path, "val.json"))
|
|
|
|
# self.instruction_pool = [
|
|
# "[vqa] {}",
|
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
|
# ]
|
|
|
|
def create_data(self, ann_path):
|
|
processed_data = []
|
|
with open(ann_path, "r") as f:
|
|
data = json.load(f)
|
|
for i in range(len(data)):
|
|
if (
|
|
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
|
|
and data[i]["answerable"]
|
|
):
|
|
imageFile = data[i]["image"]
|
|
processed_data.append(
|
|
{
|
|
"question": data[i]["question"],
|
|
"answer": data[i]["answers"][0]["answer"],
|
|
"image_path": imageFile,
|
|
"original": data[i],
|
|
}
|
|
)
|
|
return processed_data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
image: Image.Image = Image.open(
|
|
os.path.join(self.vis_root, sample["image_path"])
|
|
).convert("RGB")
|
|
# resize image
|
|
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
if self.vis_processor is not None:
|
|
image = self.vis_processor(image)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
answer = self.text_processor(answer)
|
|
|
|
chat = [
|
|
Conversation(
|
|
role="user",
|
|
content=[
|
|
ConverstationImage(type="image", image_url=""),
|
|
ConverstationText(
|
|
type="text",
|
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
),
|
|
],
|
|
),
|
|
Conversation(
|
|
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
|
),
|
|
]
|
|
|
|
return DatasetOutput(
|
|
chat=chat,
|
|
original=sample["original"],
|
|
images=[image],
|
|
) # type: ignore
|
|
|
|
|
|
class VizWizDatasetForGeneration(VizWizDataset):
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
|
"RGB"
|
|
)
|
|
# resize image
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
if self.vis_processor is not None:
|
|
image = self.vis_processor(image)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
answer = self.text_processor(answer)
|
|
|
|
chat = [
|
|
Conversation(
|
|
role="user",
|
|
content=[
|
|
ConverstationImage(type="image", image_url=""),
|
|
ConverstationText(
|
|
type="text",
|
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
),
|
|
],
|
|
),
|
|
]
|
|
return DatasetOutput(
|
|
images=[image],
|
|
chat=chat,
|
|
answer=answer,
|
|
original=sample["original"],
|
|
) # type: ignore
|
|
|
|
|
|
from .format import dataset_dir
|
|
|
|
dataset = {
|
|
"train": VizWizDataset(
|
|
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
|
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
|
split="train",
|
|
),
|
|
"test": VizWizDataset(
|
|
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
|
split="test",
|
|
),
|
|
"generation": VizWizDatasetForGeneration(
|
|
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
|
split="test",
|
|
),
|
|
}
|
|
|
|
from .factory import register_dataset
|
|
|
|
register_dataset(
|
|
dataset_name="vizwiz",
|
|
dataset=dataset,
|
|
tag=["image", "text"],
|
|
)
|