From 188ea7df6e9d715465d217f1040c438675a4305b Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Fri, 16 May 2025 13:44:19 +0800 Subject: [PATCH] =?UTF-8?q?feat=E2=9C=A8:=20=E6=B7=BB=E5=8A=A0VizWiz?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86=E6=94=AF=E6=8C=81=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=95=B0=E6=8D=AE=E9=9B=86=E5=B7=A5=E5=8E=82=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B=EF=BC=8C?= =?UTF-8?q?=E8=B0=83=E6=95=B4VSCode=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 3 + src/dataset_library/ScienceQADataset.py | 2 +- src/dataset_library/VizWizDataset.py | 132 ++++++++++++++++++++++++ src/dataset_library/factory.py | 24 +++++ src/dataset_library/test_dataset.py | 84 ++++++++------- src/peft_repo | 2 +- src/todo.md | 8 ++ src/transformers_repo | 2 +- 8 files changed, 218 insertions(+), 39 deletions(-) create mode 100644 src/dataset_library/VizWizDataset.py diff --git a/.vscode/settings.json b/.vscode/settings.json index bc87c5c..15e2ee1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,8 @@ "python.analysis.extraPaths": [ "src/transformers_repo/src/", "src/peft_repo/src/" + ], + "python.analysis.exclude": [ + "dataset/**" ] } \ No newline at end of file diff --git a/src/dataset_library/ScienceQADataset.py b/src/dataset_library/ScienceQADataset.py index 6f3c17e..8b4d279 100644 --- a/src/dataset_library/ScienceQADataset.py +++ b/src/dataset_library/ScienceQADataset.py @@ -19,7 +19,7 @@ class ScienceQADataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor from .format import dataset_dir - ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) + ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore self.data = ds[split] # type: ignore def __len__(self): diff --git a/src/dataset_library/VizWizDataset.py b/src/dataset_library/VizWizDataset.py new file mode 100644 index 0000000..fa6aa09 --- /dev/null +++ b/src/dataset_library/VizWizDataset.py @@ -0,0 +1,132 @@ +from PIL import Image +from .format import ( + Conversation, + ConverstationAudio, + ConverstationImage, + ConverstationText, + DatasetOutput, +) +from torch.utils.data import Dataset +import json +import os +from pathlib import Path + + +class VizWizDataset(Dataset): + def __init__( + self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + from .vis_processor import size_processor + + self.vis_processor = ( + vis_processor if vis_processor is not None else size_processor + ) + self.text_processor = text_processor + if split == "train": + self.data = self.create_data(Path(ann_path, "train.json")) + elif split == "test": + self.data = self.create_data(Path(ann_path, "val.json")) + + # self.instruction_pool = [ + # "[vqa] {}", + # "[vqa] Based on the image, respond to this question with a short answer: {}", + # ] + + def create_data(self, ann_path): + processed_data = [] + with open(ann_path, "r") as f: + data = json.load(f) + for i in range(len(data)): + if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]: + imageFile = data[i]["image"] + processed_data.append( + { + "question": data[i]["question"], + "answer": data[i]["answers"][0]["answer"], + "image_path": imageFile, + "original": data[i], + } + ) + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image: Image.Image = Image.open( + os.path.join(self.vis_root, sample["image_path"]) + ).convert("RGB") + # resize image + + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + Conversation( + role="user", + content=[ + ConverstationImage(type="image", image_url=""), + ConverstationText( + type="text", + text=f"[vqa] Based on the image, respond to this question with a short answer: {question}", + ), + ], + ), + Conversation( + role="assistant", content=[ConverstationText(type="text", text=answer)] + ), + ] + + return DatasetOutput( + chat=chat, + original=sample["original"], + images=[image], + ) # type: ignore + + +class VizWizDatasetForGeneration(VizWizDataset): + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( + "RGB" + ) + # resize image + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + Conversation( + role="user", + content=[ + ConverstationImage(type="image", image_url=""), + ConverstationText( + type="text", + text=f"[vqa] Based on the image, respond to this question with a short answer: {question}", + ), + ], + ), + ] + return DatasetOutput( + images=[image], + chat=chat, + answer=answer, + original=sample["original"], + ) # type: ignore diff --git a/src/dataset_library/factory.py b/src/dataset_library/factory.py index 3f6e4b9..624c71f 100644 --- a/src/dataset_library/factory.py +++ b/src/dataset_library/factory.py @@ -128,5 +128,29 @@ def get_dataset( "test": RefCOCOplusDataset(split="testA"), "generation": RefCOCOplusDatasetForGeneration(split="testA"), } + + case "vizwiz": + from .VizWizDataset import ( + VizWizDataset, + VizWizDatasetForGeneration, + ) + + dataset = { + "train": VizWizDataset( + vis_root=Path(base_path, "vizwiz", "images", "train"), + ann_path=Path(base_path, "vizwiz", "Annotations"), + split="train", + ), + "test": VizWizDataset( + vis_root=Path(base_path, "vizwiz", "images", "val"), + ann_path=Path(base_path, "vizwiz", "Annotations"), + split="test", + ), + "generation": VizWizDatasetForGeneration( + vis_root=Path(base_path, "vizwiz", "images", "val"), + ann_path=Path(base_path, "vizwiz", "Annotations"), + split="test", + ), + } return dataset diff --git a/src/dataset_library/test_dataset.py b/src/dataset_library/test_dataset.py index 21690e9..ad20107 100644 --- a/src/dataset_library/test_dataset.py +++ b/src/dataset_library/test_dataset.py @@ -1,70 +1,82 @@ from .factory import get_dataset -# def test_gigaspeech(): -# dataset = get_dataset("gigaspeech") -# assert len(dataset["train"]) > 0 # type: ignore -# assert len(dataset["train"][0]["chat"]) > 0 +def test_gigaspeech(): + dataset = get_dataset("gigaspeech") + assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"][0]["chat"]) > 0 -# assert len(dataset["test"]) > 0 # type: ignore -# assert len(dataset["test"][0]["chat"]) > 0 + assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"][0]["chat"]) > 0 -# def test_chem(): -# dataset = get_dataset("chem") -# assert len(dataset["train"]) > 0 # type: ignore -# assert len(dataset["train"][0]["chat"]) > 0 +def test_chem(): + dataset = get_dataset("chem") + assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"][0]["chat"]) > 0 -# assert len(dataset["test"]) > 0 # type: ignore -# assert len(dataset["test"][0]["chat"]) > 0 + assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"][0]["chat"]) > 0 -# def test_ocrvqa200k(): -# dataset = get_dataset("ocrvqa200k") -# assert len(dataset["train"]) > 0 # type: ignore -# assert len(dataset["train"][0]["chat"]) > 0 +def test_ocrvqa200k(): + dataset = get_dataset("ocrvqa200k") + assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"][0]["chat"]) > 0 -# assert len(dataset["test"]) > 0 # type: ignore -# assert len(dataset["test"][0]["chat"]) > 0 + assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"][0]["chat"]) > 0 -# def test_textvqa(): -# dataset = get_dataset("textvqa") -# assert len(dataset["train"]) > 0 # type: ignore -# assert len(dataset["train"][0]["chat"]) > 0 +def test_textvqa(): + dataset = get_dataset("textvqa") + assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"][0]["chat"]) > 0 -# assert len(dataset["test"]) > 0 # type: ignore -# assert len(dataset["test"][0]["chat"]) > 0 + assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"][0]["chat"]) > 0 -# def test_scienceqa(): -# dataset = get_dataset("scienceqa") -# assert len(dataset["train"]) > 0 # type: ignore -# assert len(dataset["train"][0]["chat"]) > 0 +def test_scienceqa(): + dataset = get_dataset("scienceqa") + assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"][0]["chat"]) > 0 + + assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"][0]["chat"]) > 0 -# assert len(dataset["test"]) > 0 # type: ignore -# assert len(dataset["test"][0]["chat"]) > 0 def test_refcoco(): dataset = get_dataset("refcoco") - assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"][0]["chat"]) > 0 - assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"][0]["chat"]) > 0 + def test_refcocog(): dataset = get_dataset("refcocog") - assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"][0]["chat"]) > 0 - assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"][0]["chat"]) > 0 + def test_refcocoplus(): dataset = get_dataset("refcocoplus") - assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"][0]["chat"]) > 0 - assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"]) > 0 # type: ignore + assert len(dataset["test"][0]["chat"]) > 0 + + +def test_VizWiz(): + dataset = get_dataset("vizwiz") + assert len(dataset["train"]) > 0 # type: ignore + assert len(dataset["train"][0]["chat"]) > 0 + + assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"][0]["chat"]) > 0 diff --git a/src/peft_repo b/src/peft_repo index ebab283..65c3c43 160000 --- a/src/peft_repo +++ b/src/peft_repo @@ -1 +1 @@ -Subproject commit ebab283576fc8803314314b5e1e4331c424a2198 +Subproject commit 65c3c43cd195bd90b8cb339c1ba883b4c6c66b43 diff --git a/src/todo.md b/src/todo.md index 5a05c0f..5a018c2 100644 --- a/src/todo.md +++ b/src/todo.md @@ -20,3 +20,11 @@ - [ ] 多个数据集引入 - [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用) - [ ] 引入EWC和LWF + +[2025.05.15] + +- [x] vizwiz处理 + +[2025.05.16] + +- [ ] 处理不同的持续学习框架,使得整体框架能够兼容 \ No newline at end of file diff --git a/src/transformers_repo b/src/transformers_repo index 8ee1a4e..7961d29 160000 --- a/src/transformers_repo +++ b/src/transformers_repo @@ -1 +1 @@ -Subproject commit 8ee1a4eadda1d83cf65c024fe54364b5bd74e55f +Subproject commit 7961d291b338d568fa2160f7deac85baa21c49dc