133 lines
4.7 KiB
Python
133 lines
4.7 KiB
Python
from torch.utils.data import Dataset
|
|
from typing import Literal
|
|
from pathlib import Path
|
|
from dataset_library.format import dataset_dir
|
|
|
|
|
|
def get_dataset(
|
|
dataset_name, base_path=dataset_dir
|
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
|
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
|
match dataset_name:
|
|
case "ocrvqa200k":
|
|
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
|
|
|
dataset = {
|
|
"train": OCRVQADataset(
|
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
|
split="train",
|
|
),
|
|
"test": OCRVQADataset(
|
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
|
split="test",
|
|
),
|
|
"generation": OCRVQADatasetForGeneration(
|
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
|
split="test",
|
|
),
|
|
}
|
|
case "chem":
|
|
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
|
|
|
dataset = {
|
|
"train": ChemDataset(
|
|
vis_root=Path(base_path, "chem", "images"),
|
|
ann_path=Path(base_path, "chem"),
|
|
split="train",
|
|
),
|
|
"test": ChemDataset(
|
|
vis_root=Path(base_path, "chem", "images"),
|
|
ann_path=Path(base_path, "chem"),
|
|
split="test",
|
|
),
|
|
"generation": ChemDatasetForGeneration(
|
|
vis_root=Path(base_path, "chem", "images"),
|
|
ann_path=Path(base_path, "chem"),
|
|
split="test",
|
|
),
|
|
}
|
|
case "gigaspeech":
|
|
from .GigaspeechDataset import (
|
|
GigaspeechDataset,
|
|
GigaspeechDatasetForGeneration,
|
|
)
|
|
|
|
dataset = {
|
|
"train": GigaspeechDataset(split="train"),
|
|
"test": GigaspeechDataset(split="test"),
|
|
"generation": GigaspeechDatasetForGeneration(split="test"),
|
|
}
|
|
|
|
case "textvqa":
|
|
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
|
|
|
dataset = {
|
|
"train": TextVQADataset(
|
|
vis_root=Path(base_path, "TextVQA", "images"),
|
|
ann_path=Path(base_path, "TextVQA"),
|
|
split="train",
|
|
),
|
|
"test": TextVQADataset(
|
|
vis_root=Path(base_path, "TextVQA", "images"),
|
|
ann_path=Path(base_path, "TextVQA"),
|
|
split="test",
|
|
),
|
|
"generation": TextVQADatasetForGeneration(
|
|
vis_root=Path(base_path, "TextVQA", "images"),
|
|
ann_path=Path(base_path, "TextVQA"),
|
|
split="test",
|
|
),
|
|
}
|
|
case "scienceqa":
|
|
from .ScienceQADataset import (
|
|
ScienceQADataset,
|
|
ScienceQADatasetForGeneration,
|
|
)
|
|
|
|
dataset = {
|
|
"train": ScienceQADataset(split="train"),
|
|
"test": ScienceQADataset(split="test"),
|
|
"generation": ScienceQADatasetForGeneration(split="test"),
|
|
}
|
|
|
|
case "refcoco":
|
|
from .RefCOCODataset import (
|
|
RefCOCODataset,
|
|
RefCOCODatasetForGeneration,
|
|
)
|
|
|
|
dataset = {
|
|
"train": RefCOCODataset(split="val"),
|
|
"test": RefCOCODataset(split="test"),
|
|
"generation": RefCOCODatasetForGeneration(split="test"),
|
|
}
|
|
|
|
case "refcocog":
|
|
from .RefCOCOgDataset import (
|
|
RefCOCOgDataset,
|
|
RefCOCOgDatasetForGeneration,
|
|
)
|
|
|
|
dataset = {
|
|
"train": RefCOCOgDataset(split="val"),
|
|
"test": RefCOCOgDataset(split="test"),
|
|
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
|
}
|
|
|
|
case "refcocoplus":
|
|
from .RefCOCOPlusDataset import (
|
|
RefCOCOplusDataset,
|
|
RefCOCOplusDatasetForGeneration,
|
|
)
|
|
|
|
dataset = {
|
|
"train": RefCOCOplusDataset(split="val"),
|
|
"test": RefCOCOplusDataset(split="testA"),
|
|
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
|
}
|
|
|
|
return dataset
|