cl-lmm/src/dataset_library/factory.py

52 lines
1.8 KiB
Python

from torch.utils.data import Dataset
from typing import Literal
def get_dataset(
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
if dataset_name == "OCR_VQA_200K":
import os.path as osp
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
dataset = {
"train": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="train",
),
"test": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="test",
),
}
if dataset_name == "CHEM":
import os.path as osp
from .CHEM import CHEMDataset, CHEMDatasetForGeneration
dataset = {
"train": CHEMDataset(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem/qwen_data"),
split="train",
),
"test": CHEMDataset(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem/qwen_data"),
split="test",
),
"generation": CHEMDatasetForGeneration(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem/qwen_data"),
split="test",
),
}
return dataset