from .format import ( Conversation, ConverstationAudio, ConverstationImage, ConverstationText, DatasetOutput, ) from torch.utils.data import Dataset from datasets import load_dataset, DatasetDict from typing import Literal class RefCOCOgDataset(Dataset): def __init__( self, vis_processor=None, text_processor=None, split: Literal["val", "test"] = "val", ): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ self.vis_processor = vis_processor self.text_processor = text_processor from .format import dataset_dir ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore self.data = ds[split] # type: ignore def __len__(self): return len(self.data) def __getitem__(self, index): sample = self.data[index] # print(sample) images = sample["image"] question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: images = self.vis_processor(images) if self.text_processor is not None: question = self.text_processor(question) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=question, ), ], ), Conversation( role="assistant", content=[ConverstationText(type="text", text=answer)], ), ] return DatasetOutput( images=[images], chat=chat, original=sample, ) # type: ignore class RefCOCOgDatasetForGeneration(RefCOCOgDataset): def __getitem__(self, index): sample = self.data[index] # print(sample) images = sample["image"] question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: images = self.vis_processor(images) if self.text_processor is not None: question = self.text_processor(question) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=f"{question}", ), ], ), ] return DatasetOutput( images=[images], chat=chat, answer=answer, original=sample, ) # type: ignore def test_RefCOCOg(): dataset = RefCOCOgDataset( split="val", ) print(dataset[3]) assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 dataset = RefCOCOgDatasetForGeneration( split="test", ) print(dataset[3]) assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 if __name__ == "__main__": test_RefCOCOg()