from PIL import Image from .format import ( Conversation, ConverstationAudio, ConverstationImage, ConverstationText, DatasetOutput, ) from torch.utils.data import Dataset import json import os from pathlib import Path class VizWizDataset(Dataset): def __init__( self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" ): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ self.vis_root = vis_root self.ann_path = ann_path from .vis_processor import size_processor self.vis_processor = ( vis_processor if vis_processor is not None else size_processor ) self.text_processor = text_processor self.split = split def load_data(self): if self.split == "train": self.data = self.create_data(Path(self.ann_path, "train.json")) elif self.split == "test": self.data = self.create_data(Path(self.ann_path, "val.json")) # self.instruction_pool = [ # "[vqa] {}", # "[vqa] Based on the image, respond to this question with a short answer: {}", # ] def create_data(self, ann_path): processed_data = [] with open(ann_path, "r") as f: data = json.load(f) for i in range(len(data)): if ( os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"] ): imageFile = data[i]["image"] processed_data.append( { "question": data[i]["question"], "answer": data[i]["answers"][0]["answer"], "image_path": imageFile, "original": data[i], } ) return processed_data def __len__(self): return len(self.data) def __getitem__(self, index): sample = self.data[index] image: Image.Image = Image.open( os.path.join(self.vis_root, sample["image_path"]) ).convert("RGB") # resize image question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: image = self.vis_processor(image) if self.text_processor is not None: question = self.text_processor(question) answer = self.text_processor(answer) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=f"[vqa] Based on the image, respond to this question with a short answer: {question}", ), ], ), Conversation( role="assistant", content=[ConverstationText(type="text", text=answer)] ), ] return DatasetOutput( chat=chat, original=sample["original"], images=[image], ) # type: ignore class VizWizDatasetForGeneration(VizWizDataset): def __getitem__(self, index): sample = self.data[index] image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( "RGB" ) # resize image question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: image = self.vis_processor(image) if self.text_processor is not None: question = self.text_processor(question) answer = self.text_processor(answer) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=f"[vqa] Based on the image, respond to this question with a short answer: {question}", ), ], ), ] return DatasetOutput( images=[image], chat=chat, answer=answer, original=sample["original"], ) # type: ignore from .format import dataset_dir dataset = { "train": VizWizDataset( vis_root=Path(dataset_dir, "vizwiz", "images", "train"), ann_path=Path(dataset_dir, "vizwiz", "Annotations"), split="train", ), "test": VizWizDataset( vis_root=Path(dataset_dir, "vizwiz", "images", "val"), ann_path=Path(dataset_dir, "vizwiz", "Annotations"), split="test", ), "generation": VizWizDatasetForGeneration( vis_root=Path(dataset_dir, "vizwiz", "images", "val"), ann_path=Path(dataset_dir, "vizwiz", "Annotations"), split="test", ), } from .factory import register_dataset register_dataset( dataset_name="vizwiz", dataset=dataset, tag=["image", "text"], )