feat: 添加多个数据集支持,优化数据加载逻辑,更新数据集注册机制,调整VSCode设置

This commit is contained in:
YunyaoZhou 2025-05-19 13:27:35 +08:00
parent c100a59f0e
commit 1d2e7b9dcd
13 changed files with 296 additions and 260 deletions

View File

@ -4,8 +4,11 @@
"src/peft_repo/src/"
],
"python.analysis.exclude": [
"dataset/**/*"
"dataset/**/*",
],
"python.languageServer": "Default",
"python.terminal.activateEnvInCurrentTerminal": true
"python.terminal.activateEnvInCurrentTerminal": true,
"python.analysis.include": [
"src/**/*"
]
}

View File

@ -18,9 +18,13 @@ class GigaspeechDataset(Dataset):
self.audio_processor = audio_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[split] # type: ignore
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[self.split] # type: ignore
def __len__(self):
return len(self.data)
@ -55,7 +59,7 @@ class GigaspeechDataset(Dataset):
audios=[(audio, sampling_rate)],
chat=chat,
original=sample,
) # type: ignore
) # type: ignore
class GigaspeechDatasetForGeneration(GigaspeechDataset):
@ -88,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
chat=chat,
answer=text,
original=sample,
) # type: ignore
) # type: ignore
def test_gigaspeech():
@ -105,5 +109,20 @@ def test_gigaspeech():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
from .factory import register_dataset
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
register_dataset(
dataset_name="gigaspeech",
dataset=dataset,
tag=["audio", "text"],
)
if __name__ == "__main__":
test_gigaspeech()

View File

@ -28,10 +28,14 @@ class OCRVQADataset(Dataset):
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
elif split == "test":
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
self.split = split
self.ann_path = ann_path
def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
elif self.split == "test":
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
# self.instruction_pool = [
# "[vqa] {}",
@ -101,7 +105,7 @@ class OCRVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
) # type: ignore
) # type: ignore
class OCRVQADatasetForGeneration(OCRVQADataset):
@ -137,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
) # type: ignore
) # type: ignore
from .factory import register_dataset
from .format import dataset_dir as base_path
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])

View File

@ -24,9 +24,11 @@ class RefCOCODataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
@ -116,6 +118,18 @@ def test_RefCOCO():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcoco",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_RefCOCO()

View File

@ -24,9 +24,12 @@ class RefCOCOplusDataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
@ -116,6 +119,18 @@ def test_RefCOCOplus():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocoplus",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_RefCOCOplus()

View File

@ -24,9 +24,12 @@ class RefCOCOgDataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
@ -116,6 +119,18 @@ def test_RefCOCOg():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocog",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_RefCOCOg()

View File

@ -18,9 +18,12 @@ class ScienceQADataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
@ -116,6 +119,18 @@ def test_scienceQA():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="scienceqa",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_scienceQA()

View File

@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
vis_processor if vis_processor is not None else self._vis_processor
)
self.text_processor = text_processor
if split == "train":
self.split = split
self.ann_path = ann_path
self.vis_root = vis_root
def load_data(self):
if self.split == "train":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_train.json"),
vis_root=Path(vis_root, "train_images"),
Path(self.ann_path, "TextVQA_0.5.1_train.json"),
vis_root=Path(self.vis_root, "train_images"),
)
elif split == "test":
elif self.split == "test":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_val.json"),
vis_root=Path(vis_root, "train_images"),
Path(self.ann_path, "TextVQA_0.5.1_val.json"),
vis_root=Path(self.vis_root, "train_images"),
)
# self.instruction_pool = [
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
) # type: ignore
) # type: ignore
class TextVQADatasetForGeneration(TextVQADataset):
@ -158,5 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
) # type: ignore
) # type: ignore
from .format import dataset_dir
dataset = {
"train": TextVQADataset(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="test",
),
}
from .factory import register_dataset
register_dataset("textvqa", dataset, tag=["text", "image"])

View File

@ -21,6 +21,7 @@ class VizWizDataset(Dataset):
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.ann_path = ann_path
from .vis_processor import size_processor
@ -28,10 +29,13 @@ class VizWizDataset(Dataset):
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(Path(ann_path, "train.json"))
elif split == "test":
self.data = self.create_data(Path(ann_path, "val.json"))
self.split = split
def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "train.json"))
elif self.split == "test":
self.data = self.create_data(Path(self.ann_path, "val.json"))
# self.instruction_pool = [
# "[vqa] {}",
@ -43,7 +47,10 @@ class VizWizDataset(Dataset):
with open(ann_path, "r") as f:
data = json.load(f)
for i in range(len(data)):
if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]:
if (
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
and data[i]["answerable"]
):
imageFile = data[i]["image"]
processed_data.append(
{
@ -93,7 +100,7 @@ class VizWizDataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
) # type: ignore
) # type: ignore
class VizWizDatasetForGeneration(VizWizDataset):
@ -129,4 +136,33 @@ class VizWizDatasetForGeneration(VizWizDataset):
chat=chat,
answer=answer,
original=sample["original"],
) # type: ignore
) # type: ignore
from .format import dataset_dir
dataset = {
"train": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="train",
),
"test": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
"generation": VizWizDatasetForGeneration(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
}
from .factory import register_dataset
register_dataset(
dataset_name="vizwiz",
dataset=dataset,
tag=["image", "text"],
)

View File

@ -1,156 +1,79 @@
from torch.utils.data import Dataset
from typing import Literal
from typing import Literal,List
from pathlib import Path
from dataset_library.format import dataset_dir
MAPPING_NAME_TO_DATASET: dict[
str, dict[Literal["train", "test", "generation"], Dataset]
] = {}
def register_dataset(
dataset_name: str,
dataset: dict[Literal["train", "test", "generation"], Dataset],
tag: List[Literal["image", "text", "audio", "video"]] = [],
**kwargs,
) -> None:
"""
Register a dataset.
Args:
dataset_name (`str`):
The name of the dataset.
dataset (`Dataset`):
The dataset to register.
"""
dataset_name = dataset_name.lower()
if dataset_name in MAPPING_NAME_TO_DATASET:
raise ValueError(f"Dataset {dataset_name} already registered.")
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
from .VizWizDataset import (
VizWizDataset,
VizWizDatasetForGeneration,
)
def get_dataset(
dataset_name, base_path=dataset_dir
dataset_name: str, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
match dataset_name:
case "ocrvqa200k":
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
"""
Get a dataset by name.
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
case "chem":
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
dataset = {
"train": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="train",
),
"test": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
"generation": ChemDatasetForGeneration(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
}
case "gigaspeech":
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
case "textvqa":
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
dataset = {
"train": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
}
case "scienceqa":
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
case "refcoco":
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
case "refcocog":
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
case "refcocoplus":
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
case "vizwiz":
from .VizWizDataset import (
VizWizDataset,
VizWizDatasetForGeneration,
)
dataset = {
"train": VizWizDataset(
vis_root=Path(base_path, "vizwiz", "images", "train"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="train",
),
"test": VizWizDataset(
vis_root=Path(base_path, "vizwiz", "images", "val"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="test",
),
"generation": VizWizDatasetForGeneration(
vis_root=Path(base_path, "vizwiz", "images", "val"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="test",
),
}
return dataset
Args:
dataset_name (`str`):
The name of the dataset.
base_path (`str`):
The base path to the dataset.
"""
dataset_name = dataset_name.lower()
if dataset_name not in MAPPING_NAME_TO_DATASET:
raise ValueError(f"Dataset {dataset_name} not registered.")
for key in MAPPING_NAME_TO_DATASET[dataset_name]:
MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
return MAPPING_NAME_TO_DATASET[dataset_name]

View File

@ -28,8 +28,8 @@ class Conversation(TypedDict):
class DatasetOutput(TypedDict):
audios: Optional[list[Tuple[np.ndarray, int]]]
chat: list[Conversation]
answer: Optional[str]
original: Any
images: Optional[list[Image.Image]]
audios: Optional[list[Tuple[np.ndarray, int]]]

View File

@ -1,82 +1,21 @@
from .factory import get_dataset
import pytest
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
# Get all registered dataset names for parameterization
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
def test_gigaspeech():
dataset = get_dataset("gigaspeech")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
@pytest.mark.parametrize("dataset_name", dataset_names)
def test_registered_datasets(dataset_name):
dataset = get_dataset(dataset_name)
# Test train split
assert "train" in dataset, f"Train split not found in {dataset_name}"
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_chem():
dataset = get_dataset("chem")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_ocrvqa200k():
dataset = get_dataset("ocrvqa200k")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_textvqa():
dataset = get_dataset("textvqa")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_scienceqa():
dataset = get_dataset("scienceqa")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_refcoco():
dataset = get_dataset("refcoco")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_refcocog():
dataset = get_dataset("refcocog")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_refcocoplus():
dataset = get_dataset("refcocoplus")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_VizWiz():
dataset = get_dataset("vizwiz")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
# Test test split
assert "test" in dataset, f"Test split not found in {dataset_name}"
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore

View File

@ -36,7 +36,7 @@ class ContinualTrainer(Trainer):
# fisher = t
if regularization_args.lwf_enable:
pass
self.lwf_lambda = regularization_args.lwf_lambda
def create_accelerator_and_postprocess(self):