feat: 添加多个数据集支持,优化数据加载逻辑,更新数据集注册机制,调整VSCode设置

This commit is contained in:
YunyaoZhou 2025-05-19 13:27:35 +08:00
parent c100a59f0e
commit 1d2e7b9dcd
13 changed files with 296 additions and 260 deletions

View File

@ -4,8 +4,11 @@
"src/peft_repo/src/" "src/peft_repo/src/"
], ],
"python.analysis.exclude": [ "python.analysis.exclude": [
"dataset/**/*" "dataset/**/*",
], ],
"python.languageServer": "Default", "python.languageServer": "Default",
"python.terminal.activateEnvInCurrentTerminal": true "python.terminal.activateEnvInCurrentTerminal": true,
"python.analysis.include": [
"src/**/*"
]
} }

View File

@ -18,9 +18,13 @@ class GigaspeechDataset(Dataset):
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.text_processor = text_processor self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir from .format import dataset_dir
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[split] # type: ignore gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[self.split] # type: ignore
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -55,7 +59,7 @@ class GigaspeechDataset(Dataset):
audios=[(audio, sampling_rate)], audios=[(audio, sampling_rate)],
chat=chat, chat=chat,
original=sample, original=sample,
) # type: ignore ) # type: ignore
class GigaspeechDatasetForGeneration(GigaspeechDataset): class GigaspeechDatasetForGeneration(GigaspeechDataset):
@ -88,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
chat=chat, chat=chat,
answer=text, answer=text,
original=sample, original=sample,
) # type: ignore ) # type: ignore
def test_gigaspeech(): def test_gigaspeech():
@ -105,5 +109,20 @@ def test_gigaspeech():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
from .factory import register_dataset
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
register_dataset(
dataset_name="gigaspeech",
dataset=dataset,
tag=["audio", "text"],
)
if __name__ == "__main__": if __name__ == "__main__":
test_gigaspeech() test_gigaspeech()

View File

@ -28,10 +28,14 @@ class OCRVQADataset(Dataset):
vis_processor if vis_processor is not None else size_processor vis_processor if vis_processor is not None else size_processor
) )
self.text_processor = text_processor self.text_processor = text_processor
if split == "train": self.split = split
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1) self.ann_path = ann_path
elif split == "test":
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3) def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
elif self.split == "test":
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
# self.instruction_pool = [ # self.instruction_pool = [
# "[vqa] {}", # "[vqa] {}",
@ -101,7 +105,7 @@ class OCRVQADataset(Dataset):
chat=chat, chat=chat,
original=sample["original"], original=sample["original"],
images=[image], images=[image],
) # type: ignore ) # type: ignore
class OCRVQADatasetForGeneration(OCRVQADataset): class OCRVQADatasetForGeneration(OCRVQADataset):
@ -137,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
chat=chat, chat=chat,
answer=answer, answer=answer,
original=sample["original"], original=sample["original"],
) # type: ignore ) # type: ignore
from .factory import register_dataset
from .format import dataset_dir as base_path
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])

View File

@ -24,9 +24,11 @@ class RefCOCODataset(Dataset):
self.vis_processor = vis_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore self.data = ds[self.split] # type: ignore
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -116,6 +118,18 @@ def test_RefCOCO():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcoco",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__": if __name__ == "__main__":
test_RefCOCO() test_RefCOCO()

View File

@ -24,9 +24,12 @@ class RefCOCOplusDataset(Dataset):
self.vis_processor = vis_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore self.data = ds[self.split] # type: ignore
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -116,6 +119,18 @@ def test_RefCOCOplus():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocoplus",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__": if __name__ == "__main__":
test_RefCOCOplus() test_RefCOCOplus()

View File

@ -24,9 +24,12 @@ class RefCOCOgDataset(Dataset):
self.vis_processor = vis_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore self.data = ds[self.split] # type: ignore
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -116,6 +119,18 @@ def test_RefCOCOg():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocog",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__": if __name__ == "__main__":
test_RefCOCOg() test_RefCOCOg()

View File

@ -18,9 +18,12 @@ class ScienceQADataset(Dataset):
self.vis_processor = vis_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore self.data = ds[self.split] # type: ignore
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -116,6 +119,18 @@ def test_scienceQA():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="scienceqa",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__": if __name__ == "__main__":
test_scienceQA() test_scienceQA()

View File

@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
vis_processor if vis_processor is not None else self._vis_processor vis_processor if vis_processor is not None else self._vis_processor
) )
self.text_processor = text_processor self.text_processor = text_processor
if split == "train": self.split = split
self.ann_path = ann_path
self.vis_root = vis_root
def load_data(self):
if self.split == "train":
self.data = self.create_data( self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_train.json"), Path(self.ann_path, "TextVQA_0.5.1_train.json"),
vis_root=Path(vis_root, "train_images"), vis_root=Path(self.vis_root, "train_images"),
) )
elif split == "test": elif self.split == "test":
self.data = self.create_data( self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_val.json"), Path(self.ann_path, "TextVQA_0.5.1_val.json"),
vis_root=Path(vis_root, "train_images"), vis_root=Path(self.vis_root, "train_images"),
) )
# self.instruction_pool = [ # self.instruction_pool = [
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
chat=chat, chat=chat,
original=sample["original"], original=sample["original"],
images=[image], images=[image],
) # type: ignore ) # type: ignore
class TextVQADatasetForGeneration(TextVQADataset): class TextVQADatasetForGeneration(TextVQADataset):
@ -158,5 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
chat=chat, chat=chat,
answer=answer, answer=answer,
original=sample["original"], original=sample["original"],
) # type: ignore ) # type: ignore
from .format import dataset_dir
dataset = {
"train": TextVQADataset(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="test",
),
}
from .factory import register_dataset
register_dataset("textvqa", dataset, tag=["text", "image"])

View File

@ -21,6 +21,7 @@ class VizWizDataset(Dataset):
ann_root (string): directory to store the annotation file ann_root (string): directory to store the annotation file
""" """
self.vis_root = vis_root self.vis_root = vis_root
self.ann_path = ann_path
from .vis_processor import size_processor from .vis_processor import size_processor
@ -28,10 +29,13 @@ class VizWizDataset(Dataset):
vis_processor if vis_processor is not None else size_processor vis_processor if vis_processor is not None else size_processor
) )
self.text_processor = text_processor self.text_processor = text_processor
if split == "train": self.split = split
self.data = self.create_data(Path(ann_path, "train.json"))
elif split == "test": def load_data(self):
self.data = self.create_data(Path(ann_path, "val.json")) if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "train.json"))
elif self.split == "test":
self.data = self.create_data(Path(self.ann_path, "val.json"))
# self.instruction_pool = [ # self.instruction_pool = [
# "[vqa] {}", # "[vqa] {}",
@ -43,7 +47,10 @@ class VizWizDataset(Dataset):
with open(ann_path, "r") as f: with open(ann_path, "r") as f:
data = json.load(f) data = json.load(f)
for i in range(len(data)): for i in range(len(data)):
if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]: if (
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
and data[i]["answerable"]
):
imageFile = data[i]["image"] imageFile = data[i]["image"]
processed_data.append( processed_data.append(
{ {
@ -93,7 +100,7 @@ class VizWizDataset(Dataset):
chat=chat, chat=chat,
original=sample["original"], original=sample["original"],
images=[image], images=[image],
) # type: ignore ) # type: ignore
class VizWizDatasetForGeneration(VizWizDataset): class VizWizDatasetForGeneration(VizWizDataset):
@ -129,4 +136,33 @@ class VizWizDatasetForGeneration(VizWizDataset):
chat=chat, chat=chat,
answer=answer, answer=answer,
original=sample["original"], original=sample["original"],
) # type: ignore ) # type: ignore
from .format import dataset_dir
dataset = {
"train": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="train",
),
"test": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
"generation": VizWizDatasetForGeneration(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
}
from .factory import register_dataset
register_dataset(
dataset_name="vizwiz",
dataset=dataset,
tag=["image", "text"],
)

View File

@ -1,156 +1,79 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import Literal from typing import Literal,List
from pathlib import Path from pathlib import Path
from dataset_library.format import dataset_dir from dataset_library.format import dataset_dir
MAPPING_NAME_TO_DATASET: dict[
str, dict[Literal["train", "test", "generation"], Dataset]
] = {}
def register_dataset(
dataset_name: str,
dataset: dict[Literal["train", "test", "generation"], Dataset],
tag: List[Literal["image", "text", "audio", "video"]] = [],
**kwargs,
) -> None:
"""
Register a dataset.
Args:
dataset_name (`str`):
The name of the dataset.
dataset (`Dataset`):
The dataset to register.
"""
dataset_name = dataset_name.lower()
if dataset_name in MAPPING_NAME_TO_DATASET:
raise ValueError(f"Dataset {dataset_name} already registered.")
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
from .VizWizDataset import (
VizWizDataset,
VizWizDatasetForGeneration,
)
def get_dataset( def get_dataset(
dataset_name, base_path=dataset_dir dataset_name: str, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]: ) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {} """
match dataset_name: Get a dataset by name.
case "ocrvqa200k":
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
dataset = { Args:
"train": OCRVQADataset( dataset_name (`str`):
vis_root=Path(base_path, "OCR-VQA-200K", "images"), The name of the dataset.
ann_path=Path(base_path, "OCR-VQA-200K"), base_path (`str`):
split="train", The base path to the dataset.
), """
"test": OCRVQADataset( dataset_name = dataset_name.lower()
vis_root=Path(base_path, "OCR-VQA-200K", "images"), if dataset_name not in MAPPING_NAME_TO_DATASET:
ann_path=Path(base_path, "OCR-VQA-200K"), raise ValueError(f"Dataset {dataset_name} not registered.")
split="test", for key in MAPPING_NAME_TO_DATASET[dataset_name]:
), MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
"generation": OCRVQADatasetForGeneration( return MAPPING_NAME_TO_DATASET[dataset_name]
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
case "chem":
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
dataset = {
"train": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="train",
),
"test": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
"generation": ChemDatasetForGeneration(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
}
case "gigaspeech":
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
case "textvqa":
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
dataset = {
"train": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
}
case "scienceqa":
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
case "refcoco":
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
case "refcocog":
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
case "refcocoplus":
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
case "vizwiz":
from .VizWizDataset import (
VizWizDataset,
VizWizDatasetForGeneration,
)
dataset = {
"train": VizWizDataset(
vis_root=Path(base_path, "vizwiz", "images", "train"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="train",
),
"test": VizWizDataset(
vis_root=Path(base_path, "vizwiz", "images", "val"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="test",
),
"generation": VizWizDatasetForGeneration(
vis_root=Path(base_path, "vizwiz", "images", "val"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="test",
),
}
return dataset

View File

@ -28,8 +28,8 @@ class Conversation(TypedDict):
class DatasetOutput(TypedDict): class DatasetOutput(TypedDict):
audios: Optional[list[Tuple[np.ndarray, int]]]
chat: list[Conversation] chat: list[Conversation]
answer: Optional[str] answer: Optional[str]
original: Any original: Any
images: Optional[list[Image.Image]] images: Optional[list[Image.Image]]
audios: Optional[list[Tuple[np.ndarray, int]]]

View File

@ -1,82 +1,21 @@
from .factory import get_dataset import pytest
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
# Get all registered dataset names for parameterization
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
def test_gigaspeech(): @pytest.mark.parametrize("dataset_name", dataset_names)
dataset = get_dataset("gigaspeech") def test_registered_datasets(dataset_name):
assert len(dataset["train"]) > 0 # type: ignore dataset = get_dataset(dataset_name)
assert len(dataset["train"][0]["chat"]) > 0
# Test train split
assert "train" in dataset, f"Train split not found in {dataset_name}"
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
assert len(dataset["test"]) > 0 # type: ignore # Test test split
assert len(dataset["test"][0]["chat"]) > 0 assert "test" in dataset, f"Test split not found in {dataset_name}"
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
def test_chem(): assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore
dataset = get_dataset("chem")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_ocrvqa200k():
dataset = get_dataset("ocrvqa200k")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_textvqa():
dataset = get_dataset("textvqa")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_scienceqa():
dataset = get_dataset("scienceqa")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_refcoco():
dataset = get_dataset("refcoco")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_refcocog():
dataset = get_dataset("refcocog")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_refcocoplus():
dataset = get_dataset("refcocoplus")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_VizWiz():
dataset = get_dataset("vizwiz")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0

View File

@ -36,7 +36,7 @@ class ContinualTrainer(Trainer):
# fisher = t # fisher = t
if regularization_args.lwf_enable: if regularization_args.lwf_enable:
pass self.lwf_lambda = regularization_args.lwf_lambda
def create_accelerator_and_postprocess(self): def create_accelerator_and_postprocess(self):