From 1d2e7b9dcd97cb0bb91d206c55214c7a6fdd66aa Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Mon, 19 May 2025 13:27:35 +0800 Subject: [PATCH] =?UTF-8?q?feat=E2=9C=A8:=20=E6=B7=BB=E5=8A=A0=E5=A4=9A?= =?UTF-8?q?=E4=B8=AA=E6=95=B0=E6=8D=AE=E9=9B=86=E6=94=AF=E6=8C=81=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE=E5=8A=A0=E8=BD=BD=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=95=B0=E6=8D=AE=E9=9B=86?= =?UTF-8?q?=E6=B3=A8=E5=86=8C=E6=9C=BA=E5=88=B6=EF=BC=8C=E8=B0=83=E6=95=B4?= =?UTF-8?q?VSCode=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 7 +- src/dataset_library/GigaspeechDataset.py | 27 ++- src/dataset_library/OCRVQA200KDataset.py | 40 +++- src/dataset_library/RefCOCODataset.py | 16 +- src/dataset_library/RefCOCOPlusDataset.py | 17 +- src/dataset_library/RefCOCOgDataset.py | 17 +- src/dataset_library/ScienceQADataset.py | 17 +- src/dataset_library/TextVQADataset.py | 45 ++++- src/dataset_library/VizWizDataset.py | 50 ++++- src/dataset_library/factory.py | 219 +++++++--------------- src/dataset_library/format.py | 2 +- src/dataset_library/test_dataset.py | 97 ++-------- src/utils/trainer.py | 2 +- 13 files changed, 296 insertions(+), 260 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0ddb8eb..543f834 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,8 +4,11 @@ "src/peft_repo/src/" ], "python.analysis.exclude": [ - "dataset/**/*" + "dataset/**/*", ], "python.languageServer": "Default", - "python.terminal.activateEnvInCurrentTerminal": true + "python.terminal.activateEnvInCurrentTerminal": true, + "python.analysis.include": [ + "src/**/*" + ] } \ No newline at end of file diff --git a/src/dataset_library/GigaspeechDataset.py b/src/dataset_library/GigaspeechDataset.py index ee32c38..a7b615c 100644 --- a/src/dataset_library/GigaspeechDataset.py +++ b/src/dataset_library/GigaspeechDataset.py @@ -18,9 +18,13 @@ class GigaspeechDataset(Dataset): self.audio_processor = audio_processor self.text_processor = text_processor + self.split = split + + def load_data(self): from .format import dataset_dir - gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore - self.data = gs[split] # type: ignore + + gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore + self.data = gs[self.split] # type: ignore def __len__(self): return len(self.data) @@ -55,7 +59,7 @@ class GigaspeechDataset(Dataset): audios=[(audio, sampling_rate)], chat=chat, original=sample, - ) # type: ignore + ) # type: ignore class GigaspeechDatasetForGeneration(GigaspeechDataset): @@ -88,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset): chat=chat, answer=text, original=sample, - ) # type: ignore + ) # type: ignore def test_gigaspeech(): @@ -105,5 +109,20 @@ def test_gigaspeech(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 + +from .factory import register_dataset + +dataset = { + "train": GigaspeechDataset(split="train"), + "test": GigaspeechDataset(split="test"), + "generation": GigaspeechDatasetForGeneration(split="test"), +} + +register_dataset( + dataset_name="gigaspeech", + dataset=dataset, + tag=["audio", "text"], +) + if __name__ == "__main__": test_gigaspeech() diff --git a/src/dataset_library/OCRVQA200KDataset.py b/src/dataset_library/OCRVQA200KDataset.py index f95f254..625b986 100644 --- a/src/dataset_library/OCRVQA200KDataset.py +++ b/src/dataset_library/OCRVQA200KDataset.py @@ -28,10 +28,14 @@ class OCRVQADataset(Dataset): vis_processor if vis_processor is not None else size_processor ) self.text_processor = text_processor - if split == "train": - self.data = self.create_data(Path(ann_path, "dataset.json"), split=1) - elif split == "test": - self.data = self.create_data(Path(ann_path, "dataset.json"), split=3) + self.split = split + self.ann_path = ann_path + + def load_data(self): + if self.split == "train": + self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1) + elif self.split == "test": + self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3) # self.instruction_pool = [ # "[vqa] {}", @@ -101,7 +105,7 @@ class OCRVQADataset(Dataset): chat=chat, original=sample["original"], images=[image], - ) # type: ignore + ) # type: ignore class OCRVQADatasetForGeneration(OCRVQADataset): @@ -137,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset): chat=chat, answer=answer, original=sample["original"], - ) # type: ignore + ) # type: ignore + + +from .factory import register_dataset +from .format import dataset_dir as base_path + +dataset = { + "train": OCRVQADataset( + vis_root=Path(base_path, "OCR-VQA-200K", "images"), + ann_path=Path(base_path, "OCR-VQA-200K"), + split="train", + ), + "test": OCRVQADataset( + vis_root=Path(base_path, "OCR-VQA-200K", "images"), + ann_path=Path(base_path, "OCR-VQA-200K"), + split="test", + ), + "generation": OCRVQADatasetForGeneration( + vis_root=Path(base_path, "OCR-VQA-200K", "images"), + ann_path=Path(base_path, "OCR-VQA-200K"), + split="test", + ), +} + +register_dataset("ocrvqa200k", dataset, tag=["image", "text"]) \ No newline at end of file diff --git a/src/dataset_library/RefCOCODataset.py b/src/dataset_library/RefCOCODataset.py index c7578fc..a165ad5 100644 --- a/src/dataset_library/RefCOCODataset.py +++ b/src/dataset_library/RefCOCODataset.py @@ -24,9 +24,11 @@ class RefCOCODataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor + self.split = split + def load_data(self): from .format import dataset_dir ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore - self.data = ds[split] # type: ignore + self.data = ds[self.split] # type: ignore def __len__(self): return len(self.data) @@ -116,6 +118,18 @@ def test_RefCOCO(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 +dataset = { + "train": RefCOCODataset(split="val"), + "test": RefCOCODataset(split="test"), + "generation": RefCOCODatasetForGeneration(split="test"), + } +from .factory import register_dataset +register_dataset( + dataset_name="refcoco", + dataset=dataset, + tag=["image", "text"], +) + if __name__ == "__main__": test_RefCOCO() diff --git a/src/dataset_library/RefCOCOPlusDataset.py b/src/dataset_library/RefCOCOPlusDataset.py index 8d574b3..95a6c86 100644 --- a/src/dataset_library/RefCOCOPlusDataset.py +++ b/src/dataset_library/RefCOCOPlusDataset.py @@ -24,9 +24,12 @@ class RefCOCOplusDataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor + self.split = split + + def load_data(self): from .format import dataset_dir ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore - self.data = ds[split] # type: ignore + self.data = ds[self.split] # type: ignore def __len__(self): return len(self.data) @@ -116,6 +119,18 @@ def test_RefCOCOplus(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 +dataset = { + "train": RefCOCOplusDataset(split="val"), + "test": RefCOCOplusDataset(split="testA"), + "generation": RefCOCOplusDatasetForGeneration(split="testA"), + } +from .factory import register_dataset +register_dataset( + dataset_name="refcocoplus", + dataset=dataset, + tag=["image", "text"], +) + if __name__ == "__main__": test_RefCOCOplus() diff --git a/src/dataset_library/RefCOCOgDataset.py b/src/dataset_library/RefCOCOgDataset.py index 2121491..73fb52f 100644 --- a/src/dataset_library/RefCOCOgDataset.py +++ b/src/dataset_library/RefCOCOgDataset.py @@ -24,9 +24,12 @@ class RefCOCOgDataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor + self.split = split + + def load_data(self): from .format import dataset_dir ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore - self.data = ds[split] # type: ignore + self.data = ds[self.split] # type: ignore def __len__(self): return len(self.data) @@ -116,6 +119,18 @@ def test_RefCOCOg(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 +dataset = { + "train": RefCOCOgDataset(split="val"), + "test": RefCOCOgDataset(split="test"), + "generation": RefCOCOgDatasetForGeneration(split="test"), + } +from .factory import register_dataset +register_dataset( + dataset_name="refcocog", + dataset=dataset, + tag=["image", "text"], +) + if __name__ == "__main__": test_RefCOCOg() diff --git a/src/dataset_library/ScienceQADataset.py b/src/dataset_library/ScienceQADataset.py index 8b4d279..d51f165 100644 --- a/src/dataset_library/ScienceQADataset.py +++ b/src/dataset_library/ScienceQADataset.py @@ -18,9 +18,12 @@ class ScienceQADataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor + self.split = split + + def load_data(self): from .format import dataset_dir ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore - self.data = ds[split] # type: ignore + self.data = ds[self.split] # type: ignore def __len__(self): return len(self.data) @@ -116,6 +119,18 @@ def test_scienceQA(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 +dataset = { + "train": ScienceQADataset(split="train"), + "test": ScienceQADataset(split="test"), + "generation": ScienceQADatasetForGeneration(split="test"), +} +from .factory import register_dataset +register_dataset( + dataset_name="scienceqa", + dataset=dataset, + tag=["image", "text"], +) + if __name__ == "__main__": test_scienceQA() diff --git a/src/dataset_library/TextVQADataset.py b/src/dataset_library/TextVQADataset.py index 51fe8d8..62d0526 100644 --- a/src/dataset_library/TextVQADataset.py +++ b/src/dataset_library/TextVQADataset.py @@ -25,15 +25,20 @@ class TextVQADataset(Dataset): vis_processor if vis_processor is not None else self._vis_processor ) self.text_processor = text_processor - if split == "train": + self.split = split + self.ann_path = ann_path + self.vis_root = vis_root + + def load_data(self): + if self.split == "train": self.data = self.create_data( - Path(ann_path, "TextVQA_0.5.1_train.json"), - vis_root=Path(vis_root, "train_images"), + Path(self.ann_path, "TextVQA_0.5.1_train.json"), + vis_root=Path(self.vis_root, "train_images"), ) - elif split == "test": + elif self.split == "test": self.data = self.create_data( - Path(ann_path, "TextVQA_0.5.1_val.json"), - vis_root=Path(vis_root, "train_images"), + Path(self.ann_path, "TextVQA_0.5.1_val.json"), + vis_root=Path(self.vis_root, "train_images"), ) # self.instruction_pool = [ @@ -124,7 +129,7 @@ class TextVQADataset(Dataset): chat=chat, original=sample["original"], images=[image], - ) # type: ignore + ) # type: ignore class TextVQADatasetForGeneration(TextVQADataset): @@ -158,5 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset): chat=chat, answer=answer, original=sample["original"], - ) # type: ignore + ) # type: ignore + +from .format import dataset_dir + +dataset = { + "train": TextVQADataset( + vis_root=Path(dataset_dir, "TextVQA", "images"), + ann_path=Path(dataset_dir, "TextVQA"), + split="train", + ), + "test": TextVQADataset( + vis_root=Path(dataset_dir, "TextVQA", "images"), + ann_path=Path(dataset_dir, "TextVQA"), + split="test", + ), + "generation": TextVQADatasetForGeneration( + vis_root=Path(dataset_dir, "TextVQA", "images"), + ann_path=Path(dataset_dir, "TextVQA"), + split="test", + ), +} + +from .factory import register_dataset + +register_dataset("textvqa", dataset, tag=["text", "image"]) diff --git a/src/dataset_library/VizWizDataset.py b/src/dataset_library/VizWizDataset.py index fa6aa09..756734a 100644 --- a/src/dataset_library/VizWizDataset.py +++ b/src/dataset_library/VizWizDataset.py @@ -21,6 +21,7 @@ class VizWizDataset(Dataset): ann_root (string): directory to store the annotation file """ self.vis_root = vis_root + self.ann_path = ann_path from .vis_processor import size_processor @@ -28,10 +29,13 @@ class VizWizDataset(Dataset): vis_processor if vis_processor is not None else size_processor ) self.text_processor = text_processor - if split == "train": - self.data = self.create_data(Path(ann_path, "train.json")) - elif split == "test": - self.data = self.create_data(Path(ann_path, "val.json")) + self.split = split + + def load_data(self): + if self.split == "train": + self.data = self.create_data(Path(self.ann_path, "train.json")) + elif self.split == "test": + self.data = self.create_data(Path(self.ann_path, "val.json")) # self.instruction_pool = [ # "[vqa] {}", @@ -43,7 +47,10 @@ class VizWizDataset(Dataset): with open(ann_path, "r") as f: data = json.load(f) for i in range(len(data)): - if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]: + if ( + os.path.exists(os.path.join(self.vis_root, data[i]["image"])) + and data[i]["answerable"] + ): imageFile = data[i]["image"] processed_data.append( { @@ -93,7 +100,7 @@ class VizWizDataset(Dataset): chat=chat, original=sample["original"], images=[image], - ) # type: ignore + ) # type: ignore class VizWizDatasetForGeneration(VizWizDataset): @@ -129,4 +136,33 @@ class VizWizDatasetForGeneration(VizWizDataset): chat=chat, answer=answer, original=sample["original"], - ) # type: ignore + ) # type: ignore + + +from .format import dataset_dir + +dataset = { + "train": VizWizDataset( + vis_root=Path(dataset_dir, "vizwiz", "images", "train"), + ann_path=Path(dataset_dir, "vizwiz", "Annotations"), + split="train", + ), + "test": VizWizDataset( + vis_root=Path(dataset_dir, "vizwiz", "images", "val"), + ann_path=Path(dataset_dir, "vizwiz", "Annotations"), + split="test", + ), + "generation": VizWizDatasetForGeneration( + vis_root=Path(dataset_dir, "vizwiz", "images", "val"), + ann_path=Path(dataset_dir, "vizwiz", "Annotations"), + split="test", + ), +} + +from .factory import register_dataset + +register_dataset( + dataset_name="vizwiz", + dataset=dataset, + tag=["image", "text"], +) diff --git a/src/dataset_library/factory.py b/src/dataset_library/factory.py index 624c71f..490412b 100644 --- a/src/dataset_library/factory.py +++ b/src/dataset_library/factory.py @@ -1,156 +1,79 @@ from torch.utils.data import Dataset -from typing import Literal +from typing import Literal,List from pathlib import Path from dataset_library.format import dataset_dir +MAPPING_NAME_TO_DATASET: dict[ + str, dict[Literal["train", "test", "generation"], Dataset] +] = {} + + +def register_dataset( + dataset_name: str, + dataset: dict[Literal["train", "test", "generation"], Dataset], + tag: List[Literal["image", "text", "audio", "video"]] = [], + **kwargs, +) -> None: + """ + Register a dataset. + + Args: + dataset_name (`str`): + The name of the dataset. + dataset (`Dataset`): + The dataset to register. + """ + dataset_name = dataset_name.lower() + if dataset_name in MAPPING_NAME_TO_DATASET: + raise ValueError(f"Dataset {dataset_name} already registered.") + MAPPING_NAME_TO_DATASET[dataset_name] = dataset + + +from .GigaspeechDataset import ( + GigaspeechDataset, + GigaspeechDatasetForGeneration, +) +from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration +from .ChemDataset import ChemDataset, ChemDatasetForGeneration +from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration +from .ScienceQADataset import ( + ScienceQADataset, + ScienceQADatasetForGeneration, +) +from .RefCOCODataset import ( + RefCOCODataset, + RefCOCODatasetForGeneration, +) +from .RefCOCOgDataset import ( + RefCOCOgDataset, + RefCOCOgDatasetForGeneration, +) +from .RefCOCOPlusDataset import ( + RefCOCOplusDataset, + RefCOCOplusDatasetForGeneration, +) +from .VizWizDataset import ( + VizWizDataset, + VizWizDatasetForGeneration, +) + + def get_dataset( - dataset_name, base_path=dataset_dir + dataset_name: str, base_path=dataset_dir ) -> dict[Literal["train", "test", "generation"], Dataset]: - dataset: dict[Literal["train", "test", "generation"], Dataset] = {} - match dataset_name: - case "ocrvqa200k": - from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration + """ + Get a dataset by name. - dataset = { - "train": OCRVQADataset( - vis_root=Path(base_path, "OCR-VQA-200K", "images"), - ann_path=Path(base_path, "OCR-VQA-200K"), - split="train", - ), - "test": OCRVQADataset( - vis_root=Path(base_path, "OCR-VQA-200K", "images"), - ann_path=Path(base_path, "OCR-VQA-200K"), - split="test", - ), - "generation": OCRVQADatasetForGeneration( - vis_root=Path(base_path, "OCR-VQA-200K", "images"), - ann_path=Path(base_path, "OCR-VQA-200K"), - split="test", - ), - } - case "chem": - from .ChemDataset import ChemDataset, ChemDatasetForGeneration - - dataset = { - "train": ChemDataset( - vis_root=Path(base_path, "chem", "images"), - ann_path=Path(base_path, "chem"), - split="train", - ), - "test": ChemDataset( - vis_root=Path(base_path, "chem", "images"), - ann_path=Path(base_path, "chem"), - split="test", - ), - "generation": ChemDatasetForGeneration( - vis_root=Path(base_path, "chem", "images"), - ann_path=Path(base_path, "chem"), - split="test", - ), - } - case "gigaspeech": - from .GigaspeechDataset import ( - GigaspeechDataset, - GigaspeechDatasetForGeneration, - ) - - dataset = { - "train": GigaspeechDataset(split="train"), - "test": GigaspeechDataset(split="test"), - "generation": GigaspeechDatasetForGeneration(split="test"), - } - - case "textvqa": - from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration - - dataset = { - "train": TextVQADataset( - vis_root=Path(base_path, "TextVQA", "images"), - ann_path=Path(base_path, "TextVQA"), - split="train", - ), - "test": TextVQADataset( - vis_root=Path(base_path, "TextVQA", "images"), - ann_path=Path(base_path, "TextVQA"), - split="test", - ), - "generation": TextVQADatasetForGeneration( - vis_root=Path(base_path, "TextVQA", "images"), - ann_path=Path(base_path, "TextVQA"), - split="test", - ), - } - case "scienceqa": - from .ScienceQADataset import ( - ScienceQADataset, - ScienceQADatasetForGeneration, - ) - - dataset = { - "train": ScienceQADataset(split="train"), - "test": ScienceQADataset(split="test"), - "generation": ScienceQADatasetForGeneration(split="test"), - } - - case "refcoco": - from .RefCOCODataset import ( - RefCOCODataset, - RefCOCODatasetForGeneration, - ) - - dataset = { - "train": RefCOCODataset(split="val"), - "test": RefCOCODataset(split="test"), - "generation": RefCOCODatasetForGeneration(split="test"), - } - - case "refcocog": - from .RefCOCOgDataset import ( - RefCOCOgDataset, - RefCOCOgDatasetForGeneration, - ) - - dataset = { - "train": RefCOCOgDataset(split="val"), - "test": RefCOCOgDataset(split="test"), - "generation": RefCOCOgDatasetForGeneration(split="test"), - } - - case "refcocoplus": - from .RefCOCOPlusDataset import ( - RefCOCOplusDataset, - RefCOCOplusDatasetForGeneration, - ) - - dataset = { - "train": RefCOCOplusDataset(split="val"), - "test": RefCOCOplusDataset(split="testA"), - "generation": RefCOCOplusDatasetForGeneration(split="testA"), - } - - case "vizwiz": - from .VizWizDataset import ( - VizWizDataset, - VizWizDatasetForGeneration, - ) - - dataset = { - "train": VizWizDataset( - vis_root=Path(base_path, "vizwiz", "images", "train"), - ann_path=Path(base_path, "vizwiz", "Annotations"), - split="train", - ), - "test": VizWizDataset( - vis_root=Path(base_path, "vizwiz", "images", "val"), - ann_path=Path(base_path, "vizwiz", "Annotations"), - split="test", - ), - "generation": VizWizDatasetForGeneration( - vis_root=Path(base_path, "vizwiz", "images", "val"), - ann_path=Path(base_path, "vizwiz", "Annotations"), - split="test", - ), - } - - return dataset + Args: + dataset_name (`str`): + The name of the dataset. + base_path (`str`): + The base path to the dataset. + """ + dataset_name = dataset_name.lower() + if dataset_name not in MAPPING_NAME_TO_DATASET: + raise ValueError(f"Dataset {dataset_name} not registered.") + for key in MAPPING_NAME_TO_DATASET[dataset_name]: + MAPPING_NAME_TO_DATASET[dataset_name][key].load_data() + return MAPPING_NAME_TO_DATASET[dataset_name] diff --git a/src/dataset_library/format.py b/src/dataset_library/format.py index f983eec..680478a 100644 --- a/src/dataset_library/format.py +++ b/src/dataset_library/format.py @@ -28,8 +28,8 @@ class Conversation(TypedDict): class DatasetOutput(TypedDict): - audios: Optional[list[Tuple[np.ndarray, int]]] chat: list[Conversation] answer: Optional[str] original: Any images: Optional[list[Image.Image]] + audios: Optional[list[Tuple[np.ndarray, int]]] \ No newline at end of file diff --git a/src/dataset_library/test_dataset.py b/src/dataset_library/test_dataset.py index ad20107..34a7156 100644 --- a/src/dataset_library/test_dataset.py +++ b/src/dataset_library/test_dataset.py @@ -1,82 +1,21 @@ -from .factory import get_dataset +import pytest +from .factory import get_dataset, MAPPING_NAME_TO_DATASET +# Get all registered dataset names for parameterization +dataset_names = list(MAPPING_NAME_TO_DATASET.keys()) -def test_gigaspeech(): - dataset = get_dataset("gigaspeech") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 +@pytest.mark.parametrize("dataset_name", dataset_names) +def test_registered_datasets(dataset_name): + dataset = get_dataset(dataset_name) + + # Test train split + assert "train" in dataset, f"Train split not found in {dataset_name}" + assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore + assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore + assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_chem(): - dataset = get_dataset("chem") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_ocrvqa200k(): - dataset = get_dataset("ocrvqa200k") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_textvqa(): - dataset = get_dataset("textvqa") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_scienceqa(): - dataset = get_dataset("scienceqa") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_refcoco(): - dataset = get_dataset("refcoco") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_refcocog(): - dataset = get_dataset("refcocog") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_refcocoplus(): - dataset = get_dataset("refcocoplus") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 - - -def test_VizWiz(): - dataset = get_dataset("vizwiz") - assert len(dataset["train"]) > 0 # type: ignore - assert len(dataset["train"][0]["chat"]) > 0 - - assert len(dataset["test"]) > 0 # type: ignore - assert len(dataset["test"][0]["chat"]) > 0 + # Test test split + assert "test" in dataset, f"Test split not found in {dataset_name}" + assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore + assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore + assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore diff --git a/src/utils/trainer.py b/src/utils/trainer.py index 4ad0d75..b389314 100644 --- a/src/utils/trainer.py +++ b/src/utils/trainer.py @@ -36,7 +36,7 @@ class ContinualTrainer(Trainer): # fisher = t if regularization_args.lwf_enable: - pass + self.lwf_lambda = regularization_args.lwf_lambda def create_accelerator_and_postprocess(self):