122 lines
3.2 KiB
Python
122 lines
3.2 KiB
Python
from .format import (
|
|
Conversation,
|
|
ConverstationAudio,
|
|
ConverstationImage,
|
|
ConverstationText,
|
|
DatasetOutput,
|
|
)
|
|
from torch.utils.data import Dataset
|
|
from datasets import load_dataset, DatasetDict
|
|
from typing import Literal
|
|
|
|
|
|
class RefCOCOgDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
vis_processor=None,
|
|
text_processor=None,
|
|
split: Literal["val", "test"] = "val",
|
|
):
|
|
"""
|
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
ann_root (string): directory to store the annotation file
|
|
"""
|
|
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
from .format import dataset_dir
|
|
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
|
self.data = ds[split] # type: ignore
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
# print(sample)
|
|
images = sample["image"]
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
|
|
if self.vis_processor is not None:
|
|
images = self.vis_processor(images)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
|
|
chat = [
|
|
Conversation(
|
|
role="user",
|
|
content=[
|
|
ConverstationImage(type="image", image_url=""),
|
|
ConverstationText(
|
|
type="text",
|
|
text=question,
|
|
),
|
|
],
|
|
),
|
|
Conversation(
|
|
role="assistant",
|
|
content=[ConverstationText(type="text", text=answer)],
|
|
),
|
|
]
|
|
|
|
return DatasetOutput(
|
|
images=[images],
|
|
chat=chat,
|
|
original=sample,
|
|
) # type: ignore
|
|
|
|
|
|
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.data[index]
|
|
# print(sample)
|
|
images = sample["image"]
|
|
question = sample["question"]
|
|
answer = sample["answer"]
|
|
|
|
if self.vis_processor is not None:
|
|
images = self.vis_processor(images)
|
|
if self.text_processor is not None:
|
|
question = self.text_processor(question)
|
|
|
|
chat = [
|
|
Conversation(
|
|
role="user",
|
|
content=[
|
|
ConverstationImage(type="image", image_url=""),
|
|
ConverstationText(
|
|
type="text",
|
|
text=f"{question}",
|
|
),
|
|
],
|
|
),
|
|
]
|
|
|
|
return DatasetOutput(
|
|
images=[images],
|
|
chat=chat,
|
|
answer=answer,
|
|
original=sample,
|
|
) # type: ignore
|
|
|
|
|
|
def test_RefCOCOg():
|
|
dataset = RefCOCOgDataset(
|
|
split="val",
|
|
)
|
|
print(dataset[3])
|
|
assert len(dataset) > 0
|
|
assert len(dataset[0]["chat"]) > 0
|
|
dataset = RefCOCOgDatasetForGeneration(
|
|
split="test",
|
|
)
|
|
print(dataset[3])
|
|
assert len(dataset) > 0
|
|
assert len(dataset[0]["chat"]) > 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_RefCOCOg()
|