cl-lmm/src/dataset_library/ScienceQADataset.py

140 lines
5.5 KiB
Python

from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
class ScienceQADataset(Dataset):
def __init__(self, vis_processor=None, text_processor=None, split="train"):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
images = sample["image"]
question = sample["question"]
choices = sample["choices"]
task = sample["task"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[{task}] '{question}' choose from '{choices}'",
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=choices[answer])],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class ScienceQADatasetForGeneration(ScienceQADataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
images = sample["image"]
question = sample["question"]
choices = sample["choices"]
task = sample["task"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[{task}] '{question}' choose from '{choices}'",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=choices[answer],
original=sample,
) # type: ignore
def test_scienceQA():
dataset = ScienceQADataset(
split="train",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = ScienceQADatasetForGeneration(
split="train",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="scienceqa",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_scienceQA()