52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
from torch.utils.data import Dataset
|
|
from typing import Literal
|
|
|
|
|
|
def get_dataset(
|
|
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
|
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
|
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
|
if dataset_name == "OCR_VQA_200K":
|
|
import os.path as osp
|
|
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
|
|
|
|
dataset = {
|
|
"train": OCRVQADataset(
|
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
|
split="train",
|
|
),
|
|
"test": OCRVQADataset(
|
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
|
split="test",
|
|
),
|
|
"generation": OCRVQADatasetForGeneration(
|
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
|
split="test",
|
|
),
|
|
}
|
|
if dataset_name == "CHEM":
|
|
import os.path as osp
|
|
from .CHEM import CHEMDataset, CHEMDatasetForGeneration
|
|
|
|
dataset = {
|
|
"train": CHEMDataset(
|
|
osp.join(base_path, "chem/images"),
|
|
osp.join(base_path, "chem/qwen_data"),
|
|
split="train",
|
|
),
|
|
"test": CHEMDataset(
|
|
osp.join(base_path, "chem/images"),
|
|
osp.join(base_path, "chem/qwen_data"),
|
|
split="test",
|
|
),
|
|
"generation": CHEMDatasetForGeneration(
|
|
osp.join(base_path, "chem/images"),
|
|
osp.join(base_path, "chem/qwen_data"),
|
|
split="test",
|
|
),
|
|
}
|
|
return dataset
|