cl-lmm/src/dataset_library/factory.py

133 lines
4.7 KiB
Python

from torch.utils.data import Dataset
from typing import Literal
from pathlib import Path
from dataset_library.format import dataset_dir
def get_dataset(
dataset_name, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
match dataset_name:
case "ocrvqa200k":
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
case "chem":
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
dataset = {
"train": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="train",
),
"test": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
"generation": ChemDatasetForGeneration(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
}
case "gigaspeech":
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
case "textvqa":
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
dataset = {
"train": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
}
case "scienceqa":
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
case "refcoco":
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
case "refcocog":
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
case "refcocoplus":
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
return dataset