feat✨: 添加多个数据集支持,优化数据加载逻辑,更新数据集注册机制,调整VSCode设置
This commit is contained in:
parent
c100a59f0e
commit
1d2e7b9dcd
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
@ -4,8 +4,11 @@
|
|||||||
"src/peft_repo/src/"
|
"src/peft_repo/src/"
|
||||||
],
|
],
|
||||||
"python.analysis.exclude": [
|
"python.analysis.exclude": [
|
||||||
"dataset/**/*"
|
"dataset/**/*",
|
||||||
],
|
],
|
||||||
"python.languageServer": "Default",
|
"python.languageServer": "Default",
|
||||||
"python.terminal.activateEnvInCurrentTerminal": true
|
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||||
|
"python.analysis.include": [
|
||||||
|
"src/**/*"
|
||||||
|
]
|
||||||
}
|
}
|
@ -18,9 +18,13 @@ class GigaspeechDataset(Dataset):
|
|||||||
|
|
||||||
self.audio_processor = audio_processor
|
self.audio_processor = audio_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
|
|
||||||
self.data = gs[split] # type: ignore
|
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
|
||||||
|
self.data = gs[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -55,7 +59,7 @@ class GigaspeechDataset(Dataset):
|
|||||||
audios=[(audio, sampling_rate)],
|
audios=[(audio, sampling_rate)],
|
||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample,
|
original=sample,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||||
@ -88,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=text,
|
answer=text,
|
||||||
original=sample,
|
original=sample,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_gigaspeech():
|
def test_gigaspeech():
|
||||||
@ -105,5 +109,20 @@ def test_gigaspeech():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": GigaspeechDataset(split="train"),
|
||||||
|
"test": GigaspeechDataset(split="test"),
|
||||||
|
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="gigaspeech",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["audio", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_gigaspeech()
|
test_gigaspeech()
|
||||||
|
@ -28,10 +28,14 @@ class OCRVQADataset(Dataset):
|
|||||||
vis_processor if vis_processor is not None else size_processor
|
vis_processor if vis_processor is not None else size_processor
|
||||||
)
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
self.split = split
|
||||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
|
self.ann_path = ann_path
|
||||||
elif split == "test":
|
|
||||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
|
def load_data(self):
|
||||||
|
if self.split == "train":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
|
||||||
|
elif self.split == "test":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
# "[vqa] {}",
|
# "[vqa] {}",
|
||||||
@ -101,7 +105,7 @@ class OCRVQADataset(Dataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
images=[image],
|
images=[image],
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||||
@ -137,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
from .factory import register_dataset
|
||||||
|
from .format import dataset_dir as base_path
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": OCRVQADataset(
|
||||||
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": OCRVQADataset(
|
||||||
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": OCRVQADatasetForGeneration(
|
||||||
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])
|
@ -24,9 +24,11 @@ class RefCOCODataset(Dataset):
|
|||||||
|
|
||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -116,6 +118,18 @@ def test_RefCOCO():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": RefCOCODataset(split="val"),
|
||||||
|
"test": RefCOCODataset(split="test"),
|
||||||
|
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="refcoco",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_RefCOCO()
|
test_RefCOCO()
|
||||||
|
@ -24,9 +24,12 @@ class RefCOCOplusDataset(Dataset):
|
|||||||
|
|
||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -116,6 +119,18 @@ def test_RefCOCOplus():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": RefCOCOplusDataset(split="val"),
|
||||||
|
"test": RefCOCOplusDataset(split="testA"),
|
||||||
|
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="refcocoplus",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_RefCOCOplus()
|
test_RefCOCOplus()
|
||||||
|
@ -24,9 +24,12 @@ class RefCOCOgDataset(Dataset):
|
|||||||
|
|
||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -116,6 +119,18 @@ def test_RefCOCOg():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": RefCOCOgDataset(split="val"),
|
||||||
|
"test": RefCOCOgDataset(split="test"),
|
||||||
|
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="refcocog",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_RefCOCOg()
|
test_RefCOCOg()
|
||||||
|
@ -18,9 +18,12 @@ class ScienceQADataset(Dataset):
|
|||||||
|
|
||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
|
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -116,6 +119,18 @@ def test_scienceQA():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": ScienceQADataset(split="train"),
|
||||||
|
"test": ScienceQADataset(split="test"),
|
||||||
|
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="scienceqa",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_scienceQA()
|
test_scienceQA()
|
||||||
|
@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
|
|||||||
vis_processor if vis_processor is not None else self._vis_processor
|
vis_processor if vis_processor is not None else self._vis_processor
|
||||||
)
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
self.split = split
|
||||||
|
self.ann_path = ann_path
|
||||||
|
self.vis_root = vis_root
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
if self.split == "train":
|
||||||
self.data = self.create_data(
|
self.data = self.create_data(
|
||||||
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
Path(self.ann_path, "TextVQA_0.5.1_train.json"),
|
||||||
vis_root=Path(vis_root, "train_images"),
|
vis_root=Path(self.vis_root, "train_images"),
|
||||||
)
|
)
|
||||||
elif split == "test":
|
elif self.split == "test":
|
||||||
self.data = self.create_data(
|
self.data = self.create_data(
|
||||||
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
Path(self.ann_path, "TextVQA_0.5.1_val.json"),
|
||||||
vis_root=Path(vis_root, "train_images"),
|
vis_root=Path(self.vis_root, "train_images"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
images=[image],
|
images=[image],
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class TextVQADatasetForGeneration(TextVQADataset):
|
class TextVQADatasetForGeneration(TextVQADataset):
|
||||||
@ -158,5 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": TextVQADataset(
|
||||||
|
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||||
|
ann_path=Path(dataset_dir, "TextVQA"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": TextVQADataset(
|
||||||
|
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||||
|
ann_path=Path(dataset_dir, "TextVQA"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": TextVQADatasetForGeneration(
|
||||||
|
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||||
|
ann_path=Path(dataset_dir, "TextVQA"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
register_dataset("textvqa", dataset, tag=["text", "image"])
|
||||||
|
@ -21,6 +21,7 @@ class VizWizDataset(Dataset):
|
|||||||
ann_root (string): directory to store the annotation file
|
ann_root (string): directory to store the annotation file
|
||||||
"""
|
"""
|
||||||
self.vis_root = vis_root
|
self.vis_root = vis_root
|
||||||
|
self.ann_path = ann_path
|
||||||
|
|
||||||
from .vis_processor import size_processor
|
from .vis_processor import size_processor
|
||||||
|
|
||||||
@ -28,10 +29,13 @@ class VizWizDataset(Dataset):
|
|||||||
vis_processor if vis_processor is not None else size_processor
|
vis_processor if vis_processor is not None else size_processor
|
||||||
)
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
self.split = split
|
||||||
self.data = self.create_data(Path(ann_path, "train.json"))
|
|
||||||
elif split == "test":
|
def load_data(self):
|
||||||
self.data = self.create_data(Path(ann_path, "val.json"))
|
if self.split == "train":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "train.json"))
|
||||||
|
elif self.split == "test":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "val.json"))
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
# "[vqa] {}",
|
# "[vqa] {}",
|
||||||
@ -43,7 +47,10 @@ class VizWizDataset(Dataset):
|
|||||||
with open(ann_path, "r") as f:
|
with open(ann_path, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
for i in range(len(data)):
|
for i in range(len(data)):
|
||||||
if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]:
|
if (
|
||||||
|
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
|
||||||
|
and data[i]["answerable"]
|
||||||
|
):
|
||||||
imageFile = data[i]["image"]
|
imageFile = data[i]["image"]
|
||||||
processed_data.append(
|
processed_data.append(
|
||||||
{
|
{
|
||||||
@ -93,7 +100,7 @@ class VizWizDataset(Dataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
images=[image],
|
images=[image],
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class VizWizDatasetForGeneration(VizWizDataset):
|
class VizWizDatasetForGeneration(VizWizDataset):
|
||||||
@ -129,4 +136,33 @@ class VizWizDatasetForGeneration(VizWizDataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": VizWizDataset(
|
||||||
|
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
|
||||||
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": VizWizDataset(
|
||||||
|
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||||
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": VizWizDatasetForGeneration(
|
||||||
|
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||||
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="vizwiz",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
@ -1,156 +1,79 @@
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Literal
|
from typing import Literal,List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataset_library.format import dataset_dir
|
from dataset_library.format import dataset_dir
|
||||||
|
|
||||||
|
MAPPING_NAME_TO_DATASET: dict[
|
||||||
|
str, dict[Literal["train", "test", "generation"], Dataset]
|
||||||
|
] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_dataset(
|
||||||
|
dataset_name: str,
|
||||||
|
dataset: dict[Literal["train", "test", "generation"], Dataset],
|
||||||
|
tag: List[Literal["image", "text", "audio", "video"]] = [],
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_name (`str`):
|
||||||
|
The name of the dataset.
|
||||||
|
dataset (`Dataset`):
|
||||||
|
The dataset to register.
|
||||||
|
"""
|
||||||
|
dataset_name = dataset_name.lower()
|
||||||
|
if dataset_name in MAPPING_NAME_TO_DATASET:
|
||||||
|
raise ValueError(f"Dataset {dataset_name} already registered.")
|
||||||
|
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
|
||||||
|
|
||||||
|
|
||||||
|
from .GigaspeechDataset import (
|
||||||
|
GigaspeechDataset,
|
||||||
|
GigaspeechDatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||||
|
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||||
|
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||||
|
from .ScienceQADataset import (
|
||||||
|
ScienceQADataset,
|
||||||
|
ScienceQADatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .RefCOCODataset import (
|
||||||
|
RefCOCODataset,
|
||||||
|
RefCOCODatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .RefCOCOgDataset import (
|
||||||
|
RefCOCOgDataset,
|
||||||
|
RefCOCOgDatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .RefCOCOPlusDataset import (
|
||||||
|
RefCOCOplusDataset,
|
||||||
|
RefCOCOplusDatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .VizWizDataset import (
|
||||||
|
VizWizDataset,
|
||||||
|
VizWizDatasetForGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
dataset_name, base_path=dataset_dir
|
dataset_name: str, base_path=dataset_dir
|
||||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
"""
|
||||||
match dataset_name:
|
Get a dataset by name.
|
||||||
case "ocrvqa200k":
|
|
||||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
Args:
|
||||||
"train": OCRVQADataset(
|
dataset_name (`str`):
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
The name of the dataset.
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
base_path (`str`):
|
||||||
split="train",
|
The base path to the dataset.
|
||||||
),
|
"""
|
||||||
"test": OCRVQADataset(
|
dataset_name = dataset_name.lower()
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
if dataset_name not in MAPPING_NAME_TO_DATASET:
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
raise ValueError(f"Dataset {dataset_name} not registered.")
|
||||||
split="test",
|
for key in MAPPING_NAME_TO_DATASET[dataset_name]:
|
||||||
),
|
MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
|
||||||
"generation": OCRVQADatasetForGeneration(
|
return MAPPING_NAME_TO_DATASET[dataset_name]
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
case "chem":
|
|
||||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": ChemDataset(
|
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
|
||||||
ann_path=Path(base_path, "chem"),
|
|
||||||
split="train",
|
|
||||||
),
|
|
||||||
"test": ChemDataset(
|
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
|
||||||
ann_path=Path(base_path, "chem"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
"generation": ChemDatasetForGeneration(
|
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
|
||||||
ann_path=Path(base_path, "chem"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
case "gigaspeech":
|
|
||||||
from .GigaspeechDataset import (
|
|
||||||
GigaspeechDataset,
|
|
||||||
GigaspeechDatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": GigaspeechDataset(split="train"),
|
|
||||||
"test": GigaspeechDataset(split="test"),
|
|
||||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
|
||||||
}
|
|
||||||
|
|
||||||
case "textvqa":
|
|
||||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": TextVQADataset(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="train",
|
|
||||||
),
|
|
||||||
"test": TextVQADataset(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
"generation": TextVQADatasetForGeneration(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
case "scienceqa":
|
|
||||||
from .ScienceQADataset import (
|
|
||||||
ScienceQADataset,
|
|
||||||
ScienceQADatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": ScienceQADataset(split="train"),
|
|
||||||
"test": ScienceQADataset(split="test"),
|
|
||||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
|
||||||
}
|
|
||||||
|
|
||||||
case "refcoco":
|
|
||||||
from .RefCOCODataset import (
|
|
||||||
RefCOCODataset,
|
|
||||||
RefCOCODatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": RefCOCODataset(split="val"),
|
|
||||||
"test": RefCOCODataset(split="test"),
|
|
||||||
"generation": RefCOCODatasetForGeneration(split="test"),
|
|
||||||
}
|
|
||||||
|
|
||||||
case "refcocog":
|
|
||||||
from .RefCOCOgDataset import (
|
|
||||||
RefCOCOgDataset,
|
|
||||||
RefCOCOgDatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": RefCOCOgDataset(split="val"),
|
|
||||||
"test": RefCOCOgDataset(split="test"),
|
|
||||||
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
|
||||||
}
|
|
||||||
|
|
||||||
case "refcocoplus":
|
|
||||||
from .RefCOCOPlusDataset import (
|
|
||||||
RefCOCOplusDataset,
|
|
||||||
RefCOCOplusDatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": RefCOCOplusDataset(split="val"),
|
|
||||||
"test": RefCOCOplusDataset(split="testA"),
|
|
||||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
|
||||||
}
|
|
||||||
|
|
||||||
case "vizwiz":
|
|
||||||
from .VizWizDataset import (
|
|
||||||
VizWizDataset,
|
|
||||||
VizWizDatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": VizWizDataset(
|
|
||||||
vis_root=Path(base_path, "vizwiz", "images", "train"),
|
|
||||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
|
||||||
split="train",
|
|
||||||
),
|
|
||||||
"test": VizWizDataset(
|
|
||||||
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
|
||||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
"generation": VizWizDatasetForGeneration(
|
|
||||||
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
|
||||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
@ -28,8 +28,8 @@ class Conversation(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class DatasetOutput(TypedDict):
|
class DatasetOutput(TypedDict):
|
||||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
|
||||||
chat: list[Conversation]
|
chat: list[Conversation]
|
||||||
answer: Optional[str]
|
answer: Optional[str]
|
||||||
original: Any
|
original: Any
|
||||||
images: Optional[list[Image.Image]]
|
images: Optional[list[Image.Image]]
|
||||||
|
audios: Optional[list[Tuple[np.ndarray, int]]]
|
@ -1,82 +1,21 @@
|
|||||||
from .factory import get_dataset
|
import pytest
|
||||||
|
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
|
||||||
|
|
||||||
|
# Get all registered dataset names for parameterization
|
||||||
|
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
||||||
|
|
||||||
def test_gigaspeech():
|
@pytest.mark.parametrize("dataset_name", dataset_names)
|
||||||
dataset = get_dataset("gigaspeech")
|
def test_registered_datasets(dataset_name):
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
dataset = get_dataset(dataset_name)
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
# Test train split
|
||||||
|
assert "train" in dataset, f"Train split not found in {dataset_name}"
|
||||||
|
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
|
||||||
|
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
|
||||||
|
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
# Test test split
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
assert "test" in dataset, f"Test split not found in {dataset_name}"
|
||||||
|
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
|
||||||
|
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
|
||||||
def test_chem():
|
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore
|
||||||
dataset = get_dataset("chem")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_ocrvqa200k():
|
|
||||||
dataset = get_dataset("ocrvqa200k")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_textvqa():
|
|
||||||
dataset = get_dataset("textvqa")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_scienceqa():
|
|
||||||
dataset = get_dataset("scienceqa")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_refcoco():
|
|
||||||
dataset = get_dataset("refcoco")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_refcocog():
|
|
||||||
dataset = get_dataset("refcocog")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_refcocoplus():
|
|
||||||
dataset = get_dataset("refcocoplus")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_VizWiz():
|
|
||||||
dataset = get_dataset("vizwiz")
|
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
@ -36,7 +36,7 @@ class ContinualTrainer(Trainer):
|
|||||||
# fisher = t
|
# fisher = t
|
||||||
|
|
||||||
if regularization_args.lwf_enable:
|
if regularization_args.lwf_enable:
|
||||||
pass
|
self.lwf_lambda = regularization_args.lwf_lambda
|
||||||
|
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user