from torch.utils.data import Dataset from typing import Literal def get_dataset( dataset_name, base_path="/home/zyy/research/accelerate/dataset" ) -> dict[Literal["train", "test", "generation"], Dataset]: dataset: dict[Literal["train", "test", "generation"], Dataset] = {} if dataset_name == "OCR_VQA_200K": import os.path as osp from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration dataset = { "train": OCRVQADataset( osp.join(base_path, "OCR-VQA-200K/images"), osp.join(base_path, "OCR-VQA-200K/dataset.json"), split="train", ), "test": OCRVQADataset( osp.join(base_path, "OCR-VQA-200K/images"), osp.join(base_path, "OCR-VQA-200K/dataset.json"), split="test", ), "generation": OCRVQADatasetForGeneration( osp.join(base_path, "OCR-VQA-200K/images"), osp.join(base_path, "OCR-VQA-200K/dataset.json"), split="test", ), } if dataset_name == "CHEM": import os.path as osp from .CHEM import CHEMDataset, CHEMDatasetForGeneration dataset = { "train": CHEMDataset( osp.join(base_path, "chem/images"), osp.join(base_path, "chem/qwen_data"), split="train", ), "test": CHEMDataset( osp.join(base_path, "chem/images"), osp.join(base_path, "chem/qwen_data"), split="test", ), "generation": CHEMDatasetForGeneration( osp.join(base_path, "chem/images"), osp.join(base_path, "chem/qwen_data"), split="test", ), } return dataset