feat✨: 更新数据集处理逻辑,添加Gigaspeech、TextVQA数据集支持,优化训练脚本,增加测试用例
This commit is contained in:
parent
da644a081d
commit
7b9349091e
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
**/.venv/*
|
**/.venv/*
|
||||||
**/__pycache__/*
|
**/__pycache__/*
|
||||||
rsync.sh
|
rsync.sh
|
||||||
|
.pytest_cache/
|
||||||
|
@ -12,6 +12,7 @@ dependencies = [
|
|||||||
"peft==0.14.0",
|
"peft==0.14.0",
|
||||||
"pip==24.3.1",
|
"pip==24.3.1",
|
||||||
"pre-commit>=4.0.1",
|
"pre-commit>=4.0.1",
|
||||||
|
"pytest>=8.3.4",
|
||||||
"requests==2.32.3",
|
"requests==2.32.3",
|
||||||
"rouge-score>=0.1.2",
|
"rouge-score>=0.1.2",
|
||||||
"safetensors>=0.5.2",
|
"safetensors>=0.5.2",
|
||||||
@ -55,3 +56,11 @@ url = "https://download.pytorch.org/whl/cu124"
|
|||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
exclude = "transformers_repo|peft_repo|.venv"
|
exclude = "transformers_repo|peft_repo|.venv"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = ["--color=yes", "--durations=0", "-v", "--capture=tee-sys"]
|
||||||
|
norecursedirs = [
|
||||||
|
"src/transformers_repo",
|
||||||
|
"src/peft_repo",
|
||||||
|
".venv",
|
||||||
|
]
|
||||||
|
@ -4,7 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
class CHEMDataset(Dataset):
|
class ChemDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
):
|
):
|
||||||
@ -112,7 +112,7 @@ class CHEMDataset(Dataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CHEMDatasetForGeneration(CHEMDataset):
|
class ChemDatasetForGeneration(ChemDataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.data[index]
|
sample = self.data[index]
|
||||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
@ -147,29 +147,11 @@ class CHEMDatasetForGeneration(CHEMDataset):
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
from .format import create_generate
|
from .format import DatasetOutput
|
||||||
|
|
||||||
return create_generate(
|
return DatasetOutput(
|
||||||
images=[image],
|
images=[image],
|
||||||
chat=chat,
|
chat=chat,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dataset = CHEMDataset(
|
|
||||||
"/home/zyy/research/accelerate/dataset/chem/images",
|
|
||||||
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
|
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
print(len(dataset))
|
|
||||||
print(dataset[0])
|
|
||||||
dataset = CHEMDatasetForGeneration(
|
|
||||||
"/home/zyy/research/accelerate/dataset/chem/images",
|
|
||||||
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
|
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
print(len(dataset))
|
|
||||||
print(dataset[0])
|
|
||||||
pass
|
|
@ -1,7 +1,11 @@
|
|||||||
from PIL import Image
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
@ -32,22 +36,25 @@ class GigaspeechDataset(Dataset):
|
|||||||
text = self.text_processor(text)
|
text = self.text_processor(text)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
{
|
Conversation(
|
||||||
"role": "user",
|
role="user",
|
||||||
"content": [
|
content=[
|
||||||
{"type": "audio", "audio_url": ""},
|
ConverstationAudio(type="audio", audio_url=""),
|
||||||
{
|
ConverstationText(
|
||||||
"type": "text",
|
type="text", text="Please convert the audio to text"
|
||||||
"text": "Please convert the audio to text",
|
),
|
||||||
},
|
|
||||||
],
|
],
|
||||||
},
|
),
|
||||||
{"role": "assistant", "content": [{"type": "text", "text": text}]},
|
Conversation(
|
||||||
|
role="assistant", content=[ConverstationText(type="text", text=text)]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
return {
|
|
||||||
"audio": (audio, sampling_rate),
|
return DatasetOutput(
|
||||||
"chat": chat,
|
audio=[(audio, sampling_rate)],
|
||||||
}
|
chat=chat,
|
||||||
|
original=sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||||
@ -64,20 +71,18 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
|||||||
text = self.text_processor(text)
|
text = self.text_processor(text)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
{
|
Conversation(
|
||||||
"role": "user",
|
role="user",
|
||||||
"content": [
|
content=[
|
||||||
{"type": "audio", "audio_url": ""},
|
ConverstationAudio(type="audio", audio_url=""),
|
||||||
{
|
ConverstationText(
|
||||||
"type": "text",
|
type="text", text="Please convert the audio to text"
|
||||||
"text": "Please convert the audio to text",
|
),
|
||||||
},
|
|
||||||
],
|
],
|
||||||
},
|
),
|
||||||
]
|
]
|
||||||
from .format import create_generate
|
|
||||||
|
|
||||||
return create_generate(
|
return DatasetOutput(
|
||||||
audio=[(audio, sampling_rate)],
|
audio=[(audio, sampling_rate)],
|
||||||
chat=chat,
|
chat=chat,
|
||||||
answer=text,
|
answer=text,
|
||||||
@ -85,15 +90,16 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_gigaspeech():
|
||||||
dataset = GigaspeechDataset(
|
dataset = GigaspeechDataset(
|
||||||
split="train",
|
split="train",
|
||||||
)
|
)
|
||||||
print(len(dataset))
|
|
||||||
print(dataset[0])
|
print(dataset[0])
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
dataset = GigaspeechDatasetForGeneration(
|
dataset = GigaspeechDatasetForGeneration(
|
||||||
split="train",
|
split="train",
|
||||||
)
|
)
|
||||||
print(len(dataset))
|
|
||||||
print(dataset[0])
|
print(dataset[0])
|
||||||
pass
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
@ -1,7 +1,15 @@
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADataset(Dataset):
|
class OCRVQADataset(Dataset):
|
||||||
@ -19,9 +27,9 @@ class OCRVQADataset(Dataset):
|
|||||||
)
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.data = self.create_data(ann_path, split=1)
|
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
|
||||||
elif split == "test":
|
elif split == "test":
|
||||||
self.data = self.create_data(ann_path, split=3)
|
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
# "[vqa] {}",
|
# "[vqa] {}",
|
||||||
@ -48,6 +56,7 @@ class OCRVQADataset(Dataset):
|
|||||||
"image_id": k,
|
"image_id": k,
|
||||||
"title": data[k]["title"],
|
"title": data[k]["title"],
|
||||||
"genre": data[k]["genre"],
|
"genre": data[k]["genre"],
|
||||||
|
"original": data[k],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return processed_data
|
return processed_data
|
||||||
@ -89,22 +98,26 @@ class OCRVQADataset(Dataset):
|
|||||||
answer = self.text_processor(answer)
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
{
|
Conversation(
|
||||||
"role": "user",
|
role="user",
|
||||||
"content": [
|
content=[
|
||||||
{"type": "image"},
|
ConverstationImage(type="image", image_url=""),
|
||||||
{
|
ConverstationText(
|
||||||
"type": "text",
|
type="text",
|
||||||
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
},
|
),
|
||||||
],
|
],
|
||||||
},
|
),
|
||||||
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
Conversation(
|
||||||
|
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
return {
|
|
||||||
"image": image,
|
return DatasetOutput(
|
||||||
"chat": chat,
|
chat=chat,
|
||||||
}
|
original=sample["original"],
|
||||||
|
images=[image],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||||
@ -124,20 +137,20 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
|
|||||||
answer = self.text_processor(answer)
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
{
|
Conversation(
|
||||||
"role": "user",
|
role="user",
|
||||||
"content": [
|
content=[
|
||||||
{"type": "image"},
|
ConverstationImage(type="image", image_url=""),
|
||||||
{
|
ConverstationText(
|
||||||
"type": "text",
|
type="text",
|
||||||
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
},
|
),
|
||||||
],
|
],
|
||||||
}
|
),
|
||||||
# {"role": "assistant", "content": answer},
|
|
||||||
]
|
]
|
||||||
return {
|
return DatasetOutput(
|
||||||
"image": image,
|
images=[image],
|
||||||
"chat": chat,
|
chat=chat,
|
||||||
"answer": answer,
|
answer=answer,
|
||||||
}
|
original=sample["original"],
|
||||||
|
)
|
173
src/dataset_library/TextVQADataset.py
Normal file
173
src/dataset_library/TextVQADataset.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class TextVQADataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.vis_processor = (
|
||||||
|
vis_processor if vis_processor is not None else self._vis_processor
|
||||||
|
)
|
||||||
|
self.text_processor = text_processor
|
||||||
|
if split == "train":
|
||||||
|
self.data = self.create_data(
|
||||||
|
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
||||||
|
vis_root=Path(vis_root, "train_images"),
|
||||||
|
)
|
||||||
|
elif split == "test":
|
||||||
|
self.data = self.create_data(
|
||||||
|
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
||||||
|
vis_root=Path(vis_root, "train_images"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.instruction_pool = [
|
||||||
|
# "[vqa] {}",
|
||||||
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
def create_data(self, ann_path, vis_root):
|
||||||
|
processed_data = []
|
||||||
|
with open(ann_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
data = data["data"]
|
||||||
|
for i in range(len(data)):
|
||||||
|
# print(data[0])
|
||||||
|
# {'question': 'what is the brand of phone?', 'image_id': '0054c91397f2fe05', 'image_classes': ['Belt', 'Headphones', 'Goggles', 'Scale', 'Bottle opener', 'Mobile phone', 'Mirror', 'Digital clock', 'Television', 'Telephone', 'Tool', 'Wheel', 'Camera', 'Watch', 'Glasses', 'Aircraft'], 'flickr_original_url': 'https://farm6.staticflickr.com/2891/9134076951_f65b421097_o.jpg', 'flickr_300k_url': 'https://c4.staticflickr.com/3/2891/9134076951_9db89d3e0f_z.jpg', 'image_width': 1024, 'image_height': 730, 'answers': ['nokia', 'nokia', 'nokia', 'nokia', 'toshiba', 'nokia', 'nokia', 'nokia', 'nokia', 'nokia'], 'question_tokens': ['what', 'is', 'the', 'brand', 'of', 'phone'], 'question_id': 0, 'set_name': 'train'}
|
||||||
|
try:
|
||||||
|
imageFile = data[i]["image_id"] + ".jpg"
|
||||||
|
question = data[i]["question"]
|
||||||
|
answer = data[i]["answers"][0]
|
||||||
|
processed_data.append(
|
||||||
|
{
|
||||||
|
"question": question,
|
||||||
|
"answer": answer,
|
||||||
|
"image_path": Path(vis_root, imageFile),
|
||||||
|
"image_id": data[i]["image_id"],
|
||||||
|
"title": data[i]["image_id"],
|
||||||
|
"genre": data[i]["image_classes"],
|
||||||
|
"original": data[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
print(data[i])
|
||||||
|
pass
|
||||||
|
|
||||||
|
return processed_data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def _vis_processor(self, image: Image.Image):
|
||||||
|
width, height = image.size
|
||||||
|
if width > 500 or height > 500:
|
||||||
|
max_size = max(width, height)
|
||||||
|
ratio = 500 / max_size
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
if width < 28 or height < 28:
|
||||||
|
min_size = min(width, height)
|
||||||
|
ratio = 28 / min_size + 1
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image: Image.Image = Image.open(sample["image_path"]).convert("RGB")
|
||||||
|
# resize image
|
||||||
|
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Conversation(
|
||||||
|
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
chat=chat,
|
||||||
|
original=sample["original"],
|
||||||
|
images=[image],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextVQADatasetForGeneration(TextVQADataset):
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image = Image.open(sample["image_path"]).convert("RGB")
|
||||||
|
# resize image
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[image],
|
||||||
|
chat=chat,
|
||||||
|
answer=answer,
|
||||||
|
original=sample["original"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset():
|
||||||
|
vis_root = "/home/zyy/dataset/TextVQA/images"
|
||||||
|
ann_path = "/home/zyy/dataset/TextVQA"
|
||||||
|
dataset = TextVQADataset(vis_root, ann_path)
|
||||||
|
for i in range(10):
|
||||||
|
print(dataset[i])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_dataset()
|
@ -1,50 +1,49 @@
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
|
dataset_name, base_path="/home/zyy/dataset"
|
||||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||||
if dataset_name == "OCR-VQA-200K":
|
if dataset_name == "ocrvqa200k":
|
||||||
import os.path as osp
|
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||||
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": OCRVQADataset(
|
"train": OCRVQADataset(
|
||||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
split="train",
|
split="train",
|
||||||
),
|
),
|
||||||
"test": OCRVQADataset(
|
"test": OCRVQADataset(
|
||||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
"generation": OCRVQADatasetForGeneration(
|
"generation": OCRVQADatasetForGeneration(
|
||||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if dataset_name == "CHEM":
|
if dataset_name == "chem":
|
||||||
import os.path as osp
|
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||||
from .CHEM import CHEMDataset, CHEMDatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": CHEMDataset(
|
"train": ChemDataset(
|
||||||
osp.join(base_path, "chem/images"),
|
vis_root=Path(base_path, "chem", "images"),
|
||||||
osp.join(base_path, "chem"),
|
ann_path=Path(base_path, "chem"),
|
||||||
split="train",
|
split="train",
|
||||||
),
|
),
|
||||||
"test": CHEMDataset(
|
"test": ChemDataset(
|
||||||
osp.join(base_path, "chem/images"),
|
vis_root=Path(base_path, "chem", "images"),
|
||||||
osp.join(base_path, "chem"),
|
ann_path=Path(base_path, "chem"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
"generation": CHEMDatasetForGeneration(
|
"generation": ChemDatasetForGeneration(
|
||||||
osp.join(base_path, "chem/images"),
|
vis_root=Path(base_path, "chem", "images"),
|
||||||
osp.join(base_path, "chem"),
|
ann_path=Path(base_path, "chem"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@ -57,4 +56,26 @@ def get_dataset(
|
|||||||
"test": GigaspeechDataset(split="test"),
|
"test": GigaspeechDataset(split="test"),
|
||||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dataset_name == "textvqa":
|
||||||
|
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": TextVQADataset(
|
||||||
|
vis_root=Path(base_path, "TextVQA", "images"),
|
||||||
|
ann_path=Path(base_path, "TextVQA"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": TextVQADataset(
|
||||||
|
vis_root=Path(base_path, "TextVQA", "images"),
|
||||||
|
ann_path=Path(base_path, "TextVQA"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": TextVQADatasetForGeneration(
|
||||||
|
vis_root=Path(base_path, "TextVQA", "images"),
|
||||||
|
ann_path=Path(base_path, "TextVQA"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
32
src/dataset_library/format.py
Normal file
32
src/dataset_library/format.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Any, Tuple, TypedDict, Literal, Optional
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class ConverstationText(TypedDict):
|
||||||
|
type: Literal["text"]
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ConverstationAudio(TypedDict):
|
||||||
|
type: Literal["audio"]
|
||||||
|
audio_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class ConverstationImage(TypedDict):
|
||||||
|
type: Literal["image"]
|
||||||
|
image_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(TypedDict):
|
||||||
|
|
||||||
|
role: Literal["user", "assistant", "system"]
|
||||||
|
content: list[ConverstationText | ConverstationAudio | ConverstationImage]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetOutput(TypedDict):
|
||||||
|
audios: Optional[list[Tuple[np.ndarray, int]]]
|
||||||
|
chat: list[Conversation]
|
||||||
|
answer: Optional[str]
|
||||||
|
original: Any
|
||||||
|
images: Optional[list[Image.Image]]
|
37
src/dataset_library/test_dataset.py
Normal file
37
src/dataset_library/test_dataset.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from .factory import get_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_gigaspeech():
|
||||||
|
dataset = get_dataset("gigaspeech")
|
||||||
|
assert len(dataset["train"]) > 0
|
||||||
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
assert len(dataset["test"]) > 0
|
||||||
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chem():
|
||||||
|
dataset = get_dataset("chem")
|
||||||
|
assert len(dataset["train"]) > 0
|
||||||
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
assert len(dataset["test"]) > 0
|
||||||
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_ocrvqa200k():
|
||||||
|
dataset = get_dataset("ocrvqa200k")
|
||||||
|
assert len(dataset["train"]) > 0
|
||||||
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
assert len(dataset["test"]) > 0
|
||||||
|
assert len(dataset["test"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_textvqa():
|
||||||
|
dataset = get_dataset("textvqa")
|
||||||
|
assert len(dataset["train"]) > 0
|
||||||
|
assert len(dataset["train"][0]["chat"]) > 0
|
||||||
|
|
||||||
|
assert len(dataset["test"]) > 0
|
||||||
|
assert len(dataset["test"][0]["chat"]) > 0
|
@ -5,8 +5,6 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from utils.args import ContinualModelConfig
|
from utils.args import ContinualModelConfig
|
||||||
import transformers
|
|
||||||
print(transformers.__version__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_args: ContinualModelConfig):
|
def get_model(model_args: ContinualModelConfig):
|
||||||
@ -26,7 +24,8 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
||||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
@ -38,7 +37,7 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
)
|
)
|
||||||
from model_library.qwen2vl import (
|
from .qwen2vl import (
|
||||||
collate_fn_for_train,
|
collate_fn_for_train,
|
||||||
collate_fn_for_evaluate,
|
collate_fn_for_evaluate,
|
||||||
)
|
)
|
||||||
@ -60,7 +59,7 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
)
|
)
|
||||||
from model_library.qwen2audio import (
|
from .qwen2audio import (
|
||||||
collate_fn_for_train,
|
collate_fn_for_train,
|
||||||
collate_fn_for_evaluate,
|
collate_fn_for_evaluate,
|
||||||
)
|
)
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
from transformers import Qwen2VLProcessor
|
from transformers import Qwen2VLProcessor
|
||||||
|
from dataset_library.format import DatasetOutput
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
|
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
for example in examples
|
for example in examples
|
||||||
]
|
]
|
||||||
# print(texts)
|
# print(texts)
|
||||||
images = [example["image"] for example in examples]
|
images = [example["images"] for example in examples]
|
||||||
# Tokenize the texts and process the images
|
# Tokenize the texts and process the images
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
|||||||
for example in examples
|
for example in examples
|
||||||
]
|
]
|
||||||
# print(texts)
|
# print(texts)
|
||||||
images = [example["image"] for example in examples]
|
images = [example["images"] for example in examples]
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
# Tokenize the texts and process the images
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||||
--dataset_name CHEM \
|
--dataset_name gigaspeech \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type LORA \
|
--peft_type LORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
35
uv.lock
generated
35
uv.lock
generated
@ -270,6 +270,7 @@ dependencies = [
|
|||||||
{ name = "peft" },
|
{ name = "peft" },
|
||||||
{ name = "pip" },
|
{ name = "pip" },
|
||||||
{ name = "pre-commit" },
|
{ name = "pre-commit" },
|
||||||
|
{ name = "pytest" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
{ name = "rouge-score" },
|
{ name = "rouge-score" },
|
||||||
{ name = "safetensors" },
|
{ name = "safetensors" },
|
||||||
@ -303,6 +304,7 @@ requires-dist = [
|
|||||||
{ name = "peft", specifier = "==0.14.0" },
|
{ name = "peft", specifier = "==0.14.0" },
|
||||||
{ name = "pip", specifier = "==24.3.1" },
|
{ name = "pip", specifier = "==24.3.1" },
|
||||||
{ name = "pre-commit", specifier = ">=4.0.1" },
|
{ name = "pre-commit", specifier = ">=4.0.1" },
|
||||||
|
{ name = "pytest", specifier = ">=8.3.4" },
|
||||||
{ name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" },
|
{ name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" },
|
||||||
{ name = "rouge-score", specifier = ">=0.1.2" },
|
{ name = "rouge-score", specifier = ">=0.1.2" },
|
||||||
{ name = "safetensors", specifier = ">=0.5.2" },
|
{ name = "safetensors", specifier = ">=0.5.2" },
|
||||||
@ -571,6 +573,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iniconfig"
|
||||||
|
version = "2.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jinja2"
|
name = "jinja2"
|
||||||
version = "3.1.5"
|
version = "3.1.5"
|
||||||
@ -1171,6 +1182,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 },
|
{ url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pluggy"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pooch"
|
name = "pooch"
|
||||||
version = "1.8.2"
|
version = "1.8.2"
|
||||||
@ -1402,6 +1422,21 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
|
{ url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest"
|
||||||
|
version = "8.3.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
{ name = "iniconfig" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pluggy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
|
Loading…
Reference in New Issue
Block a user