feat: 更新数据集处理逻辑,添加Gigaspeech、TextVQA数据集支持,优化训练脚本,增加测试用例

This commit is contained in:
2025-01-18 21:37:15 +08:00
parent da644a081d
commit 7b9349091e
13 changed files with 425 additions and 116 deletions
@@ -4,7 +4,7 @@ import json
import os
class CHEMDataset(Dataset):
class ChemDataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
@@ -112,7 +112,7 @@ class CHEMDataset(Dataset):
}
class CHEMDatasetForGeneration(CHEMDataset):
class ChemDatasetForGeneration(ChemDataset):
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
@@ -147,29 +147,11 @@ class CHEMDatasetForGeneration(CHEMDataset):
],
},
]
from .format import create_generate
from .format import DatasetOutput
return create_generate(
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
)
if __name__ == "__main__":
dataset = CHEMDataset(
"/home/zyy/research/accelerate/dataset/chem/images",
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
split="train",
)
print(len(dataset))
print(dataset[0])
dataset = CHEMDatasetForGeneration(
"/home/zyy/research/accelerate/dataset/chem/images",
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
split="train",
)
print(len(dataset))
print(dataset[0])
pass
+38 -32
View File
@@ -1,7 +1,11 @@
from PIL import Image
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
import json
import os
from datasets import load_dataset
@@ -32,22 +36,25 @@ class GigaspeechDataset(Dataset):
text = self.text_processor(text)
chat = [
{
"role": "user",
"content": [
{"type": "audio", "audio_url": ""},
{
"type": "text",
"text": "Please convert the audio to text",
},
Conversation(
role="user",
content=[
ConverstationAudio(type="audio", audio_url=""),
ConverstationText(
type="text", text="Please convert the audio to text"
),
],
},
{"role": "assistant", "content": [{"type": "text", "text": text}]},
),
Conversation(
role="assistant", content=[ConverstationText(type="text", text=text)]
),
]
return {
"audio": (audio, sampling_rate),
"chat": chat,
}
return DatasetOutput(
audio=[(audio, sampling_rate)],
chat=chat,
original=sample,
)
class GigaspeechDatasetForGeneration(GigaspeechDataset):
@@ -64,20 +71,18 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
text = self.text_processor(text)
chat = [
{
"role": "user",
"content": [
{"type": "audio", "audio_url": ""},
{
"type": "text",
"text": "Please convert the audio to text",
},
Conversation(
role="user",
content=[
ConverstationAudio(type="audio", audio_url=""),
ConverstationText(
type="text", text="Please convert the audio to text"
),
],
},
),
]
from .format import create_generate
return create_generate(
return DatasetOutput(
audio=[(audio, sampling_rate)],
chat=chat,
answer=text,
@@ -85,15 +90,16 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
)
if __name__ == "__main__":
def test_gigaspeech():
dataset = GigaspeechDataset(
split="train",
)
print(len(dataset))
print(dataset[0])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = GigaspeechDatasetForGeneration(
split="train",
)
print(len(dataset))
print(dataset[0])
pass
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
@@ -1,7 +1,15 @@
from PIL import Image
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
import json
import os
from pathlib import Path
class OCRVQADataset(Dataset):
@@ -19,9 +27,9 @@ class OCRVQADataset(Dataset):
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
elif split == "test":
self.data = self.create_data(ann_path, split=3)
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
# self.instruction_pool = [
# "[vqa] {}",
@@ -48,6 +56,7 @@ class OCRVQADataset(Dataset):
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
"original": data[k],
}
)
return processed_data
@@ -89,22 +98,26 @@ class OCRVQADataset(Dataset):
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
),
Conversation(
role="assistant", content=[ConverstationText(type="text", text=answer)]
),
]
return {
"image": image,
"chat": chat,
}
return DatasetOutput(
chat=chat,
original=sample["original"],
images=[image],
)
class OCRVQADatasetForGeneration(OCRVQADataset):
@@ -124,20 +137,20 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
}
# {"role": "assistant", "content": answer},
),
]
return {
"image": image,
"chat": chat,
"answer": answer,
}
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
)
+173
View File
@@ -0,0 +1,173 @@
from PIL import Image
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
import json
import os.path as osp
from pathlib import Path
class TextVQADataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_train.json"),
vis_root=Path(vis_root, "train_images"),
)
elif split == "test":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_val.json"),
vis_root=Path(vis_root, "train_images"),
)
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, vis_root):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
data = data["data"]
for i in range(len(data)):
# print(data[0])
# {'question': 'what is the brand of phone?', 'image_id': '0054c91397f2fe05', 'image_classes': ['Belt', 'Headphones', 'Goggles', 'Scale', 'Bottle opener', 'Mobile phone', 'Mirror', 'Digital clock', 'Television', 'Telephone', 'Tool', 'Wheel', 'Camera', 'Watch', 'Glasses', 'Aircraft'], 'flickr_original_url': 'https://farm6.staticflickr.com/2891/9134076951_f65b421097_o.jpg', 'flickr_300k_url': 'https://c4.staticflickr.com/3/2891/9134076951_9db89d3e0f_z.jpg', 'image_width': 1024, 'image_height': 730, 'answers': ['nokia', 'nokia', 'nokia', 'nokia', 'toshiba', 'nokia', 'nokia', 'nokia', 'nokia', 'nokia'], 'question_tokens': ['what', 'is', 'the', 'brand', 'of', 'phone'], 'question_id': 0, 'set_name': 'train'}
try:
imageFile = data[i]["image_id"] + ".jpg"
question = data[i]["question"]
answer = data[i]["answers"][0]
processed_data.append(
{
"question": question,
"answer": answer,
"image_path": Path(vis_root, imageFile),
"image_id": data[i]["image_id"],
"title": data[i]["image_id"],
"genre": data[i]["image_classes"],
"original": data[i],
}
)
except:
print(data[i])
pass
return processed_data
def __len__(self):
return len(self.data)
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(sample["image_path"]).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
Conversation(
role="assistant", content=[ConverstationText(type="text", text=answer)]
),
]
return DatasetOutput(
chat=chat,
original=sample["original"],
images=[image],
)
class TextVQADatasetForGeneration(TextVQADataset):
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(sample["image_path"]).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
]
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
)
def test_dataset():
vis_root = "/home/zyy/dataset/TextVQA/images"
ann_path = "/home/zyy/dataset/TextVQA"
dataset = TextVQADataset(vis_root, ann_path)
for i in range(10):
print(dataset[i])
if __name__ == "__main__":
test_dataset()
+43 -22
View File
@@ -1,50 +1,49 @@
from torch.utils.data import Dataset
from typing import Literal
from pathlib import Path
def get_dataset(
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
dataset_name, base_path="/home/zyy/dataset"
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
if dataset_name == "OCR-VQA-200K":
import os.path as osp
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
if dataset_name == "ocrvqa200k":
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
dataset = {
"train": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
if dataset_name == "CHEM":
import os.path as osp
from .CHEM import CHEMDataset, CHEMDatasetForGeneration
if dataset_name == "chem":
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
dataset = {
"train": CHEMDataset(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem"),
"train": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="train",
),
"test": CHEMDataset(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem"),
"test": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
"generation": CHEMDatasetForGeneration(
osp.join(base_path, "chem/images"),
osp.join(base_path, "chem"),
"generation": ChemDatasetForGeneration(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
}
@@ -57,4 +56,26 @@ def get_dataset(
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
if dataset_name == "textvqa":
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
dataset = {
"train": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
}
return dataset
+32
View File
@@ -0,0 +1,32 @@
from typing import Any, Tuple, TypedDict, Literal, Optional
import numpy as np
from PIL import Image
class ConverstationText(TypedDict):
type: Literal["text"]
text: str
class ConverstationAudio(TypedDict):
type: Literal["audio"]
audio_url: str
class ConverstationImage(TypedDict):
type: Literal["image"]
image_url: str
class Conversation(TypedDict):
role: Literal["user", "assistant", "system"]
content: list[ConverstationText | ConverstationAudio | ConverstationImage]
class DatasetOutput(TypedDict):
audios: Optional[list[Tuple[np.ndarray, int]]]
chat: list[Conversation]
answer: Optional[str]
original: Any
images: Optional[list[Image.Image]]
+37
View File
@@ -0,0 +1,37 @@
from .factory import get_dataset
def test_gigaspeech():
dataset = get_dataset("gigaspeech")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_chem():
dataset = get_dataset("chem")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_ocrvqa200k():
dataset = get_dataset("ocrvqa200k")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_textvqa():
dataset = get_dataset("textvqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
+4 -5
View File
@@ -5,8 +5,6 @@ from trl import (
get_quantization_config,
)
from utils.args import ContinualModelConfig
import transformers
print(transformers.__version__)
def get_model(model_args: ContinualModelConfig):
@@ -26,7 +24,8 @@ def get_model(model_args: ContinualModelConfig):
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
@@ -38,7 +37,7 @@ def get_model(model_args: ContinualModelConfig):
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2vl import (
from .qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
@@ -60,7 +59,7 @@ def get_model(model_args: ContinualModelConfig):
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2audio import (
from .qwen2audio import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
+4 -3
View File
@@ -1,15 +1,16 @@
from transformers import Qwen2VLProcessor
from dataset_library.format import DatasetOutput
import torch
def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(example["chat"], tokenize=False)
for example in examples
]
# print(texts)
images = [example["image"] for example in examples]
images = [example["images"] for example in examples]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@@ -65,7 +66,7 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
for example in examples
]
# print(texts)
images = [example["image"] for example in examples]
images = [example["images"] for example in examples]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
+1 -1
View File
@@ -1,7 +1,7 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name CHEM \
--dataset_name gigaspeech \
--use_peft \
--peft_type LORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \