174 lines
6.1 KiB
Python
174 lines
6.1 KiB
Python
from PIL import Image
|
|
from .format import (
|
|
Conversation,
|
|
ConverstationAudio,
|
|
ConverstationImage,
|
|
ConverstationText,
|
|
DatasetOutput,
|
|
)
|
|
from torch.utils.data import Dataset
|
|
import json
|
|
import os.path as osp
|
|
from pathlib import Path
|
|
|
|
|
|
class TextVQADataset(Dataset):
|
|
def __init__(
|
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
):
|
|
"""
|
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
ann_root (string): directory to store the annotation file
|
|
"""
|
|
|
|
self.vis_processor = (
|
|
vis_processor if vis_processor is not None else self._vis_processor
|
|
)
|
|
self.text_processor = text_processor
|
|
if split == "train":
|
|
self.data = self.create_data(
|
|
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
|
vis_root=Path(vis_root, "train_images"),
|
|
)
|
|
elif split == "test":
|
|
self.data = self.create_data(
|
|
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
|
vis_root=Path(vis_root, "train_images"),
|
|
)
|
|
|
|
# self.instruction_pool = [
|
|
# "[vqa] {}",
|
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
|
# ]
|
|
|
|
def create_data(self, ann_path, vis_root):
|
|
processed_data = []
|
|
with open(ann_path, "r") as f:
|
|
data = json.load(f)
|
|
data = data["data"]
|
|
for i in range(len(data)):
|
|
# print(data[0])
|
|
# {'question': 'what is the brand of phone?', 'image_id': '0054c91397f2fe05', 'image_classes': ['Belt', 'Headphones', 'Goggles', 'Scale', 'Bottle opener', 'Mobile phone', 'Mirror', 'Digital clock', 'Television', 'Telephone', 'Tool', 'Wheel', 'Camera', 'Watch', 'Glasses', 'Aircraft'], 'flickr_original_url': 'https://farm6.staticflickr.com/2891/9134076951_f65b421097_o.jpg', 'flickr_300k_url': 'https://c4.staticflickr.com/3/2891/9134076951_9db89d3e0f_z.jpg', 'image_width': 1024, 'image_height': 730, 'answers': ['nokia', 'nokia', 'nokia', 'nokia', 'toshiba', 'nokia', 'nokia', 'nokia', 'nokia', 'nokia'], 'question_tokens': ['what', 'is', 'the', 'brand', 'of', 'phone'], 'question_id': 0, 'set_name': 'train'}
|
|
try:
|
|
imageFile = data[i]["image_id"] + ".jpg"
|
|
question = data[i]["question"]
|
|
answer = data[i]["answers"][0]
|
|
processed_data.append(
|
|
{
|
|
"question": question,
|
|
"answer": answer,
|
|
"image_path": Path(vis_root, imageFile),
|
|
"image_id": data[i]["image_id"],
|
|
"title": data[i]["image_id"],
|
|
"genre": data[i]["image_classes"],
|
|
"original": data[i],
|
|
}
|
|
)
|
|
except:
|
|
print(data[i])
|
|
pass
|
|
|
|
return processed_data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def _vis_processor(self, image: Image.Image):
|
|
width, height = image.size
|
|
if width > 500 or height > 500:
|
|
max_size = max(width, height)
|
|
ratio = 500 / max_size
|
|
new_width = int(width * ratio)
|
|
new_height = int(height * ratio)
|
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
|
|
|
if width < 28 or height < 28:
|
|
min_size = min(width, height)
|
|
ratio = 28 / min_size + 1
|
|
new_width = int(width * ratio)
|
|
new_height = int(height * ratio)
|
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
|
|
|
return image
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
image: Image.Image = Image.open(sample["image_path"]).convert("RGB")
|
|
# resize image
|
|
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
if self.vis_processor is not None:
|
|
image = self.vis_processor(image)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
answer = self.text_processor(answer)
|
|
|
|
chat = [
|
|
Conversation(
|
|
role="user",
|
|
content=[
|
|
ConverstationImage(type="image", image_url=""),
|
|
ConverstationText(
|
|
type="text",
|
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
),
|
|
],
|
|
),
|
|
Conversation(
|
|
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
|
),
|
|
]
|
|
|
|
return DatasetOutput(
|
|
chat=chat,
|
|
original=sample["original"],
|
|
images=[image],
|
|
)
|
|
|
|
|
|
class TextVQADatasetForGeneration(TextVQADataset):
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
image = Image.open(sample["image_path"]).convert("RGB")
|
|
# resize image
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
if self.vis_processor is not None:
|
|
image = self.vis_processor(image)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
answer = self.text_processor(answer)
|
|
|
|
chat = [
|
|
Conversation(
|
|
role="user",
|
|
content=[
|
|
ConverstationImage(type="image", image_url=""),
|
|
ConverstationText(
|
|
type="text",
|
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
),
|
|
],
|
|
),
|
|
]
|
|
return DatasetOutput(
|
|
images=[image],
|
|
chat=chat,
|
|
answer=answer,
|
|
original=sample["original"],
|
|
)
|
|
|
|
|
|
def test_dataset():
|
|
vis_root = "/home/zyy/dataset/TextVQA/images"
|
|
ann_path = "/home/zyy/dataset/TextVQA"
|
|
dataset = TextVQADataset(vis_root, ann_path)
|
|
for i in range(10):
|
|
print(dataset[i])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_dataset()
|