feat: 添加多个数据集的支持,包括Gigaspeech、TextVQA、OCR-VQA-200K、RefCOCO系列,更新数据集工厂和处理逻辑,优化图像处理功能

This commit is contained in:
YunyaoZhou 2025-05-15 20:33:29 +08:00
parent 9ca588224d
commit 24a6c3c114
17 changed files with 568 additions and 78 deletions

5
dataset/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
derek-thomas*
*.lock
speechcolab*
lmms-lab*
downloads/*

2
dataset/OCR-VQA-200K/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
images/*
dataset.json

View File

@ -0,0 +1,49 @@
import os
import json
import urllib.request as ureq
import urllib.error
import concurrent.futures
import threading
# Set the file paths for your Google Drive
dataset_path = './dataset.json'
images_path = './images'
download = 1 # Set to 0 if images are already downloaded
# Load dataset json file
with open(dataset_path, 'r') as fp:
data = json.load(fp)
# Initialize a counter and a lock for thread-safe counting
downloaded_count = 0
count_lock = threading.Lock()
# Function to download an image
def download_image(k):
global downloaded_count
imageURL = data[k]['imageURL']
ext = os.path.splitext(imageURL)[1]
outputFile = os.path.join(images_path, f'{k}{ext}')
# Only download the image if it doesn't exist
if not os.path.exists(outputFile):
try:
ureq.urlretrieve(imageURL, outputFile)
with count_lock:
downloaded_count += 1
if downloaded_count % 100 == 0:
print(f'{downloaded_count} images downloaded.')
except urllib.error.URLError as e:
print(f'Error downloading {outputFile}: {e}')
# Download images using multiple threads
if download == 1:
if not os.path.exists(images_path):
os.makedirs(images_path)
# Create a thread pool and download the images in parallel
# Increase max_workers to potentially speed up downloads for many small files.
# The optimal number may vary based on your network and the server's capacity.
with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
executor.map(download_image, data.keys())

5
dataset/TextVQA/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
images/test_images/*
images/train_images/*
TextVQA_0.5.1_test.json
TextVQA_0.5.1_train.json
TextVQA_0.5.1_val.json

3
dataset/vizwiz/Annotations/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
train.json
test.json
val.json

3
dataset/vizwiz/images/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
val/*
train/*
test/*

View File

@ -18,8 +18,9 @@ class GigaspeechDataset(Dataset):
self.audio_processor = audio_processor
self.text_processor = text_processor
gs = load_dataset("speechcolab/gigaspeech", "xs")
self.data = gs[split]
from .format import dataset_dir
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[split] # type: ignore
def __len__(self):
return len(self.data)
@ -54,7 +55,7 @@ class GigaspeechDataset(Dataset):
audios=[(audio, sampling_rate)],
chat=chat,
original=sample,
)
) # type: ignore
class GigaspeechDatasetForGeneration(GigaspeechDataset):
@ -87,7 +88,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
chat=chat,
answer=text,
original=sample,
)
) # type: ignore
def test_gigaspeech():
@ -103,3 +104,6 @@ def test_gigaspeech():
print(dataset[0])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_gigaspeech()

View File

@ -22,8 +22,10 @@ class OCRVQADataset(Dataset):
"""
self.vis_root = vis_root
from .vis_processor import size_processor
self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
if split == "train":
@ -64,24 +66,6 @@ class OCRVQADataset(Dataset):
def __len__(self):
return len(self.data)
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
@ -117,7 +101,7 @@ class OCRVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
)
) # type: ignore
class OCRVQADatasetForGeneration(OCRVQADataset):
@ -153,4 +137,4 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
)
) # type: ignore

View File

@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCODataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCODatasetForGeneration(RefCOCODataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCO():
dataset = RefCOCODataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCODatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCO()

View File

@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOplusDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "testA"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOplusDatasetForGeneration(RefCOCOplusDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOplus():
dataset = RefCOCOplusDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOplusDatasetForGeneration(
split="testA",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCOplus()

View File

@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOgDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOg():
dataset = RefCOCOgDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOgDatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCOg()

View File

@ -6,20 +6,21 @@ from .format import (
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset
from datasets import load_dataset, DatasetDict
class ScienceQADataset(Dataset):
def __init__(self, audio_processor=None, text_processor=None, split="train"):
def __init__(self, vis_processor=None, text_processor=None, split="train"):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = audio_processor
self.vis_processor = vis_processor
self.text_processor = text_processor
ds = load_dataset("derek-thomas/ScienceQA")
self.data = ds[split]
from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir)
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
@ -60,7 +61,7 @@ class ScienceQADataset(Dataset):
images=[images],
chat=chat,
original=sample,
)
) # type: ignore
class ScienceQADatasetForGeneration(ScienceQADataset):
@ -98,7 +99,7 @@ class ScienceQADatasetForGeneration(ScienceQADataset):
chat=chat,
answer=choices[answer],
original=sample,
)
) # type: ignore
def test_scienceQA():

View File

@ -124,7 +124,7 @@ class TextVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
)
) # type: ignore
class TextVQADatasetForGeneration(TextVQADataset):
@ -158,16 +158,5 @@ class TextVQADatasetForGeneration(TextVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
)
) # type: ignore
def test_dataset():
vis_root = "/home/zyy/dataset/TextVQA/images"
ann_path = "/home/zyy/dataset/TextVQA"
dataset = TextVQADataset(vis_root, ann_path)
for i in range(10):
print(dataset[i])
if __name__ == "__main__":
test_dataset()

View File

@ -1,10 +1,11 @@
from torch.utils.data import Dataset
from typing import Literal
from pathlib import Path
from dataset_library.format import dataset_dir
def get_dataset(
dataset_name, base_path="/home/zyy/dataset"
dataset_name, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
match dataset_name:
@ -92,4 +93,40 @@ def get_dataset(
"generation": ScienceQADatasetForGeneration(split="test"),
}
case "refcoco":
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
case "refcocog":
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
case "refcocoplus":
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
return dataset

View File

@ -1,6 +1,9 @@
from typing import Any, Tuple, TypedDict, Literal, Optional
import numpy as np
from PIL import Image
from pathlib import Path
dataset_dir = Path(__file__).resolve().parent.parent.parent / "dataset"
class ConverstationText(TypedDict):

View File

@ -1,46 +1,70 @@
from .factory import get_dataset
def test_gigaspeech():
dataset = get_dataset("gigaspeech")
assert len(dataset["train"]) > 0
# def test_gigaspeech():
# dataset = get_dataset("gigaspeech")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_chem():
# dataset = get_dataset("chem")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_ocrvqa200k():
# dataset = get_dataset("ocrvqa200k")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_textvqa():
# dataset = get_dataset("textvqa")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_scienceqa():
# dataset = get_dataset("scienceqa")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
def test_refcoco():
dataset = get_dataset("refcoco")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_chem():
dataset = get_dataset("chem")
assert len(dataset["train"]) > 0
def test_refcocog():
dataset = get_dataset("refcocog")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_ocrvqa200k():
dataset = get_dataset("ocrvqa200k")
assert len(dataset["train"]) > 0
def test_refcocoplus():
dataset = get_dataset("refcocoplus")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_textvqa():
dataset = get_dataset("textvqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_scienceqa():
dataset = get_dataset("scienceqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0

View File

@ -0,0 +1,18 @@
from PIL import Image
def size_processor(image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image