from PIL import Image from torch.utils.data import Dataset import json import os class OCRVQADataset(Dataset): def __init__( self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" ): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ self.vis_root = vis_root self.vis_processor = vis_processor self.text_processor = text_processor if split == "train": self.data = self.create_data(ann_path, split=1) elif split == "test": self.data = self.create_data(ann_path, split=3) # self.instruction_pool = [ # "[vqa] {}", # "[vqa] Based on the image, respond to this question with a short answer: {}", # ] def create_data(self, ann_path, split=1): processed_data = [] with open(ann_path, "r") as f: data = json.load(f) for k in data.keys(): if data[k]["split"] != split: continue # 1 for training, 2 for validation, 3 for test ext = os.path.splitext(data[k]["imageURL"])[1] imageFile = k + ext assert len(data[k]["questions"]) == len(data[k]["answers"]) for q, a in zip(data[k]["questions"], data[k]["answers"]): if os.path.exists(os.path.join(self.vis_root, imageFile)): processed_data.append( { "question": q, "answer": a, "image_path": imageFile, "image_id": k, "title": data[k]["title"], "genre": data[k]["genre"], } ) return processed_data def __len__(self): return len(self.data) def __getitem__(self, index): sample = self.data[index] image: Image.Image = Image.open( os.path.join(self.vis_root, sample["image_path"]) ).convert("RGB") # resize image width, height = image.size if width > 500 or height > 500: max_size = max(width, height) ratio = 500 / max_size new_width = int(width * ratio) new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: image = self.vis_processor(image) if self.text_processor is not None: question = self.text_processor(question) answer = self.text_processor(answer) chat = [ { "role": "user", "content": [ {"type": "image"}, { "type": "text", "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", }, ], }, {"role": "assistant", "content": [{"type": "text", "text": answer}]}, ] return { "image": image, "chat": chat, "image_id": sample["image_id"], } class OCRVQADatasetForGeneration(Dataset): def __init__( self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" ): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ self.vis_root = vis_root self.vis_processor = vis_processor self.text_processor = text_processor if split == "train": self.data = self.create_data(ann_path, split=1)[:200] elif split == "test": self.data = self.create_data(ann_path, split=3)[:200] # self.instruction_pool = [ # "[vqa] {}", # "[vqa] Based on the image, respond to this question with a short answer: {}", # ] def create_data(self, ann_path, split=1): processed_data = [] with open(ann_path, "r") as f: data = json.load(f) for k in data.keys(): if data[k]["split"] != split: continue # 1 for training, 2 for validation, 3 for test ext = os.path.splitext(data[k]["imageURL"])[1] imageFile = k + ext assert len(data[k]["questions"]) == len(data[k]["answers"]) for q, a in zip(data[k]["questions"], data[k]["answers"]): if os.path.exists(os.path.join(self.vis_root, imageFile)): processed_data.append( { "question": q, "answer": a, "image_path": imageFile, "image_id": k, "title": data[k]["title"], "genre": data[k]["genre"], } ) return processed_data def __len__(self): return len(self.data) def __getitem__(self, index): sample = self.data[index] image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( "RGB" ) question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: image = self.vis_processor(image) if self.text_processor is not None: question = self.text_processor(question) answer = self.text_processor(answer) chat = [ { "role": "user", "content": [ {"type": "image"}, { "type": "text", "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", }, ], } # {"role": "assistant", "content": answer}, ] return { "image": image, "chat": chat, "answer": answer, "image_id": sample["image_id"], }