feat: 添加多个数据集的支持,包括Gigaspeech、TextVQA、OCR-VQA-200K、RefCOCO系列,更新数据集工厂和处理逻辑,优化图像处理功能

This commit is contained in:
YunyaoZhou 2025-05-15 20:33:29 +08:00
parent 9ca588224d
commit 24a6c3c114
17 changed files with 568 additions and 78 deletions

5
dataset/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
derek-thomas*
*.lock
speechcolab*
lmms-lab*
downloads/*

2
dataset/OCR-VQA-200K/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
images/*
dataset.json

View File

@ -0,0 +1,49 @@
import os
import json
import urllib.request as ureq
import urllib.error
import concurrent.futures
import threading
# Set the file paths for your Google Drive
dataset_path = './dataset.json'
images_path = './images'
download = 1 # Set to 0 if images are already downloaded
# Load dataset json file
with open(dataset_path, 'r') as fp:
data = json.load(fp)
# Initialize a counter and a lock for thread-safe counting
downloaded_count = 0
count_lock = threading.Lock()
# Function to download an image
def download_image(k):
global downloaded_count
imageURL = data[k]['imageURL']
ext = os.path.splitext(imageURL)[1]
outputFile = os.path.join(images_path, f'{k}{ext}')
# Only download the image if it doesn't exist
if not os.path.exists(outputFile):
try:
ureq.urlretrieve(imageURL, outputFile)
with count_lock:
downloaded_count += 1
if downloaded_count % 100 == 0:
print(f'{downloaded_count} images downloaded.')
except urllib.error.URLError as e:
print(f'Error downloading {outputFile}: {e}')
# Download images using multiple threads
if download == 1:
if not os.path.exists(images_path):
os.makedirs(images_path)
# Create a thread pool and download the images in parallel
# Increase max_workers to potentially speed up downloads for many small files.
# The optimal number may vary based on your network and the server's capacity.
with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
executor.map(download_image, data.keys())

5
dataset/TextVQA/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
images/test_images/*
images/train_images/*
TextVQA_0.5.1_test.json
TextVQA_0.5.1_train.json
TextVQA_0.5.1_val.json

3
dataset/vizwiz/Annotations/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
train.json
test.json
val.json

3
dataset/vizwiz/images/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
val/*
train/*
test/*

View File

@ -18,8 +18,9 @@ class GigaspeechDataset(Dataset):
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.text_processor = text_processor self.text_processor = text_processor
gs = load_dataset("speechcolab/gigaspeech", "xs") from .format import dataset_dir
self.data = gs[split] gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[split] # type: ignore
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -54,7 +55,7 @@ class GigaspeechDataset(Dataset):
audios=[(audio, sampling_rate)], audios=[(audio, sampling_rate)],
chat=chat, chat=chat,
original=sample, original=sample,
) ) # type: ignore
class GigaspeechDatasetForGeneration(GigaspeechDataset): class GigaspeechDatasetForGeneration(GigaspeechDataset):
@ -87,7 +88,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
chat=chat, chat=chat,
answer=text, answer=text,
original=sample, original=sample,
) ) # type: ignore
def test_gigaspeech(): def test_gigaspeech():
@ -103,3 +104,6 @@ def test_gigaspeech():
print(dataset[0]) print(dataset[0])
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_gigaspeech()

View File

@ -22,8 +22,10 @@ class OCRVQADataset(Dataset):
""" """
self.vis_root = vis_root self.vis_root = vis_root
from .vis_processor import size_processor
self.vis_processor = ( self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor vis_processor if vis_processor is not None else size_processor
) )
self.text_processor = text_processor self.text_processor = text_processor
if split == "train": if split == "train":
@ -64,24 +66,6 @@ class OCRVQADataset(Dataset):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def __getitem__(self, index): def __getitem__(self, index):
sample = self.data[index] sample = self.data[index]
image: Image.Image = Image.open( image: Image.Image = Image.open(
@ -117,7 +101,7 @@ class OCRVQADataset(Dataset):
chat=chat, chat=chat,
original=sample["original"], original=sample["original"],
images=[image], images=[image],
) ) # type: ignore
class OCRVQADatasetForGeneration(OCRVQADataset): class OCRVQADatasetForGeneration(OCRVQADataset):
@ -153,4 +137,4 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
chat=chat, chat=chat,
answer=answer, answer=answer,
original=sample["original"], original=sample["original"],
) ) # type: ignore

View File

@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCODataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCODatasetForGeneration(RefCOCODataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCO():
dataset = RefCOCODataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCODatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCO()

View File

@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOplusDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "testA"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOplusDatasetForGeneration(RefCOCOplusDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOplus():
dataset = RefCOCOplusDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOplusDatasetForGeneration(
split="testA",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCOplus()

View File

@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOgDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOg():
dataset = RefCOCOgDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOgDatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCOg()

View File

@ -6,20 +6,21 @@ from .format import (
DatasetOutput, DatasetOutput,
) )
from torch.utils.data import Dataset from torch.utils.data import Dataset
from datasets import load_dataset from datasets import load_dataset, DatasetDict
class ScienceQADataset(Dataset): class ScienceQADataset(Dataset):
def __init__(self, audio_processor=None, text_processor=None, split="train"): def __init__(self, vis_processor=None, text_processor=None, split="train"):
""" """
vis_root (string): Root directory of images (e.g. coco/images/) vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file ann_root (string): directory to store the annotation file
""" """
self.vis_processor = audio_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
ds = load_dataset("derek-thomas/ScienceQA") from .format import dataset_dir
self.data = ds[split] ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir)
self.data = ds[split] # type: ignore
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -60,7 +61,7 @@ class ScienceQADataset(Dataset):
images=[images], images=[images],
chat=chat, chat=chat,
original=sample, original=sample,
) ) # type: ignore
class ScienceQADatasetForGeneration(ScienceQADataset): class ScienceQADatasetForGeneration(ScienceQADataset):
@ -98,7 +99,7 @@ class ScienceQADatasetForGeneration(ScienceQADataset):
chat=chat, chat=chat,
answer=choices[answer], answer=choices[answer],
original=sample, original=sample,
) ) # type: ignore
def test_scienceQA(): def test_scienceQA():

View File

@ -124,7 +124,7 @@ class TextVQADataset(Dataset):
chat=chat, chat=chat,
original=sample["original"], original=sample["original"],
images=[image], images=[image],
) ) # type: ignore
class TextVQADatasetForGeneration(TextVQADataset): class TextVQADatasetForGeneration(TextVQADataset):
@ -158,16 +158,5 @@ class TextVQADatasetForGeneration(TextVQADataset):
chat=chat, chat=chat,
answer=answer, answer=answer,
original=sample["original"], original=sample["original"],
) ) # type: ignore
def test_dataset():
vis_root = "/home/zyy/dataset/TextVQA/images"
ann_path = "/home/zyy/dataset/TextVQA"
dataset = TextVQADataset(vis_root, ann_path)
for i in range(10):
print(dataset[i])
if __name__ == "__main__":
test_dataset()

View File

@ -1,10 +1,11 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import Literal from typing import Literal
from pathlib import Path from pathlib import Path
from dataset_library.format import dataset_dir
def get_dataset( def get_dataset(
dataset_name, base_path="/home/zyy/dataset" dataset_name, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]: ) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {} dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
match dataset_name: match dataset_name:
@ -92,4 +93,40 @@ def get_dataset(
"generation": ScienceQADatasetForGeneration(split="test"), "generation": ScienceQADatasetForGeneration(split="test"),
} }
case "refcoco":
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
case "refcocog":
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
case "refcocoplus":
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
return dataset return dataset

View File

@ -1,6 +1,9 @@
from typing import Any, Tuple, TypedDict, Literal, Optional from typing import Any, Tuple, TypedDict, Literal, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from pathlib import Path
dataset_dir = Path(__file__).resolve().parent.parent.parent / "dataset"
class ConverstationText(TypedDict): class ConverstationText(TypedDict):

View File

@ -1,46 +1,70 @@
from .factory import get_dataset from .factory import get_dataset
def test_gigaspeech(): # def test_gigaspeech():
dataset = get_dataset("gigaspeech") # dataset = get_dataset("gigaspeech")
assert len(dataset["train"]) > 0 # assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_chem():
# dataset = get_dataset("chem")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_ocrvqa200k():
# dataset = get_dataset("ocrvqa200k")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_textvqa():
# dataset = get_dataset("textvqa")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_scienceqa():
# dataset = get_dataset("scienceqa")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
def test_refcoco():
dataset = get_dataset("refcoco")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
def test_refcocog():
def test_chem(): dataset = get_dataset("refcocog")
dataset = get_dataset("chem") assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0
def test_refcocoplus():
def test_ocrvqa200k(): dataset = get_dataset("refcocoplus")
dataset = get_dataset("ocrvqa200k") assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0 assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0 assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_textvqa():
dataset = get_dataset("textvqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_scienceqa():
dataset = get_dataset("scienceqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0 assert len(dataset["test"][0]["chat"]) > 0

View File

@ -0,0 +1,18 @@
from PIL import Image
def size_processor(image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image