cl-lmm/src/dataset_library/RefCOCOgDataset.py

122 lines
3.2 KiB
Python

from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOgDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOg():
dataset = RefCOCOgDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOgDatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCOg()