feat✨: 添加多个数据集支持,优化数据加载逻辑,更新数据集注册机制,调整VSCode设置
This commit is contained in:
parent
c100a59f0e
commit
1d2e7b9dcd
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
@ -4,8 +4,11 @@
|
||||
"src/peft_repo/src/"
|
||||
],
|
||||
"python.analysis.exclude": [
|
||||
"dataset/**/*"
|
||||
"dataset/**/*",
|
||||
],
|
||||
"python.languageServer": "Default",
|
||||
"python.terminal.activateEnvInCurrentTerminal": true
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
"python.analysis.include": [
|
||||
"src/**/*"
|
||||
]
|
||||
}
|
@ -18,9 +18,13 @@ class GigaspeechDataset(Dataset):
|
||||
|
||||
self.audio_processor = audio_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = gs[split] # type: ignore
|
||||
|
||||
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = gs[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -55,7 +59,7 @@ class GigaspeechDataset(Dataset):
|
||||
audios=[(audio, sampling_rate)],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||
@ -88,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||
chat=chat,
|
||||
answer=text,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_gigaspeech():
|
||||
@ -105,5 +109,20 @@ def test_gigaspeech():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
from .factory import register_dataset
|
||||
|
||||
dataset = {
|
||||
"train": GigaspeechDataset(split="train"),
|
||||
"test": GigaspeechDataset(split="test"),
|
||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
register_dataset(
|
||||
dataset_name="gigaspeech",
|
||||
dataset=dataset,
|
||||
tag=["audio", "text"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gigaspeech()
|
||||
|
@ -28,10 +28,14 @@ class OCRVQADataset(Dataset):
|
||||
vis_processor if vis_processor is not None else size_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
|
||||
elif split == "test":
|
||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
|
||||
self.split = split
|
||||
self.ann_path = ann_path
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
@ -101,7 +105,7 @@ class OCRVQADataset(Dataset):
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
@ -137,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
from .factory import register_dataset
|
||||
from .format import dataset_dir as base_path
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
"generation": OCRVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])
|
@ -24,9 +24,11 @@ class RefCOCODataset(Dataset):
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[split] # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -116,6 +118,18 @@ def test_RefCOCO():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCODataset(split="val"),
|
||||
"test": RefCOCODataset(split="test"),
|
||||
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
register_dataset(
|
||||
dataset_name="refcoco",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCO()
|
||||
|
@ -24,9 +24,12 @@ class RefCOCOplusDataset(Dataset):
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[split] # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -116,6 +119,18 @@ def test_RefCOCOplus():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOplusDataset(split="val"),
|
||||
"test": RefCOCOplusDataset(split="testA"),
|
||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
register_dataset(
|
||||
dataset_name="refcocoplus",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCOplus()
|
||||
|
@ -24,9 +24,12 @@ class RefCOCOgDataset(Dataset):
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[split] # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -116,6 +119,18 @@ def test_RefCOCOg():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOgDataset(split="val"),
|
||||
"test": RefCOCOgDataset(split="test"),
|
||||
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
register_dataset(
|
||||
dataset_name="refcocog",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCOg()
|
||||
|
@ -18,9 +18,12 @@ class ScienceQADataset(Dataset):
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[split] # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -116,6 +119,18 @@ def test_scienceQA():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
dataset = {
|
||||
"train": ScienceQADataset(split="train"),
|
||||
"test": ScienceQADataset(split="test"),
|
||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
register_dataset(
|
||||
dataset_name="scienceqa",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_scienceQA()
|
||||
|
@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
|
||||
vis_processor if vis_processor is not None else self._vis_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.split = split
|
||||
self.ann_path = ann_path
|
||||
self.vis_root = vis_root
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(
|
||||
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
||||
vis_root=Path(vis_root, "train_images"),
|
||||
Path(self.ann_path, "TextVQA_0.5.1_train.json"),
|
||||
vis_root=Path(self.vis_root, "train_images"),
|
||||
)
|
||||
elif split == "test":
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(
|
||||
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
||||
vis_root=Path(vis_root, "train_images"),
|
||||
Path(self.ann_path, "TextVQA_0.5.1_val.json"),
|
||||
vis_root=Path(self.vis_root, "train_images"),
|
||||
)
|
||||
|
||||
# self.instruction_pool = [
|
||||
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class TextVQADatasetForGeneration(TextVQADataset):
|
||||
@ -158,5 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
from .format import dataset_dir
|
||||
|
||||
dataset = {
|
||||
"train": TextVQADataset(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="train",
|
||||
),
|
||||
"test": TextVQADataset(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
"generation": TextVQADatasetForGeneration(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset("textvqa", dataset, tag=["text", "image"])
|
||||
|
@ -21,6 +21,7 @@ class VizWizDataset(Dataset):
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
self.ann_path = ann_path
|
||||
|
||||
from .vis_processor import size_processor
|
||||
|
||||
@ -28,10 +29,13 @@ class VizWizDataset(Dataset):
|
||||
vis_processor if vis_processor is not None else size_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(Path(ann_path, "train.json"))
|
||||
elif split == "test":
|
||||
self.data = self.create_data(Path(ann_path, "val.json"))
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(Path(self.ann_path, "train.json"))
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(Path(self.ann_path, "val.json"))
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
@ -43,7 +47,10 @@ class VizWizDataset(Dataset):
|
||||
with open(ann_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in range(len(data)):
|
||||
if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]:
|
||||
if (
|
||||
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
|
||||
and data[i]["answerable"]
|
||||
):
|
||||
imageFile = data[i]["image"]
|
||||
processed_data.append(
|
||||
{
|
||||
@ -93,7 +100,7 @@ class VizWizDataset(Dataset):
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class VizWizDatasetForGeneration(VizWizDataset):
|
||||
@ -129,4 +136,33 @@ class VizWizDatasetForGeneration(VizWizDataset):
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
from .format import dataset_dir
|
||||
|
||||
dataset = {
|
||||
"train": VizWizDataset(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="train",
|
||||
),
|
||||
"test": VizWizDataset(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
"generation": VizWizDatasetForGeneration(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="vizwiz",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
@ -1,156 +1,79 @@
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Literal
|
||||
from typing import Literal,List
|
||||
from pathlib import Path
|
||||
from dataset_library.format import dataset_dir
|
||||
|
||||
MAPPING_NAME_TO_DATASET: dict[
|
||||
str, dict[Literal["train", "test", "generation"], Dataset]
|
||||
] = {}
|
||||
|
||||
|
||||
def register_dataset(
|
||||
dataset_name: str,
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset],
|
||||
tag: List[Literal["image", "text", "audio", "video"]] = [],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Register a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (`str`):
|
||||
The name of the dataset.
|
||||
dataset (`Dataset`):
|
||||
The dataset to register.
|
||||
"""
|
||||
dataset_name = dataset_name.lower()
|
||||
if dataset_name in MAPPING_NAME_TO_DATASET:
|
||||
raise ValueError(f"Dataset {dataset_name} already registered.")
|
||||
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
|
||||
|
||||
|
||||
from .GigaspeechDataset import (
|
||||
GigaspeechDataset,
|
||||
GigaspeechDatasetForGeneration,
|
||||
)
|
||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||
from .ScienceQADataset import (
|
||||
ScienceQADataset,
|
||||
ScienceQADatasetForGeneration,
|
||||
)
|
||||
from .RefCOCODataset import (
|
||||
RefCOCODataset,
|
||||
RefCOCODatasetForGeneration,
|
||||
)
|
||||
from .RefCOCOgDataset import (
|
||||
RefCOCOgDataset,
|
||||
RefCOCOgDatasetForGeneration,
|
||||
)
|
||||
from .RefCOCOPlusDataset import (
|
||||
RefCOCOplusDataset,
|
||||
RefCOCOplusDatasetForGeneration,
|
||||
)
|
||||
from .VizWizDataset import (
|
||||
VizWizDataset,
|
||||
VizWizDatasetForGeneration,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_dataset(
|
||||
dataset_name, base_path=dataset_dir
|
||||
dataset_name: str, base_path=dataset_dir
|
||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||
match dataset_name:
|
||||
case "ocrvqa200k":
|
||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
"""
|
||||
Get a dataset by name.
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
"generation": OCRVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
case "chem":
|
||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": ChemDataset(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="train",
|
||||
),
|
||||
"test": ChemDataset(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="test",
|
||||
),
|
||||
"generation": ChemDatasetForGeneration(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
case "gigaspeech":
|
||||
from .GigaspeechDataset import (
|
||||
GigaspeechDataset,
|
||||
GigaspeechDatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": GigaspeechDataset(split="train"),
|
||||
"test": GigaspeechDataset(split="test"),
|
||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
case "textvqa":
|
||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": TextVQADataset(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="train",
|
||||
),
|
||||
"test": TextVQADataset(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
"generation": TextVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
case "scienceqa":
|
||||
from .ScienceQADataset import (
|
||||
ScienceQADataset,
|
||||
ScienceQADatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": ScienceQADataset(split="train"),
|
||||
"test": ScienceQADataset(split="test"),
|
||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
case "refcoco":
|
||||
from .RefCOCODataset import (
|
||||
RefCOCODataset,
|
||||
RefCOCODatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCODataset(split="val"),
|
||||
"test": RefCOCODataset(split="test"),
|
||||
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
case "refcocog":
|
||||
from .RefCOCOgDataset import (
|
||||
RefCOCOgDataset,
|
||||
RefCOCOgDatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOgDataset(split="val"),
|
||||
"test": RefCOCOgDataset(split="test"),
|
||||
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
case "refcocoplus":
|
||||
from .RefCOCOPlusDataset import (
|
||||
RefCOCOplusDataset,
|
||||
RefCOCOplusDatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOplusDataset(split="val"),
|
||||
"test": RefCOCOplusDataset(split="testA"),
|
||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||
}
|
||||
|
||||
case "vizwiz":
|
||||
from .VizWizDataset import (
|
||||
VizWizDataset,
|
||||
VizWizDatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": VizWizDataset(
|
||||
vis_root=Path(base_path, "vizwiz", "images", "train"),
|
||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||
split="train",
|
||||
),
|
||||
"test": VizWizDataset(
|
||||
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
"generation": VizWizDatasetForGeneration(
|
||||
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
return dataset
|
||||
Args:
|
||||
dataset_name (`str`):
|
||||
The name of the dataset.
|
||||
base_path (`str`):
|
||||
The base path to the dataset.
|
||||
"""
|
||||
dataset_name = dataset_name.lower()
|
||||
if dataset_name not in MAPPING_NAME_TO_DATASET:
|
||||
raise ValueError(f"Dataset {dataset_name} not registered.")
|
||||
for key in MAPPING_NAME_TO_DATASET[dataset_name]:
|
||||
MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
|
||||
return MAPPING_NAME_TO_DATASET[dataset_name]
|
||||
|
@ -28,8 +28,8 @@ class Conversation(TypedDict):
|
||||
|
||||
|
||||
class DatasetOutput(TypedDict):
|
||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
||||
chat: list[Conversation]
|
||||
answer: Optional[str]
|
||||
original: Any
|
||||
images: Optional[list[Image.Image]]
|
||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
@ -1,82 +1,21 @@
|
||||
from .factory import get_dataset
|
||||
import pytest
|
||||
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
|
||||
|
||||
# Get all registered dataset names for parameterization
|
||||
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
||||
|
||||
def test_gigaspeech():
|
||||
dataset = get_dataset("gigaspeech")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
@pytest.mark.parametrize("dataset_name", dataset_names)
|
||||
def test_registered_datasets(dataset_name):
|
||||
dataset = get_dataset(dataset_name)
|
||||
|
||||
# Test train split
|
||||
assert "train" in dataset, f"Train split not found in {dataset_name}"
|
||||
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
|
||||
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_chem():
|
||||
dataset = get_dataset("chem")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_ocrvqa200k():
|
||||
dataset = get_dataset("ocrvqa200k")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_textvqa():
|
||||
dataset = get_dataset("textvqa")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_scienceqa():
|
||||
dataset = get_dataset("scienceqa")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_refcoco():
|
||||
dataset = get_dataset("refcoco")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_refcocog():
|
||||
dataset = get_dataset("refcocog")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_refcocoplus():
|
||||
dataset = get_dataset("refcocoplus")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_VizWiz():
|
||||
dataset = get_dataset("vizwiz")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
# Test test split
|
||||
assert "test" in dataset, f"Test split not found in {dataset_name}"
|
||||
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
|
||||
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore
|
||||
|
@ -36,7 +36,7 @@ class ContinualTrainer(Trainer):
|
||||
# fisher = t
|
||||
|
||||
if regularization_args.lwf_enable:
|
||||
pass
|
||||
self.lwf_lambda = regularization_args.lwf_lambda
|
||||
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
|
Loading…
Reference in New Issue
Block a user