from torch.utils.data import Dataset from typing import Literal from pathlib import Path from dataset_library.format import dataset_dir def get_dataset( dataset_name, base_path=dataset_dir ) -> dict[Literal["train", "test", "generation"], Dataset]: dataset: dict[Literal["train", "test", "generation"], Dataset] = {} match dataset_name: case "ocrvqa200k": from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration dataset = { "train": OCRVQADataset( vis_root=Path(base_path, "OCR-VQA-200K", "images"), ann_path=Path(base_path, "OCR-VQA-200K"), split="train", ), "test": OCRVQADataset( vis_root=Path(base_path, "OCR-VQA-200K", "images"), ann_path=Path(base_path, "OCR-VQA-200K"), split="test", ), "generation": OCRVQADatasetForGeneration( vis_root=Path(base_path, "OCR-VQA-200K", "images"), ann_path=Path(base_path, "OCR-VQA-200K"), split="test", ), } case "chem": from .ChemDataset import ChemDataset, ChemDatasetForGeneration dataset = { "train": ChemDataset( vis_root=Path(base_path, "chem", "images"), ann_path=Path(base_path, "chem"), split="train", ), "test": ChemDataset( vis_root=Path(base_path, "chem", "images"), ann_path=Path(base_path, "chem"), split="test", ), "generation": ChemDatasetForGeneration( vis_root=Path(base_path, "chem", "images"), ann_path=Path(base_path, "chem"), split="test", ), } case "gigaspeech": from .GigaspeechDataset import ( GigaspeechDataset, GigaspeechDatasetForGeneration, ) dataset = { "train": GigaspeechDataset(split="train"), "test": GigaspeechDataset(split="test"), "generation": GigaspeechDatasetForGeneration(split="test"), } case "textvqa": from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration dataset = { "train": TextVQADataset( vis_root=Path(base_path, "TextVQA", "images"), ann_path=Path(base_path, "TextVQA"), split="train", ), "test": TextVQADataset( vis_root=Path(base_path, "TextVQA", "images"), ann_path=Path(base_path, "TextVQA"), split="test", ), "generation": TextVQADatasetForGeneration( vis_root=Path(base_path, "TextVQA", "images"), ann_path=Path(base_path, "TextVQA"), split="test", ), } case "scienceqa": from .ScienceQADataset import ( ScienceQADataset, ScienceQADatasetForGeneration, ) dataset = { "train": ScienceQADataset(split="train"), "test": ScienceQADataset(split="test"), "generation": ScienceQADatasetForGeneration(split="test"), } case "refcoco": from .RefCOCODataset import ( RefCOCODataset, RefCOCODatasetForGeneration, ) dataset = { "train": RefCOCODataset(split="val"), "test": RefCOCODataset(split="test"), "generation": RefCOCODatasetForGeneration(split="test"), } case "refcocog": from .RefCOCOgDataset import ( RefCOCOgDataset, RefCOCOgDatasetForGeneration, ) dataset = { "train": RefCOCOgDataset(split="val"), "test": RefCOCOgDataset(split="test"), "generation": RefCOCOgDatasetForGeneration(split="test"), } case "refcocoplus": from .RefCOCOPlusDataset import ( RefCOCOplusDataset, RefCOCOplusDatasetForGeneration, ) dataset = { "train": RefCOCOplusDataset(split="val"), "test": RefCOCOplusDataset(split="testA"), "generation": RefCOCOplusDatasetForGeneration(split="testA"), } return dataset