feat✨: 添加VizWiz数据集支持,更新数据集工厂,优化测试用例,调整VSCode设置
This commit is contained in:
parent
24a6c3c114
commit
188ea7df6e
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -2,5 +2,8 @@
|
|||||||
"python.analysis.extraPaths": [
|
"python.analysis.extraPaths": [
|
||||||
"src/transformers_repo/src/",
|
"src/transformers_repo/src/",
|
||||||
"src/peft_repo/src/"
|
"src/peft_repo/src/"
|
||||||
|
],
|
||||||
|
"python.analysis.exclude": [
|
||||||
|
"dataset/**"
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -19,7 +19,7 @@ class ScienceQADataset(Dataset):
|
|||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir)
|
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[split] # type: ignore
|
self.data = ds[split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
132
src/dataset_library/VizWizDataset.py
Normal file
132
src/dataset_library/VizWizDataset.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class VizWizDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
self.vis_root = vis_root
|
||||||
|
|
||||||
|
from .vis_processor import size_processor
|
||||||
|
|
||||||
|
self.vis_processor = (
|
||||||
|
vis_processor if vis_processor is not None else size_processor
|
||||||
|
)
|
||||||
|
self.text_processor = text_processor
|
||||||
|
if split == "train":
|
||||||
|
self.data = self.create_data(Path(ann_path, "train.json"))
|
||||||
|
elif split == "test":
|
||||||
|
self.data = self.create_data(Path(ann_path, "val.json"))
|
||||||
|
|
||||||
|
# self.instruction_pool = [
|
||||||
|
# "[vqa] {}",
|
||||||
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
def create_data(self, ann_path):
|
||||||
|
processed_data = []
|
||||||
|
with open(ann_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for i in range(len(data)):
|
||||||
|
if os.path.exists(os.path.join(self.vis_root, data[i]["image"])) and data[i]["answerable"]:
|
||||||
|
imageFile = data[i]["image"]
|
||||||
|
processed_data.append(
|
||||||
|
{
|
||||||
|
"question": data[i]["question"],
|
||||||
|
"answer": data[i]["answers"][0]["answer"],
|
||||||
|
"image_path": imageFile,
|
||||||
|
"original": data[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return processed_data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image: Image.Image = Image.open(
|
||||||
|
os.path.join(self.vis_root, sample["image_path"])
|
||||||
|
).convert("RGB")
|
||||||
|
# resize image
|
||||||
|
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Conversation(
|
||||||
|
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
chat=chat,
|
||||||
|
original=sample["original"],
|
||||||
|
images=[image],
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class VizWizDatasetForGeneration(VizWizDataset):
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
|
"RGB"
|
||||||
|
)
|
||||||
|
# resize image
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[image],
|
||||||
|
chat=chat,
|
||||||
|
answer=answer,
|
||||||
|
original=sample["original"],
|
||||||
|
) # type: ignore
|
@ -128,5 +128,29 @@ def get_dataset(
|
|||||||
"test": RefCOCOplusDataset(split="testA"),
|
"test": RefCOCOplusDataset(split="testA"),
|
||||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "vizwiz":
|
||||||
|
from .VizWizDataset import (
|
||||||
|
VizWizDataset,
|
||||||
|
VizWizDatasetForGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": VizWizDataset(
|
||||||
|
vis_root=Path(base_path, "vizwiz", "images", "train"),
|
||||||
|
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": VizWizDataset(
|
||||||
|
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
||||||
|
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": VizWizDatasetForGeneration(
|
||||||
|
vis_root=Path(base_path, "vizwiz", "images", "val"),
|
||||||
|
ann_path=Path(base_path, "vizwiz", "Annotations"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -1,70 +1,82 @@
|
|||||||
from .factory import get_dataset
|
from .factory import get_dataset
|
||||||
|
|
||||||
|
|
||||||
# def test_gigaspeech():
|
def test_gigaspeech():
|
||||||
# dataset = get_dataset("gigaspeech")
|
dataset = get_dataset("gigaspeech")
|
||||||
# assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
# assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
# assert len(dataset["test"]) > 0 # type: ignore
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
# assert len(dataset["test"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
# def test_chem():
|
def test_chem():
|
||||||
# dataset = get_dataset("chem")
|
dataset = get_dataset("chem")
|
||||||
# assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
# assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
# assert len(dataset["test"]) > 0 # type: ignore
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
# assert len(dataset["test"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
# def test_ocrvqa200k():
|
def test_ocrvqa200k():
|
||||||
# dataset = get_dataset("ocrvqa200k")
|
dataset = get_dataset("ocrvqa200k")
|
||||||
# assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
# assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
# assert len(dataset["test"]) > 0 # type: ignore
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
# assert len(dataset["test"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
# def test_textvqa():
|
def test_textvqa():
|
||||||
# dataset = get_dataset("textvqa")
|
dataset = get_dataset("textvqa")
|
||||||
# assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
# assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
# assert len(dataset["test"]) > 0 # type: ignore
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
# assert len(dataset["test"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
# def test_scienceqa():
|
def test_scienceqa():
|
||||||
# dataset = get_dataset("scienceqa")
|
dataset = get_dataset("scienceqa")
|
||||||
# assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
# assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
# assert len(dataset["test"]) > 0 # type: ignore
|
|
||||||
# assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
def test_refcoco():
|
def test_refcoco():
|
||||||
dataset = get_dataset("refcoco")
|
dataset = get_dataset("refcoco")
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_refcocog():
|
def test_refcocog():
|
||||||
dataset = get_dataset("refcocog")
|
dataset = get_dataset("refcocog")
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_refcocoplus():
|
def test_refcocoplus():
|
||||||
dataset = get_dataset("refcocoplus")
|
dataset = get_dataset("refcocoplus")
|
||||||
assert len(dataset["train"]) > 0 # type: ignore
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0 # type: ignore
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_VizWiz():
|
||||||
|
dataset = get_dataset("vizwiz")
|
||||||
|
assert len(dataset["train"]) > 0 # type: ignore
|
||||||
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
assert len(dataset["test"]) > 0 # type: ignore
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit ebab283576fc8803314314b5e1e4331c424a2198
|
Subproject commit 65c3c43cd195bd90b8cb339c1ba883b4c6c66b43
|
@ -20,3 +20,11 @@
|
|||||||
- [ ] 多个数据集引入
|
- [ ] 多个数据集引入
|
||||||
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
||||||
- [ ] 引入EWC和LWF
|
- [ ] 引入EWC和LWF
|
||||||
|
|
||||||
|
[2025.05.15]
|
||||||
|
|
||||||
|
- [x] vizwiz处理
|
||||||
|
|
||||||
|
[2025.05.16]
|
||||||
|
|
||||||
|
- [ ] 处理不同的持续学习框架,使得整体框架能够兼容
|
@ -1 +1 @@
|
|||||||
Subproject commit 8ee1a4eadda1d83cf65c024fe54364b5bd74e55f
|
Subproject commit 7961d291b338d568fa2160f7deac85baa21c49dc
|
Loading…
Reference in New Issue
Block a user