feat: 添加VizWiz数据集支持,更新数据集工厂,优化测试用例,调整VSCode设置

This commit is contained in:
YunyaoZhou 2025-05-16 13:44:19 +08:00
parent 24a6c3c114
commit 188ea7df6e
8 changed files with 218 additions and 39 deletions

View File

@ -2,5 +2,8 @@
"python.analysis.extraPaths": [ "python.analysis.extraPaths": [
"src/transformers_repo/src/", "src/transformers_repo/src/",
"src/peft_repo/src/" "src/peft_repo/src/"
],
"python.analysis.exclude": [
"dataset/**"
] ]
} }

View File

@ -19,7 +19,7 @@ class ScienceQADataset(Dataset):
self.vis_processor = vis_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore self.data = ds[split] # type: ignore
def __len__(self): def __len__(self):

View File

@ -0,0 +1,132 @@
from PIL import Image
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
import json
import os
from pathlib import Path
class VizWizDataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
from .vis_processor import size_processor
self.vis_processor = (
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(Path(ann_path, "train.json"))
elif split == "test":
self.data = self.create_data(Path(ann_path, "val.json"))
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for i in range(len(data)):
if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]:
imageFile = data[i]["image"]
processed_data.append(
{
"question": data[i]["question"],
"answer": data[i]["answers"][0]["answer"],
"image_path": imageFile,
"original": data[i],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
os.path.join(self.vis_root, sample["image_path"])
).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
Conversation(
role="assistant", content=[ConverstationText(type="text", text=answer)]
),
]
return DatasetOutput(
chat=chat,
original=sample["original"],
images=[image],
) # type: ignore
class VizWizDatasetForGeneration(VizWizDataset):
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
]
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
) # type: ignore

View File

@ -129,4 +129,28 @@ def get_dataset(
"generation": RefCOCOplusDatasetForGeneration(split="testA"), "generation": RefCOCOplusDatasetForGeneration(split="testA"),
} }
case "vizwiz":
from .VizWizDataset import (
VizWizDataset,
VizWizDatasetForGeneration,
)
dataset = {
"train": VizWizDataset(
vis_root=Path(base_path, "vizwiz", "images", "train"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="train",
),
"test": VizWizDataset(
vis_root=Path(base_path, "vizwiz", "images", "val"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="test",
),
"generation": VizWizDatasetForGeneration(
vis_root=Path(base_path, "vizwiz", "images", "val"),
ann_path=Path(base_path, "vizwiz", "Annotations"),
split="test",
),
}
return dataset return dataset

View File

@ -1,49 +1,50 @@
from .factory import get_dataset from .factory import get_dataset
# def test_gigaspeech(): def test_gigaspeech():
# dataset = get_dataset("gigaspeech") dataset = get_dataset("gigaspeech")
# assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
# def test_chem(): def test_chem():
# dataset = get_dataset("chem") dataset = get_dataset("chem")
# assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
# def test_ocrvqa200k(): def test_ocrvqa200k():
# dataset = get_dataset("ocrvqa200k") dataset = get_dataset("ocrvqa200k")
# assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
# def test_textvqa(): def test_textvqa():
# dataset = get_dataset("textvqa") dataset = get_dataset("textvqa")
# assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
# def test_scienceqa(): def test_scienceqa():
# dataset = get_dataset("scienceqa") dataset = get_dataset("scienceqa")
# assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
def test_refcoco(): def test_refcoco():
dataset = get_dataset("refcoco") dataset = get_dataset("refcoco")
@ -53,6 +54,7 @@ def test_refcoco():
assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
def test_refcocog(): def test_refcocog():
dataset = get_dataset("refcocog") dataset = get_dataset("refcocog")
assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"]) > 0 # type: ignore
@ -61,6 +63,7 @@ def test_refcocog():
assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
def test_refcocoplus(): def test_refcocoplus():
dataset = get_dataset("refcocoplus") dataset = get_dataset("refcocoplus")
assert len(dataset["train"]) > 0 # type: ignore assert len(dataset["train"]) > 0 # type: ignore
@ -68,3 +71,12 @@ def test_refcocoplus():
assert len(dataset["test"]) > 0 # type: ignore assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
def test_VizWiz():
dataset = get_dataset("vizwiz")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0

@ -1 +1 @@
Subproject commit ebab283576fc8803314314b5e1e4331c424a2198 Subproject commit 65c3c43cd195bd90b8cb339c1ba883b4c6c66b43

View File

@ -20,3 +20,11 @@
- [ ] 多个数据集引入 - [ ] 多个数据集引入
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用) - [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
- [ ] 引入EWC和LWF - [ ] 引入EWC和LWF
[2025.05.15]
- [x] vizwiz处理
[2025.05.16]
- [ ] 处理不同的持续学习框架,使得整体框架能够兼容

@ -1 +1 @@
Subproject commit 8ee1a4eadda1d83cf65c024fe54364b5bd74e55f Subproject commit 7961d291b338d568fa2160f7deac85baa21c49dc