feat✨: 添加VizWiz数据集支持,更新数据集工厂,优化测试用例,调整VSCode设置
This commit is contained in:
parent
24a6c3c114
commit
188ea7df6e
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -2,5 +2,8 @@
|
||||
"python.analysis.extraPaths": [
|
||||
"src/transformers_repo/src/",
|
||||
"src/peft_repo/src/"
|
||||
],
|
||||
"python.analysis.exclude": [
|
||||
"dataset/**"
|
||||
]
|
||||
}
|
@ -19,7 +19,7 @@ class ScienceQADataset(Dataset):
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
from .format import dataset_dir
|
||||
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir)
|
||||
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
|
132
src/dataset_library/VizWizDataset.py
Normal file
132
src/dataset_library/VizWizDataset.py
Normal file
@ -0,0 +1,132 @@
|
||||
from PIL import Image
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class VizWizDataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
from .vis_processor import size_processor
|
||||
|
||||
self.vis_processor = (
|
||||
vis_processor if vis_processor is not None else size_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(Path(ann_path, "train.json"))
|
||||
elif split == "test":
|
||||
self.data = self.create_data(Path(ann_path, "val.json"))
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||
# ]
|
||||
|
||||
def create_data(self, ann_path):
|
||||
processed_data = []
|
||||
with open(ann_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in range(len(data)):
|
||||
if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]:
|
||||
imageFile = data[i]["image"]
|
||||
processed_data.append(
|
||||
{
|
||||
"question": data[i]["question"],
|
||||
"answer": data[i]["answers"][0]["answer"],
|
||||
"image_path": imageFile,
|
||||
"original": data[i],
|
||||
}
|
||||
)
|
||||
return processed_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image: Image.Image = Image.open(
|
||||
os.path.join(self.vis_root, sample["image_path"])
|
||||
).convert("RGB")
|
||||
# resize image
|
||||
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class VizWizDatasetForGeneration(VizWizDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||
"RGB"
|
||||
)
|
||||
# resize image
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
return DatasetOutput(
|
||||
images=[image],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
) # type: ignore
|
@ -128,5 +128,29 @@ def get_dataset(
|
||||
"test": RefCOCOplusDataset(split="testA"),
|
||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||
}
|
||||
|
||||
case "vizwiz":
|
||||
from .VizWizDataset import (
|
||||
VizWizDataset,
|
||||
VizWizDatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": VizWizDataset(
|
||||
vis_root=Path(base_path, "vizwiz", "images", "train"),
|
||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||
split="train",
|
||||
),
|
||||
"test": VizWizDataset(
|
||||
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
"generation": VizWizDatasetForGeneration(
|
||||
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
||||
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
return dataset
|
||||
|
@ -1,70 +1,82 @@
|
||||
from .factory import get_dataset
|
||||
|
||||
|
||||
# def test_gigaspeech():
|
||||
# dataset = get_dataset("gigaspeech")
|
||||
# assert len(dataset["train"]) > 0 # type: ignore
|
||||
# assert len(dataset["train"][0]["chat"]) > 0
|
||||
def test_gigaspeech():
|
||||
dataset = get_dataset("gigaspeech")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
# assert len(dataset["test"]) > 0 # type: ignore
|
||||
# assert len(dataset["test"][0]["chat"]) > 0
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
# def test_chem():
|
||||
# dataset = get_dataset("chem")
|
||||
# assert len(dataset["train"]) > 0 # type: ignore
|
||||
# assert len(dataset["train"][0]["chat"]) > 0
|
||||
def test_chem():
|
||||
dataset = get_dataset("chem")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
# assert len(dataset["test"]) > 0 # type: ignore
|
||||
# assert len(dataset["test"][0]["chat"]) > 0
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
# def test_ocrvqa200k():
|
||||
# dataset = get_dataset("ocrvqa200k")
|
||||
# assert len(dataset["train"]) > 0 # type: ignore
|
||||
# assert len(dataset["train"][0]["chat"]) > 0
|
||||
def test_ocrvqa200k():
|
||||
dataset = get_dataset("ocrvqa200k")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
# assert len(dataset["test"]) > 0 # type: ignore
|
||||
# assert len(dataset["test"][0]["chat"]) > 0
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
# def test_textvqa():
|
||||
# dataset = get_dataset("textvqa")
|
||||
# assert len(dataset["train"]) > 0 # type: ignore
|
||||
# assert len(dataset["train"][0]["chat"]) > 0
|
||||
def test_textvqa():
|
||||
dataset = get_dataset("textvqa")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
# assert len(dataset["test"]) > 0 # type: ignore
|
||||
# assert len(dataset["test"][0]["chat"]) > 0
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
# def test_scienceqa():
|
||||
# dataset = get_dataset("scienceqa")
|
||||
# assert len(dataset["train"]) > 0 # type: ignore
|
||||
# assert len(dataset["train"][0]["chat"]) > 0
|
||||
def test_scienceqa():
|
||||
dataset = get_dataset("scienceqa")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
# assert len(dataset["test"]) > 0 # type: ignore
|
||||
# assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
def test_refcoco():
|
||||
dataset = get_dataset("refcoco")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_refcocog():
|
||||
dataset = get_dataset("refcocog")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_refcocoplus():
|
||||
dataset = get_dataset("refcocoplus")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_VizWiz():
|
||||
dataset = get_dataset("vizwiz")
|
||||
assert len(dataset["train"]) > 0 # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0 # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit ebab283576fc8803314314b5e1e4331c424a2198
|
||||
Subproject commit 65c3c43cd195bd90b8cb339c1ba883b4c6c66b43
|
@ -20,3 +20,11 @@
|
||||
- [ ] 多个数据集引入
|
||||
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
||||
- [ ] 引入EWC和LWF
|
||||
|
||||
[2025.05.15]
|
||||
|
||||
- [x] vizwiz处理
|
||||
|
||||
[2025.05.16]
|
||||
|
||||
- [ ] 处理不同的持续学习框架,使得整体框架能够兼容
|
@ -1 +1 @@
|
||||
Subproject commit 8ee1a4eadda1d83cf65c024fe54364b5bd74e55f
|
||||
Subproject commit 7961d291b338d568fa2160f7deac85baa21c49dc
|
Loading…
Reference in New Issue
Block a user