feat: 添加多个数据集的支持,包括Gigaspeech、TextVQA、OCR-VQA-200K、RefCOCO系列,更新数据集工厂和处理逻辑,优化图像处理功能

This commit is contained in:
2025-05-15 20:33:29 +08:00
parent 9ca588224d
commit 24a6c3c114
17 changed files with 568 additions and 78 deletions
+8 -4
View File
@@ -18,8 +18,9 @@ class GigaspeechDataset(Dataset):
self.audio_processor = audio_processor
self.text_processor = text_processor
gs = load_dataset("speechcolab/gigaspeech", "xs")
self.data = gs[split]
from .format import dataset_dir
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[split] # type: ignore
def __len__(self):
return len(self.data)
@@ -54,7 +55,7 @@ class GigaspeechDataset(Dataset):
audios=[(audio, sampling_rate)],
chat=chat,
original=sample,
)
) # type: ignore
class GigaspeechDatasetForGeneration(GigaspeechDataset):
@@ -87,7 +88,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
chat=chat,
answer=text,
original=sample,
)
) # type: ignore
def test_gigaspeech():
@@ -103,3 +104,6 @@ def test_gigaspeech():
print(dataset[0])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_gigaspeech()
+5 -21
View File
@@ -22,8 +22,10 @@ class OCRVQADataset(Dataset):
"""
self.vis_root = vis_root
from .vis_processor import size_processor
self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
if split == "train":
@@ -64,24 +66,6 @@ class OCRVQADataset(Dataset):
def __len__(self):
return len(self.data)
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
@@ -117,7 +101,7 @@ class OCRVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
)
) # type: ignore
class OCRVQADatasetForGeneration(OCRVQADataset):
@@ -153,4 +137,4 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
)
) # type: ignore
+121
View File
@@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCODataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCODatasetForGeneration(RefCOCODataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCO():
dataset = RefCOCODataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCODatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCO()
+121
View File
@@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOplusDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "testA"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOplusDatasetForGeneration(RefCOCOplusDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOplus():
dataset = RefCOCOplusDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOplusDatasetForGeneration(
split="testA",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCOplus()
+121
View File
@@ -0,0 +1,121 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOgDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOg():
dataset = RefCOCOgDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOgDatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_RefCOCOg()
+8 -7
View File
@@ -6,20 +6,21 @@ from .format import (
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset
from datasets import load_dataset, DatasetDict
class ScienceQADataset(Dataset):
def __init__(self, audio_processor=None, text_processor=None, split="train"):
def __init__(self, vis_processor=None, text_processor=None, split="train"):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = audio_processor
self.vis_processor = vis_processor
self.text_processor = text_processor
ds = load_dataset("derek-thomas/ScienceQA")
self.data = ds[split]
from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir)
self.data = ds[split] # type: ignore
def __len__(self):
return len(self.data)
@@ -60,7 +61,7 @@ class ScienceQADataset(Dataset):
images=[images],
chat=chat,
original=sample,
)
) # type: ignore
class ScienceQADatasetForGeneration(ScienceQADataset):
@@ -98,7 +99,7 @@ class ScienceQADatasetForGeneration(ScienceQADataset):
chat=chat,
answer=choices[answer],
original=sample,
)
) # type: ignore
def test_scienceQA():
+2 -13
View File
@@ -124,7 +124,7 @@ class TextVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
)
) # type: ignore
class TextVQADatasetForGeneration(TextVQADataset):
@@ -158,16 +158,5 @@ class TextVQADatasetForGeneration(TextVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
)
) # type: ignore
def test_dataset():
vis_root = "/home/zyy/dataset/TextVQA/images"
ann_path = "/home/zyy/dataset/TextVQA"
dataset = TextVQADataset(vis_root, ann_path)
for i in range(10):
print(dataset[i])
if __name__ == "__main__":
test_dataset()
+38 -1
View File
@@ -1,10 +1,11 @@
from torch.utils.data import Dataset
from typing import Literal
from pathlib import Path
from dataset_library.format import dataset_dir
def get_dataset(
dataset_name, base_path="/home/zyy/dataset"
dataset_name, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
match dataset_name:
@@ -92,4 +93,40 @@ def get_dataset(
"generation": ScienceQADatasetForGeneration(split="test"),
}
case "refcoco":
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
case "refcocog":
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
case "refcocoplus":
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
return dataset
+3
View File
@@ -1,6 +1,9 @@
from typing import Any, Tuple, TypedDict, Literal, Optional
import numpy as np
from PIL import Image
from pathlib import Path
dataset_dir = Path(__file__).resolve().parent.parent.parent / "dataset"
class ConverstationText(TypedDict):
+56 -32
View File
@@ -1,46 +1,70 @@
from .factory import get_dataset
def test_gigaspeech():
dataset = get_dataset("gigaspeech")
assert len(dataset["train"]) > 0
# def test_gigaspeech():
# dataset = get_dataset("gigaspeech")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_chem():
# dataset = get_dataset("chem")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_ocrvqa200k():
# dataset = get_dataset("ocrvqa200k")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_textvqa():
# dataset = get_dataset("textvqa")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
# def test_scienceqa():
# dataset = get_dataset("scienceqa")
# assert len(dataset["train"]) > 0 # type: ignore
# assert len(dataset["train"][0]["chat"]) > 0
# assert len(dataset["test"]) > 0 # type: ignore
# assert len(dataset["test"][0]["chat"]) > 0
def test_refcoco():
dataset = get_dataset("refcoco")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_chem():
dataset = get_dataset("chem")
assert len(dataset["train"]) > 0
def test_refcocog():
dataset = get_dataset("refcocog")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
def test_ocrvqa200k():
dataset = get_dataset("ocrvqa200k")
assert len(dataset["train"]) > 0
def test_refcocoplus():
dataset = get_dataset("refcocoplus")
assert len(dataset["train"]) > 0 # type: ignore
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_textvqa():
dataset = get_dataset("textvqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_scienceqa():
dataset = get_dataset("scienceqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"]) > 0 # type: ignore
assert len(dataset["test"][0]["chat"]) > 0
+18
View File
@@ -0,0 +1,18 @@
from PIL import Image
def size_processor(image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image