from PIL import Image from .format import ( Conversation, ConverstationAudio, ConverstationImage, ConverstationText, DatasetOutput, ) from torch.utils.data import Dataset import json import os.path as osp from pathlib import Path class TextVQADataset(Dataset): def __init__( self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" ): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ self.vis_processor = ( vis_processor if vis_processor is not None else self._vis_processor ) self.text_processor = text_processor if split == "train": self.data = self.create_data( Path(ann_path, "TextVQA_0.5.1_train.json"), vis_root=Path(vis_root, "train_images"), ) elif split == "test": self.data = self.create_data( Path(ann_path, "TextVQA_0.5.1_val.json"), vis_root=Path(vis_root, "train_images"), ) # self.instruction_pool = [ # "[vqa] {}", # "[vqa] Based on the image, respond to this question with a short answer: {}", # ] def create_data(self, ann_path, vis_root): processed_data = [] with open(ann_path, "r") as f: data = json.load(f) data = data["data"] for i in range(len(data)): # print(data[0]) # {'question': 'what is the brand of phone?', 'image_id': '0054c91397f2fe05', 'image_classes': ['Belt', 'Headphones', 'Goggles', 'Scale', 'Bottle opener', 'Mobile phone', 'Mirror', 'Digital clock', 'Television', 'Telephone', 'Tool', 'Wheel', 'Camera', 'Watch', 'Glasses', 'Aircraft'], 'flickr_original_url': 'https://farm6.staticflickr.com/2891/9134076951_f65b421097_o.jpg', 'flickr_300k_url': 'https://c4.staticflickr.com/3/2891/9134076951_9db89d3e0f_z.jpg', 'image_width': 1024, 'image_height': 730, 'answers': ['nokia', 'nokia', 'nokia', 'nokia', 'toshiba', 'nokia', 'nokia', 'nokia', 'nokia', 'nokia'], 'question_tokens': ['what', 'is', 'the', 'brand', 'of', 'phone'], 'question_id': 0, 'set_name': 'train'} try: imageFile = data[i]["image_id"] + ".jpg" question = data[i]["question"] answer = data[i]["answers"][0] processed_data.append( { "question": question, "answer": answer, "image_path": Path(vis_root, imageFile), "image_id": data[i]["image_id"], "title": data[i]["image_id"], "genre": data[i]["image_classes"], "original": data[i], } ) except: print(data[i]) pass return processed_data def __len__(self): return len(self.data) def _vis_processor(self, image: Image.Image): width, height = image.size if width > 500 or height > 500: max_size = max(width, height) ratio = 500 / max_size new_width = int(width * ratio) new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) if width < 28 or height < 28: min_size = min(width, height) ratio = 28 / min_size + 1 new_width = int(width * ratio) new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) return image def __getitem__(self, index): sample = self.data[index] image: Image.Image = Image.open(sample["image_path"]).convert("RGB") # resize image question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: image = self.vis_processor(image) if self.text_processor is not None: question = self.text_processor(question) answer = self.text_processor(answer) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=f"[vqa] Based on the image, respond to this question with a short answer: {question}", ), ], ), Conversation( role="assistant", content=[ConverstationText(type="text", text=answer)] ), ] return DatasetOutput( chat=chat, original=sample["original"], images=[image], ) class TextVQADatasetForGeneration(TextVQADataset): def __getitem__(self, index): sample = self.data[index] image = Image.open(sample["image_path"]).convert("RGB") # resize image question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: image = self.vis_processor(image) if self.text_processor is not None: question = self.text_processor(question) answer = self.text_processor(answer) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=f"[vqa] Based on the image, respond to this question with a short answer: {question}", ), ], ), ] return DatasetOutput( images=[image], chat=chat, answer=answer, original=sample["original"], ) def test_dataset(): vis_root = "/home/zyy/dataset/TextVQA/images" ann_path = "/home/zyy/dataset/TextVQA" dataset = TextVQADataset(vis_root, ann_path) for i in range(10): print(dataset[i]) if __name__ == "__main__": test_dataset()