from .format import ( Conversation, ConverstationAudio, ConverstationImage, ConverstationText, DatasetOutput, ) from torch.utils.data import Dataset from datasets import load_dataset, DatasetDict class ScienceQADataset(Dataset): def __init__(self, vis_processor=None, text_processor=None, split="train"): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ self.vis_processor = vis_processor self.text_processor = text_processor self.split = split def load_data(self): from .format import dataset_dir ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore self.data = ds[self.split] # type: ignore def __len__(self): return len(self.data) def __getitem__(self, index): sample = self.data[index] # print(sample) # {'image': , 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'} images = sample["image"] question = sample["question"] choices = sample["choices"] task = sample["task"] answer = sample["answer"] if self.vis_processor is not None: images = self.vis_processor(images) if self.text_processor is not None: question = self.text_processor(question) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=f"[{task}] '{question}' choose from '{choices}'", ), ], ), Conversation( role="assistant", content=[ConverstationText(type="text", text=choices[answer])], ), ] return DatasetOutput( images=[images], chat=chat, original=sample, ) # type: ignore class ScienceQADatasetForGeneration(ScienceQADataset): def __getitem__(self, index): sample = self.data[index] # print(sample) # {'image': , 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'} images = sample["image"] question = sample["question"] choices = sample["choices"] task = sample["task"] answer = sample["answer"] if self.vis_processor is not None: images = self.vis_processor(images) if self.text_processor is not None: question = self.text_processor(question) chat = [ Conversation( role="user", content=[ ConverstationImage(type="image", image_url=""), ConverstationText( type="text", text=f"[{task}] '{question}' choose from '{choices}'", ), ], ), ] return DatasetOutput( images=[images], chat=chat, answer=choices[answer], original=sample, ) # type: ignore def test_scienceQA(): dataset = ScienceQADataset( split="train", ) print(dataset[3]) assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 dataset = ScienceQADatasetForGeneration( split="train", ) print(dataset[3]) assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 dataset = { "train": ScienceQADataset(split="train"), "test": ScienceQADataset(split="test"), "generation": ScienceQADatasetForGeneration(split="test"), } from .factory import register_dataset register_dataset( dataset_name="scienceqa", dataset=dataset, tag=["image", "text"], ) if __name__ == "__main__": test_scienceQA()