Compare commits

...

13 Commits

Author SHA1 Message Date
42df13390d feat: 添加OLORA支持,新增训练脚本并更新现有代码以适应新配置 2025-06-04 14:00:49 +08:00
266fcd57ad feat: 更新.gitignore以忽略results/目录,保持项目整洁 2025-06-03 20:25:31 +08:00
d686cbc254 feat: 添加MOELORA支持,优化训练和评估脚本,修复拼写错误,提升代码可读性 2025-06-03 20:25:20 +08:00
b84ebb03c7 feat: 更新训练脚本以支持新的lora_target_modules和输出目录,优化训练配置 2025-05-30 19:14:34 +08:00
baccca420a feat: 更新训练脚本以支持MOELORA,调整梯度累积步数,优化配置文件 2025-05-30 16:33:36 +08:00
70c446e548 refactor: 优化代码结构,提升可读性和一致性,移除不必要的文件 2025-05-28 19:32:13 +08:00
0bc1034f35 Refactor code structure for improved readability and maintainability 2025-05-27 16:09:48 +08:00
3fe2c85f6b Refactor code structure for improved readability and maintainability 2025-05-20 18:26:25 +08:00
56e46f0e0c fix: 修复数据集注册和加载逻辑,优化代码格式,确保一致性 2025-05-20 13:33:17 +08:00
1d2e7b9dcd feat: 添加多个数据集支持,优化数据加载逻辑,更新数据集注册机制,调整VSCode设置 2025-05-19 13:27:35 +08:00
c100a59f0e Refactor code structure for improved readability and maintainability 2025-05-16 19:50:41 +08:00
188ea7df6e feat: 添加VizWiz数据集支持,更新数据集工厂,优化测试用例,调整VSCode设置 2025-05-16 13:44:19 +08:00
24a6c3c114 feat: 添加多个数据集的支持,包括Gigaspeech、TextVQA、OCR-VQA-200K、RefCOCO系列,更新数据集工厂和处理逻辑,优化图像处理功能 2025-05-15 20:33:29 +08:00
46 changed files with 4073 additions and 2342 deletions

19
.vscode/settings.json vendored
View File

@ -1,6 +1,19 @@
{
"python.analysis.extraPaths": [
"src/transformers_repo/src/",
"src/peft_repo/src/"
]
"./src/peft_repo/src/",
"./src/transformers_repo/src/",
],
"python.analysis.exclude": [
"dataset/**/*",
],
"python.languageServer": "Default",
"python.terminal.activateEnvInCurrentTerminal": true,
// "python.analysis.include": [
// "src/**/*"
// ],
"python.analysis.languageServerMode": "default",
"python.analysis.typeCheckingMode": "basic",
"python.analysis.userFileIndexingLimit": 10000,
"python.analysis.usePullDiagnostics": false,
"python.analysis.importFormat": "relative",
}

5
dataset/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
derek-thomas*
*.lock
speechcolab*
lmms-lab*
downloads/*

2
dataset/OCR-VQA-200K/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
images/*
dataset.json

View File

@ -0,0 +1,51 @@
import os
import json
import urllib.request as ureq
import urllib.error
import concurrent.futures
import threading
# Set the file paths for your Google Drive
dataset_path = "./dataset.json"
images_path = "./images"
download = 1 # Set to 0 if images are already downloaded
# Load dataset json file
with open(dataset_path, "r") as fp:
data = json.load(fp)
# Initialize a counter and a lock for thread-safe counting
downloaded_count = 0
count_lock = threading.Lock()
# Function to download an image
def download_image(k):
global downloaded_count
imageURL = data[k]["imageURL"]
ext = os.path.splitext(imageURL)[1]
outputFile = os.path.join(images_path, f"{k}{ext}")
# Only download the image if it doesn't exist
if not os.path.exists(outputFile):
try:
ureq.urlretrieve(imageURL, outputFile)
with count_lock:
downloaded_count += 1
if downloaded_count % 100 == 0:
print(f"{downloaded_count} images downloaded.")
except urllib.error.URLError as e:
print(f"Error downloading {outputFile}: {e}")
# Download images using multiple threads
if download == 1:
if not os.path.exists(images_path):
os.makedirs(images_path)
# Create a thread pool and download the images in parallel
# Increase max_workers to potentially speed up downloads for many small files.
# The optimal number may vary based on your network and the server's capacity.
with concurrent.futures.ThreadPoolExecutor(max_workers=400) as executor:
executor.map(download_image, data.keys())

5
dataset/TextVQA/.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
images/test_images/*
images/train_images/*
TextVQA_0.5.1_test.json
TextVQA_0.5.1_train.json
TextVQA_0.5.1_val.json

3
dataset/vizwiz/Annotations/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
train.json
test.json
val.json

3
dataset/vizwiz/images/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
val/*
train/*
test/*

View File

@ -2,9 +2,11 @@
dependencies = [
"absl-py>=2.1.0",
"accelerate==1.2.1",
"calflops>=0.3.2",
"datasets==3.2.0",
"deepspeed==0.16.2",
"evaluate==0.4.3",
"huggingface-hub==0.30.1",
"librosa>=0.10.2.post1",
"markupsafe==2.1.5",
"ms-swift>=1.3.0",
@ -19,9 +21,9 @@ dependencies = [
"safetensors>=0.5.2",
"setuptools>=70.0.0",
"soundfile>=0.13.0",
"torch==2.5.1+cu124",
"torchaudio==2.5.1+cu124",
"torchvision==0.20.1+cu124",
"torch==2.6.0",
"torchaudio==2.6.0",
"torchvision==0.21.0",
"transformers==4.48.0",
"trl==0.13.0",
"wandb>=0.19.4",

5
src/.gitignore vendored
View File

@ -1 +1,4 @@
checkpoint/*
checkpoint/*
wandb/*
test.py
results/

View File

@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
gradient_accumulation_steps: 2
zero3_init_flag: false
zero_stage: 1
distributed_type: DEEPSPEED
@ -11,7 +11,7 @@ machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
num_processes: 3
rdzv_backend: static
same_network: true
tpu_env: []

View File

@ -12,7 +12,7 @@ machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []

View File

@ -18,8 +18,13 @@ class GigaspeechDataset(Dataset):
self.audio_processor = audio_processor
self.text_processor = text_processor
gs = load_dataset("speechcolab/gigaspeech", "xs")
self.data = gs[split]
self.split = split
def load_data(self):
from .format import dataset_dir
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
self.data = gs[self.split] # type: ignore
def __len__(self):
return len(self.data)
@ -54,7 +59,7 @@ class GigaspeechDataset(Dataset):
audios=[(audio, sampling_rate)],
chat=chat,
original=sample,
)
) # type: ignore
class GigaspeechDatasetForGeneration(GigaspeechDataset):
@ -87,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
chat=chat,
answer=text,
original=sample,
)
) # type: ignore
def test_gigaspeech():
@ -103,3 +108,21 @@ def test_gigaspeech():
print(dataset[0])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
from .factory import register_dataset
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
register_dataset(
dataset_name="gigaspeech",
dataset=dataset,
tag=["audio", "text"],
)
if __name__ == "__main__":
test_gigaspeech()

View File

@ -22,14 +22,20 @@ class OCRVQADataset(Dataset):
"""
self.vis_root = vis_root
from .vis_processor import size_processor
self.vis_processor = (
vis_processor if vis_processor is not None else self._vis_processor
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
elif split == "test":
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
self.split = split
self.ann_path = ann_path
def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
elif self.split == "test":
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
# self.instruction_pool = [
# "[vqa] {}",
@ -64,24 +70,6 @@ class OCRVQADataset(Dataset):
def __len__(self):
return len(self.data)
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
@ -117,7 +105,7 @@ class OCRVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
)
) # type: ignore
class OCRVQADatasetForGeneration(OCRVQADataset):
@ -153,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
)
) # type: ignore
from .factory import register_dataset
from .format import dataset_dir as base_path
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])

View File

@ -0,0 +1,139 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCODataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCODatasetForGeneration(RefCOCODataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCO():
dataset = RefCOCODataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCODatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcoco",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_RefCOCO()

View File

@ -0,0 +1,139 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOplusDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "testA"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOplusDatasetForGeneration(RefCOCOplusDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOplus():
dataset = RefCOCOplusDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOplusDatasetForGeneration(
split="testA",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocoplus",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_RefCOCOplus()

View File

@ -0,0 +1,139 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset, DatasetDict
from typing import Literal
class RefCOCOgDataset(Dataset):
def __init__(
self,
vis_processor=None,
text_processor=None,
split: Literal["val", "test"] = "val",
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=question,
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=answer)],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
) # type: ignore
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
images = sample["image"]
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"{question}",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=answer,
original=sample,
) # type: ignore
def test_RefCOCOg():
dataset = RefCOCOgDataset(
split="val",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = RefCOCOgDatasetForGeneration(
split="test",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocog",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_RefCOCOg()

View File

@ -6,20 +6,25 @@ from .format import (
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset
from datasets import load_dataset, DatasetDict
class ScienceQADataset(Dataset):
def __init__(self, audio_processor=None, text_processor=None, split="train"):
def __init__(self, vis_processor=None, text_processor=None, split="train"):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = audio_processor
self.vis_processor = vis_processor
self.text_processor = text_processor
ds = load_dataset("derek-thomas/ScienceQA")
self.data = ds[split]
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
return len(self.data)
@ -60,7 +65,7 @@ class ScienceQADataset(Dataset):
images=[images],
chat=chat,
original=sample,
)
) # type: ignore
class ScienceQADatasetForGeneration(ScienceQADataset):
@ -98,7 +103,7 @@ class ScienceQADatasetForGeneration(ScienceQADataset):
chat=chat,
answer=choices[answer],
original=sample,
)
) # type: ignore
def test_scienceQA():
@ -116,5 +121,19 @@ def test_scienceQA():
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="scienceqa",
dataset=dataset,
tag=["image", "text"],
)
if __name__ == "__main__":
test_scienceQA()

View File

@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
vis_processor if vis_processor is not None else self._vis_processor
)
self.text_processor = text_processor
if split == "train":
self.split = split
self.ann_path = ann_path
self.vis_root = vis_root
def load_data(self):
if self.split == "train":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_train.json"),
vis_root=Path(vis_root, "train_images"),
Path(self.ann_path, "TextVQA_0.5.1_train.json"),
vis_root=Path(self.vis_root, "train_images"),
)
elif split == "test":
elif self.split == "test":
self.data = self.create_data(
Path(ann_path, "TextVQA_0.5.1_val.json"),
vis_root=Path(vis_root, "train_images"),
Path(self.ann_path, "TextVQA_0.5.1_val.json"),
vis_root=Path(self.vis_root, "train_images"),
)
# self.instruction_pool = [
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
chat=chat,
original=sample["original"],
images=[image],
)
) # type: ignore
class TextVQADatasetForGeneration(TextVQADataset):
@ -158,16 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
chat=chat,
answer=answer,
original=sample["original"],
)
) # type: ignore
def test_dataset():
vis_root = "/home/zyy/dataset/TextVQA/images"
ann_path = "/home/zyy/dataset/TextVQA"
dataset = TextVQADataset(vis_root, ann_path)
for i in range(10):
print(dataset[i])
from .format import dataset_dir
dataset = {
"train": TextVQADataset(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(dataset_dir, "TextVQA", "images"),
ann_path=Path(dataset_dir, "TextVQA"),
split="test",
),
}
if __name__ == "__main__":
test_dataset()
from .factory import register_dataset
register_dataset("textvqa", dataset, tag=["text", "image"])

View File

@ -0,0 +1,168 @@
from PIL import Image
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
import json
import os
from pathlib import Path
class VizWizDataset(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.ann_path = ann_path
from .vis_processor import size_processor
self.vis_processor = (
vis_processor if vis_processor is not None else size_processor
)
self.text_processor = text_processor
self.split = split
def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "train.json"))
elif self.split == "test":
self.data = self.create_data(Path(self.ann_path, "val.json"))
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for i in range(len(data)):
if (
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
and data[i]["answerable"]
):
imageFile = data[i]["image"]
processed_data.append(
{
"question": data[i]["question"],
"answer": data[i]["answers"][0]["answer"],
"image_path": imageFile,
"original": data[i],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image: Image.Image = Image.open(
os.path.join(self.vis_root, sample["image_path"])
).convert("RGB")
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
Conversation(
role="assistant", content=[ConverstationText(type="text", text=answer)]
),
]
return DatasetOutput(
chat=chat,
original=sample["original"],
images=[image],
) # type: ignore
class VizWizDatasetForGeneration(VizWizDataset):
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
# resize image
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
),
],
),
]
return DatasetOutput(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
) # type: ignore
from .format import dataset_dir
dataset = {
"train": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="train",
),
"test": VizWizDataset(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
"generation": VizWizDatasetForGeneration(
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
split="test",
),
}
from .factory import register_dataset
register_dataset(
dataset_name="vizwiz",
dataset=dataset,
tag=["image", "text"],
)

View File

@ -1,95 +1,78 @@
from torch.utils.data import Dataset
from typing import Literal
from typing import Literal, List
from pathlib import Path
from dataset_library.format import dataset_dir
MAPPING_NAME_TO_DATASET: dict[
str, dict[Literal["train", "test", "generation"], Dataset]
] = {}
def register_dataset(
dataset_name: str,
dataset: dict[Literal["train", "test", "generation"], Dataset],
tag: List[Literal["image", "text", "audio", "video"]] = [],
**kwargs,
) -> None:
"""
Register a dataset.
Args:
dataset_name (`str`):
The name of the dataset.
dataset (`Dataset`):
The dataset to register.
"""
dataset_name = dataset_name.lower()
if dataset_name in MAPPING_NAME_TO_DATASET:
raise ValueError(f"Dataset {dataset_name} already registered.")
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
from .RefCOCODataset import (
RefCOCODataset,
RefCOCODatasetForGeneration,
)
from .RefCOCOgDataset import (
RefCOCOgDataset,
RefCOCOgDatasetForGeneration,
)
from .RefCOCOPlusDataset import (
RefCOCOplusDataset,
RefCOCOplusDatasetForGeneration,
)
from .VizWizDataset import (
VizWizDataset,
VizWizDatasetForGeneration,
)
def get_dataset(
dataset_name, base_path="/home/zyy/dataset"
dataset_name: str, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
match dataset_name:
case "ocrvqa200k":
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
"""
Get a dataset by name.
dataset = {
"train": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="train",
),
"test": OCRVQADataset(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
ann_path=Path(base_path, "OCR-VQA-200K"),
split="test",
),
}
case "chem":
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
dataset = {
"train": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="train",
),
"test": ChemDataset(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
"generation": ChemDatasetForGeneration(
vis_root=Path(base_path, "chem", "images"),
ann_path=Path(base_path, "chem"),
split="test",
),
}
case "gigaspeech":
from .GigaspeechDataset import (
GigaspeechDataset,
GigaspeechDatasetForGeneration,
)
dataset = {
"train": GigaspeechDataset(split="train"),
"test": GigaspeechDataset(split="test"),
"generation": GigaspeechDatasetForGeneration(split="test"),
}
case "textvqa":
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
dataset = {
"train": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="train",
),
"test": TextVQADataset(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
"generation": TextVQADatasetForGeneration(
vis_root=Path(base_path, "TextVQA", "images"),
ann_path=Path(base_path, "TextVQA"),
split="test",
),
}
case "scienceqa":
from .ScienceQADataset import (
ScienceQADataset,
ScienceQADatasetForGeneration,
)
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
return dataset
Args:
dataset_name (`str`):
The name of the dataset.
base_path (`str`):
The base path to the dataset.
"""
dataset_name = dataset_name.lower()
if dataset_name not in MAPPING_NAME_TO_DATASET:
raise ValueError(f"Dataset {dataset_name} not registered.")
for key in MAPPING_NAME_TO_DATASET[dataset_name]:
MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
return MAPPING_NAME_TO_DATASET[dataset_name]

View File

@ -1,6 +1,9 @@
from typing import Any, Tuple, TypedDict, Literal, Optional
import numpy as np
from PIL import Image
from pathlib import Path
dataset_dir = Path(__file__).resolve().parent.parent.parent / "dataset"
class ConverstationText(TypedDict):
@ -25,8 +28,8 @@ class Conversation(TypedDict):
class DatasetOutput(TypedDict):
audios: Optional[list[Tuple[np.ndarray, int]]]
chat: list[Conversation]
answer: Optional[str]
original: Any
images: Optional[list[Image.Image]]
audios: Optional[list[Tuple[np.ndarray, int]]]

View File

@ -1,46 +1,22 @@
from .factory import get_dataset
import pytest
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
# Get all registered dataset names for parameterization
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
def test_gigaspeech():
dataset = get_dataset("gigaspeech")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
@pytest.mark.parametrize("dataset_name", dataset_names)
def test_registered_datasets(dataset_name):
dataset = get_dataset(dataset_name)
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
# Test train split
assert "train" in dataset, f"Train split not found in {dataset_name}"
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
def test_chem():
dataset = get_dataset("chem")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_ocrvqa200k():
dataset = get_dataset("ocrvqa200k")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_textvqa():
dataset = get_dataset("textvqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_scienceqa():
dataset = get_dataset("scienceqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
# Test test split
assert "test" in dataset, f"Test split not found in {dataset_name}"
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore

View File

@ -0,0 +1,20 @@
from PIL import Image
def size_processor(image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
if width < 28 or height < 28:
min_size = min(width, height)
ratio = 28 / min_size + 1
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image

View File

@ -19,67 +19,39 @@ from trl import (
get_quantization_config,
)
from utils.args import ContinualScriptArguments, ContinualModelConfig
from utils.args import (
ContinualScriptArguments,
ContinualModelConfig,
ContinualRegularizationArguments,
)
from typing import TYPE_CHECKING
if __name__ == "__main__":
parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
(
ContinualScriptArguments,
TrainingArguments,
ContinualModelConfig,
ContinualRegularizationArguments,
)
)
script_args, training_args, model_args = parser.parse_args_and_config()
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:
if TYPE_CHECKING:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ModelConfig()
model_args = ContinualModelConfig()
reg_args = ContinualRegularizationArguments()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
from model_library.factory import get_model
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
quantization_config=quantization_config,
)
from transformers import (
Qwen2VLProcessor,
Qwen2VLForConditionalGeneration,
AutoModelForVision2Seq,
AutoModel,
)
from peft.peft_model import PeftModelForCausalLM
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
model = Qwen2VLForConditionalGeneration.from_pretrained(
training_args.output_dir,
**model_kwargs,
)
# from peft_library import get_peft_model
processor = Qwen2VLProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
model_args=model_args, training_args=training_args
)
################
# Dataset
################
@ -107,6 +79,6 @@ if __name__ == "__main__":
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
from utils.evaluate_tool import evaluate_rouge, evalute_save
from utils.evaluate_tool import evaluate_rouge, evaluate_save
evalute_save(model, val_dataloader, processor, accelerator)
evaluate_save(model, val_dataloader, processor, accelerator)

View File

@ -5,9 +5,12 @@ from trl import (
get_quantization_config,
)
from utils.args import ContinualModelConfig
from transformers import TrainingArguments
def get_model(model_args: ContinualModelConfig):
def get_model(
model_args: ContinualModelConfig, training_args: TrainingArguments = None
):
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
@ -26,12 +29,20 @@ def get_model(model_args: ContinualModelConfig):
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
if training_args is not None:
model = Qwen2VLForConditionalGeneration.from_pretrained(
training_args.output_dir,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
processor = Qwen2VLProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
@ -49,11 +60,18 @@ def get_model(model_args: ContinualModelConfig):
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
model = Qwen2AudioForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if training_args is not None:
model = Qwen2AudioForConditionalGeneration.from_pretrained(
training_args.output_dir,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
model = Qwen2AudioForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
processor = Qwen2AudioProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
@ -68,4 +86,63 @@ def get_model(model_args: ContinualModelConfig):
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
if training_args is not None:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
training_args.output_dir,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
processor = Qwen2_5_VLProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
from transformers.models.qwen2_5_omni import (
Qwen2_5OmniThinkerForConditionalGeneration,
Qwen2_5OmniProcessor,
)
if training_args is not None:
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
training_args.output_dir,
**model_kwargs,
)
else:
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
processor = Qwen2_5OmniProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
return model, processor, collate_fn_for_train, collate_fn_for_evaluate

View File

@ -1,7 +1,8 @@
from transformers import Qwen2AudioProcessor
from dataset_library.format import Conversation
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
def collate_fn_for_train(examples:list[Conversation], processor: Qwen2AudioProcessor):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(example["chat"], tokenize=False)
@ -60,7 +61,7 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
return batch
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
def collate_fn_for_evaluate(examples:list[Conversation], processor: Qwen2AudioProcessor):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
@ -91,3 +92,4 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
# answers_ids torch.Size([3, 10])
# answers_mask torch.Size([3, 10])
return batch

View File

@ -1,5 +1,6 @@
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
from .model import Qwen2VLForConditionalGeneration_modified
# from .model import Qwen2VLForConditionalGeneration_modified
__all__ = [
"collate_fn_for_train",

View File

@ -1,9 +1,22 @@
from transformers import Qwen2VLProcessor
# from transformers import Qwen2VLProcessor
import sys
sys.path.insert(0, "transformers_repo/src/")
sys.path.insert(0, "peft_repo/src/")
import transformers
import peft
from dataset_library.format import DatasetOutput
import torch
from typing import TYPE_CHECKING
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
if TYPE_CHECKING:
from transformers import Qwen2VLProcessor
from transformers import Qwen2_5_VLProcessor
def collate_fn_for_train(examples: list[DatasetOutput], processor: "Qwen2VLProcessor"):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(example["chat"], tokenize=False)
@ -32,36 +45,36 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
# 151644 151645 8948 872 77091 None 151643
# for i, label in enumerate(labels):
# now_index = 0
# while now_index < len(label):
# if label[now_index] == im_start_token_id:
# label[now_index] = -100
# now_index += 1
# if (
# label[now_index] == system_token_id
# or label[now_index] == user_token_id
# ):
# while label[now_index] != im_end_token_id:
# label[now_index] = -100
# now_index += 1
# label[now_index] = -100
# elif label[now_index] == assistant_token_id:
# label[now_index] = -100
# label[now_index + 1] = -100
# now_index += 2
# while (
# now_index < len(label) and label[now_index] != im_end_token_id
# ):
# now_index += 1
# now_index += 1
for i, label in enumerate(labels):
now_index = 0
while now_index < len(label):
if label[now_index] == im_start_token_id:
label[now_index] = -100
now_index += 1
if (
label[now_index] == system_token_id
or label[now_index] == user_token_id
):
while label[now_index] != im_end_token_id:
label[now_index] = -100
now_index += 1
label[now_index] = -100
elif label[now_index] == assistant_token_id:
label[now_index] = -100
label[now_index + 1] = -100
now_index += 2
while (
now_index < len(label) and label[now_index] != im_end_token_id
):
now_index += 1
now_index += 1
batch["labels"] = labels
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
return batch
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
def collate_fn_for_evaluate(examples, processor: "Qwen2VLProcessor"):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
@ -89,3 +102,68 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
# answers_ids torch.Size([3, 10])
# answers_mask torch.Size([3, 10])
return batch
if __name__ == "__main__":
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
from PIL import Image
# 随机生成一个图片
import numpy as np
random_image = Image.fromarray(
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
)
example = {
"chat": [
# {"role": "user", "content": "What is the capital of France?"},
{
"role": "user",
"content": [
{
"type": "image",
"image_url": "", # Assuming no image for this example
},
{
"type": "image",
"image_url": "", # Assuming no image for this example
},
{
"type": "image",
"image_url": "", # Assuming no image for this example
},
{
"type": "text",
"text": "What is the capital of France?",
},
],
}, # Assuming no image for this example
{"role": "assistant", "content": "The capital of France is Paris."},
],
"images": [
random_image,
random_image,
random_image,
], # Assuming no images for this example
}
batch = collate_fn_for_train([example], processor)
# print(batch)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
print(f"{k}: {v.shape}")
else:
print(f"{k}: {v}")
# input_ids: torch.Size([1, 101])
# attention_mask: torch.Size([1, 101])
# pixel_values: torch.Size([256, 1176])
# image_grid_thw: torch.Size([1, 3])
# labels: torch.Size([1, 101])
# Load model directly
from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-3B", torch_dtype="auto", device_map="auto")
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")

View File

@ -73,7 +73,6 @@ from peft.tuners import (
from .tuners import MMOELoraModel, MMOELoraConfig
from peft.tuners.tuners_utils import BaseTuner
from peft.utils import _prepare_prompt_learning_config
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
if TYPE_CHECKING:

View File

@ -46,7 +46,7 @@ from transformers.modeling_outputs import (
)
from transformers.utils import PushToHubMixin
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft import __version__
from peft.config import PeftConfig

View File

@ -0,0 +1,19 @@
class RegularizationMethod:
"""RegularizationMethod implement regularization strategies.
RegularizationMethod is a callable.
The method `update` is called to update the loss, typically at the end
of an experience.
"""
def pre_adapt(self, agent, exp):
pass # implementation may be empty if adapt is not needed
def post_adapt(self, agent, exp):
pass # implementation may be empty if adapt is not needed
def __call__(self, *args, **kwargs):
raise NotImplementedError()
from .ewc import EWC
from .lwf import LWF

View File

@ -0,0 +1,58 @@
from . import RegularizationMethod
import torch
class EWC(RegularizationMethod):
"""Learning Without Forgetting.
The method applies knowledge distilllation to mitigate forgetting.
The teacher is the model checkpoint after the last experience.
"""
def __init__(self, EWC_lambda=1, temperature=2):
"""
:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
:param temperature: softmax temperature for distillation
"""
self.EWC_lambda = EWC_lambda
self.temperature = temperature
self.fisher = {}
self.optpar = {}
""" In Avalanche, targets of different experiences are not ordered.
As a result, some units may be allocated even though their
corresponding class has never been seen by the model.
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output, model, **kwargs):
ewc_loss = 0
for n, p in model.named_parameters():
if p.requires_grad:
dev = p.device
l = (
self.EWC_lambda
* self.fisher[n].to(dev)
* (p.data - self.optpar[n].to(dev)).pow(2)
)
ewc_loss += l.sum()
output["loss"] += ewc_loss
return output
def init_epoch(self, model):
"""Update the previous logits for the given question id."""
optpar = {}
fisher = {}
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher[n] = torch.zeros(p.data.shape)
optpar[n] = p.clone().cpu().data
def update_fisher(self, model):
"""Update the fisher information for the given question id."""
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher = self.fisher[n]
fisher += p.grad.data.pow(2).cpu()
self.fisher[n] = fisher

View File

@ -0,0 +1,67 @@
from . import RegularizationMethod
import torch
class LWF(RegularizationMethod):
"""Learning Without Forgetting.
The method applies knowledge distilllation to mitigate forgetting.
The teacher is the model checkpoint after the last experience.
"""
def __init__(self, LWF_lambda=1, temperature=2):
"""
:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
:param temperature: softmax temperature for distillation
"""
self.LWF_lambda = LWF_lambda
self.temperature = temperature
self.previous_logits = {}
""" In Avalanche, targets of different experiences are not ordered.
As a result, some units may be allocated even though their
corresponding class has never been seen by the model.
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output, **kwargs):
def modified_kl_div(old, new):
return -torch.mean(torch.sum(old * torch.log(new), 1))
def smooth(logits, temp, dim):
log = logits ** (1 / temp)
return log / torch.sum(log, dim).unsqueeze(1)
lwf_loss = []
soft = torch.nn.Softmax(dim=1)
previous_keys = self.previous_logits.keys()
for index, question_id in enumerate(iterable=kwargs["question_ids"]):
if question_id in previous_keys:
previous_logits = self.previous_logits[question_id]
current_logits = output["logits"][index]
short_index = min(len(previous_logits), len(current_logits))
previous_logits = previous_logits[:short_index]
current_logits = current_logits[:short_index]
lwf_loss.append(
modified_kl_div(
old=smooth(
logits=soft(previous_logits).to(current_logits.device),
temp=2,
dim=1,
),
new=smooth(logits=soft(current_logits), temp=2, dim=1),
)
)
if len(lwf_loss) > 0:
output["loss"] += self.LWF_lambda * torch.stack(
tensors=lwf_loss, dim=0
).sum(dim=0)
return output
def update_previous_logits(self, question_id, logits):
"""Update the previous logits for the given question id."""
self.previous_logits[question_id] = logits

@ -1 +1 @@
Subproject commit ebab283576fc8803314314b5e1e4331c424a2198
Subproject commit a6e39fafb4e3ccce27671ce14afe62fb754c83c2

15
src/scripts/eval_omni.sh Executable file
View File

@ -0,0 +1,15 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml evaluation.py \
--dataset_name textvqa \
--use_peft \
--peft_type MOELORA \
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
--output_dir ./checkpoint/qwen2_5omni_moelora/ \
--bf16 \
--torch_dtype bfloat16
# --eval_strategy epoch \

View File

@ -0,0 +1,25 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
--dataset_name textvqa \
--use_peft \
--peft_type MOELORA \
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
--lora_r 8 \
--lora_alpha 32 \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--num_train_epochs 1 \
--output_dir checkpoint/qwen2_5omni_moelora/ \
--learning_rate 2e-4 \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
--bf16 \
--torch_dtype bfloat16 \
--logging_steps 300 \
--gradient_checkpointing \
--weight_decay 0.1 \
--eval_strategy steps \
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500

25
src/scripts/train_omni_olora.sh Executable file
View File

@ -0,0 +1,25 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
--dataset_name textvqa \
--use_peft \
--peft_type OLORA \
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
--lora_r 8 \
--lora_alpha 32 \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--num_train_epochs 1 \
--output_dir checkpoint/qwen2_5omni_olora/ \
--learning_rate 2e-4 \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
--bf16 \
--torch_dtype bfloat16 \
--logging_steps 300 \
--gradient_checkpointing \
--weight_decay 0.1 \
--eval_strategy steps \
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500

31
src/test_evalutae.py Normal file
View File

@ -0,0 +1,31 @@
import evaluate
# Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, and CIDEr
example = {
"generated": "The cat sat on the mat.",
"target": "The cat is sitting on the mat.",
"original": "The cat is sitting on the mat.",
}
evaluate_bleu = evaluate.load("bleu")
evaluate_rouge = evaluate.load("rouge")
evaluate_meteor = evaluate.load("meteor")
evaluate_bleu.add_batch(
predictions=[example["generated"]],
references=[[example["target"]]],
)
evaluate_rouge.add_batch(
predictions=[example["generated"]],
references=[[example["target"]]],
)
evaluate_meteor.add_batch(
predictions=[example["generated"]],
references=[[example["target"]]],
)
bleu = evaluate_bleu.compute()
rouge = evaluate_rouge.compute()
meteor = evaluate_meteor.compute()
comprehensive_results = sum(bleu['precisions']) + rouge['rougeL'] + meteor['meteor']
print("Comprehensive Results:", comprehensive_results/6)

View File

@ -17,6 +17,31 @@
[2025.01.19]
- [ ] 多个数据集引入
- [x] 多个数据集引入
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
- [ ] 引入EWC和LWF
[2025.05.15]
- [x] vizwiz处理
[2025.05.16]
- [ ] 处理不同的持续学习框架,使得整体框架能够兼容
[2025.05.28]
- [x] MoeLora
- [ ] Coin Benchmark
- [x] 确定保存什么,便于后期测试
- [x] Olora 非实现问题loss越来越高感觉很难训练
- [ ] Hide-Llava(复写基类引入clip不同的adapter做平均loralinear根据不同的name做插入top layer或正常layer模型要求接受传入task_id即clip计算的最大相似)
- [ ] Hide-llava问题前些层平均fusion很没有道理后些层的moe处理却整整引入了clip的计算量任务数确定task数量使得一些方法没有扩展性。现实场景要求没法知道后面还有多少个数据集然后减少遗忘最好能够对后续未见数据集产生效果moelora问题只能适当缓解利用不同的参数承接不同的任务。 那这个benchmark每次输入保留数据baseline是进一个把之前所有的都训练一边持续学习方法使用update的方式比较不同数据集按批次输入的收益找函数定义[How Efficient Are Todays Continual Learning Algorithms?],[]),也就是准确度的积分。
[2025.05.30]
- [x] 评价指标
[2025.06.03]
- [ ] 预期算法,低计算成本,

View File

@ -13,30 +13,42 @@ from trl import (
TrlParser,
)
from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments, ContinualModelConfig
from utils.args import (
ContinualScriptArguments,
ContinualModelConfig,
ContinualRegularizationArguments,
)
import logging
from typing import TYPE_CHECKING
logging.basicConfig(level=logging.INFO)
if __name__ == "__main__":
parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
(
ContinualScriptArguments,
TrainingArguments,
ContinualModelConfig,
ContinualRegularizationArguments,
) # type: ignore
)
script_args, training_args, model_args = parser.parse_args_and_config()
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:
if TYPE_CHECKING:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ContinualModelConfig()
reg_args = ContinualRegularizationArguments()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
from model_library.factory import get_model
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
model_args
model_args=model_args
)
################
# Dataset
@ -47,12 +59,19 @@ if __name__ == "__main__":
accelerator = create_accelerator_and_postprocess(training_args)
if model_args.peft_type == "MMOELORA":
from peft_library.tuners import MMOELoraConfig
from peft.tuners import MMOELoraConfig
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
model.add_adapter(peft_config)
elif model_args.peft_type == "MOELORA":
from peft.tuners import MOELoraConfig
peft_config = MOELoraConfig(target_modules=model_args.lora_target_modules)
model.add_adapter(peft_config)
elif model_args.peft_type == "LORA":
from peft.tuners.lora import LoraConfig
@ -64,10 +83,25 @@ if __name__ == "__main__":
)
model.add_adapter(peft_config)
elif model_args.peft_type == "OLORA":
from peft.tuners import LoraConfig
peft_config = LoraConfig(
target_modules=model_args.lora_target_modules,
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
init_lora_weights="olora"
)
model.add_adapter(peft_config)
else:
peft_config = None
from peft import get_peft_model
if accelerator.is_local_main_process:
print(model)
@ -79,19 +113,21 @@ if __name__ == "__main__":
model=model,
args=training_args,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
eval_dataset=(
dataset[script_args.dataset_test_split]
dataset[script_args.dataset_test_split] # type: ignore
if training_args.eval_strategy != "no"
else None
),
accelerator=accelerator,
reg_args=reg_args,
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if accelerator.is_local_main_process:
print("Saving model")
trainer.save_model(training_args.output_dir)
# trainer.save_model(training_args.output_dir)
model.save_pretrained(training_args.output_dir)
if accelerator.is_local_main_process:
print("Model saved")
@ -109,6 +145,6 @@ if __name__ == "__main__":
# )
# val_dataloader = accelerator.prepare(val_dataloader)
# from utils.evaluate_tool import evaluate_rouge
# from utils.evaluate_tool import evaluate_save
# evaluate_rouge(model, val_dataloader, processor)
# evaluate_save(model, val_dataloader, processor, accelerator)

View File

@ -1,18 +1,19 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name chem \
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
--dataset_name textvqa \
--use_peft \
--peft_type LORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--peft_type MOELORA \
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
--lora_r 8 \
--lora_alpha 32 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--per_device_train_batch_size 3 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 1 \
--num_train_epochs 1 \
--output_dir checkpoint/qwen2_alllinear/ \
--learning_rate 1e-4 \
--learning_rate 5e-5 \
--bf16 \
--torch_dtype bfloat16 \
--logging_steps 30 \

@ -1 +1 @@
Subproject commit 8ee1a4eadda1d83cf65c024fe54364b5bd74e55f
Subproject commit 42a8639e1e827d6f0ab07d87078ff048b20dab19

View File

@ -18,3 +18,16 @@ class ContinualModelConfig(ModelConfig):
"""Model configuration for continual learning."""
peft_type: Optional[str] = None
@dataclass
class ContinualRegularizationArguments:
"""Regularization arguments for continual learning."""
# EWC
ewc_lambda: float = 0.0
ewc_enable: bool = False
# LWF
lwf_lambda: float = 0.0
lwf_enable: bool = False

View File

@ -1,5 +1,6 @@
import evaluate
from accelerate import Accelerator
from typing import TYPE_CHECKING
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
@ -7,10 +8,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
for batch in val_dataloader:
completion = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
pixel_values=batch["pixel_values"],
image_grid_thw=batch["image_grid_thw"],
**batch,
max_length=1000,
)
target = batch["answers_ids"]
@ -27,7 +25,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
print(glue.compute())
def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None):
def evaluate_save(model, val_dataloader, processor, accelerator: Accelerator = None):
import os
mtime = 0
@ -53,6 +51,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
answers = []
completion = model.generate(
**batch,
# max_new_tokens=30,
max_length=1000,
)
generated_text = [
@ -63,20 +62,17 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
generated_text, skip_special_tokens=True
)
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
for i in range(len(generated_text)):
answers.append(
{
"generated": generated_text[i],
"target": target_text[i],
"original": str(origianl[i]),
}
)
import json
world_size = accelerator.process_index
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
for answer in answers:
for i in range(len(generated_text)):
answer = {
"generated": generated_text[i],
"target": target_text[i],
"original": str(origianl[i]),
}
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
f.write(json.dumps(answer) + "\n")
if accelerator.is_local_main_process:
@ -97,3 +93,71 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
# delete file
for file in files:
os.remove(f"results/{mtime}/{file}")
def evaluate_from_jsonl_directory(directory_path):
"""
从指定目录读取所有jsonl文件并计算综合评估结果
Args:
directory_path: 包含jsonl文件的目录路径
Returns:
dict: 包含各项指标和综合结果的字典
"""
import os
import json
# 初始化评估器
evaluate_bleu = evaluate.load("bleu")
evaluate_rouge = evaluate.load("rouge")
evaluate_meteor = evaluate.load("meteor")
# 读取目录下所有jsonl文件
all_data = []
for file in os.listdir(directory_path):
if file.endswith(".jsonl"):
file_path = os.path.join(directory_path, file)
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
data = json.loads(line)
all_data.append(data)
if not all_data:
print(f"未在目录 {directory_path} 中找到有效的jsonl数据")
return None
# 准备数据
predictions = [item["generated"] for item in all_data]
references = [[item["target"]] for item in all_data]
# 批量添加数据
evaluate_bleu.add_batch(predictions=predictions, references=references)
evaluate_rouge.add_batch(predictions=predictions, references=references)
evaluate_meteor.add_batch(predictions=predictions, references=references)
# 计算结果
bleu = evaluate_bleu.compute()
rouge = evaluate_rouge.compute()
meteor = evaluate_meteor.compute()
# 计算综合结果
comprehensive_score = (sum(bleu["precisions"]) + rouge["rougeL"] + meteor["meteor"]) / 6
results = {
"bleu": bleu,
"rouge": rouge,
"meteor": meteor,
"comprehensive_score": comprehensive_score,
"total_samples": len(all_data),
}
print(f"评估完成,共处理 {len(all_data)} 条数据")
print(f"BLEU分数: {bleu}")
print(f"ROUGE分数: {rouge}")
print(f"METEOR分数: {meteor}")
print(f"综合分数: {comprehensive_score}")
return results

View File

@ -1,26 +1,73 @@
# _________________________________________________________
from transformers.trainer import (
Trainer,
_is_peft_model,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
tpu_spmd_dataloader,
logger,
has_length,
sys,
)
from transformers.trainer import *
from transformers import (
TrainingArguments,
)
from .args import ContinualRegularizationArguments
from peft_library.regularizations import EWC, LWF
from torch.nn import CrossEntropyLoss
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
logits = outputs.logits
device = logits.device
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:].to(device)
# Save memory
masks = shift_labels != -100
shift_logits = shift_logits[masks]
shift_labels = shift_labels[masks]
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits, shift_labels)
if num_items_in_batch is None:
loss = loss.mean()
else:
# compat transformers>=4.46
loss = loss.sum() / num_items_in_batch
return loss
class ContinualTrainer(Trainer):
def __init__(
self, model, args, data_collator, train_dataset, eval_dataset, accelerator
self,
model,
args: TrainingArguments,
data_collator,
train_dataset,
eval_dataset,
accelerator,
reg_args: ContinualRegularizationArguments = None,
):
self.accelerator = accelerator
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
# compute_loss_func=ce_loss_func,
)
if reg_args.ewc_enable:
self.ewc_lambda = reg_args.ewc_lambda
from peft_library.regularizations.ewc import EWC
self.EWC = EWC()
# fisher = t
if reg_args.lwf_enable:
self.lwf_lambda = reg_args.lwf_lambda
from peft_library.regularizations.lwf import LWF
self.LWF = LWF()
def create_accelerator_and_postprocess(self):
if self.accelerator is not None:
self.is_deepspeed_enabled = (
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
@ -32,718 +79,79 @@ class ContinualTrainer(Trainer):
return
else:
super().create_accelerator_and_postprocess()
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
def create_optimizer(self):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Setup the optimizer.
Subclass and override for custom behavior.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if (
self.label_smoother is not None or self.compute_loss_func is not None
) and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if labels is not None:
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
if self.compute_loss_func is not None:
loss = self.compute_loss_func(
outputs, labels, num_items_in_batch=num_items_in_batch
)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def _inner_training_loop(
self,
batch_size=None,
args=None,
resume_from_checkpoint=None,
trial=None,
ignore_keys_for_eval=None,
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
if self.state.train_batch_size != self._train_batch_size:
from accelerate.utils import release_memory
(self.model_wrapped,) = release_memory(self.model_wrapped)
self.model_wrapped = self.model
# Check for DeepSpeed *after* the intial pass and modify the config
if self.is_deepspeed_enabled:
# Temporarily unset `self.args.train_batch_size`
original_bs = self.args.per_device_train_batch_size
self.args.per_device_train_batch_size = (
self._train_batch_size // max(1, self.args.n_gpu)
)
self.propagate_args_to_deepspeed(True)
self.args.per_device_train_batch_size = original_bs
self.state.train_batch_size = self._train_batch_size
logger.debug(
f"Currently training with a batch size of: {self._train_batch_size}"
)
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
if self.is_fsdp_xla_v2_enabled:
train_dataloader = tpu_spmd_dataloader(train_dataloader)
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = (
self._train_batch_size * args.gradient_accumulation_steps * args.world_size
)
len_dataloader = None
num_train_tokens = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = (
len_dataloader // args.gradient_accumulation_steps
)
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader, args.max_steps)
* args.gradient_accumulation_steps
)
else:
max_steps = math.ceil(
args.num_train_epochs * num_update_steps_per_epoch
)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = (
self.num_examples(train_dataloader) * args.num_train_epochs
)
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader) * args.num_train_epochs
)
elif (
args.max_steps > 0
): # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader, args.max_steps)
* args.gradient_accumulation_steps
)
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
" (torchrun or torch.distributed.launch (deprecated))."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = (
is_sagemaker_mp_enabled()
or self.is_fsdp_xla_enabled
or self.is_fsdp_enabled
)
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
self.lr_scheduler = None
self._created_lr_scheduler = False
if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps
)
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState(
stateful_callbacks=[
cb
for cb in self.callback_handler.callbacks + [self.control]
if isinstance(cb, ExportableState)
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate / 10,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.learning_rate,
},
]
)
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs
)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
model = self._wrap_model(self.model_wrapped)
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
if use_accelerator_prepare and self.is_fsdp_enabled:
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if delay_optimizer_creation:
if use_accelerator_prepare:
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if self.accelerator.mixed_precision != "fp8":
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
import bitsandbytes
# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# In this case we are in DDP + LOMO, which should be supported
self.optimizer = self.accelerator.prepare(self.optimizer)
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped / 2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped / 2**20}M params")
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# ckpt loading
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(
self.model_wrapped,
resume_from_checkpoint,
load_module_strict=not _is_peft_model(self.model),
)
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(
f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}"
)
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(
f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}"
)
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}"
)
logger.info(
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(
f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}"
)
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
)
self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (
num_update_steps_per_epoch
)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(
f" Continuing training from global step {self.state.global_step}"
)
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
if self.hp_name is not None and self._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.state.trial_name = self.hp_name(self._trial)
if trial is not None:
assignments = (
trial.assignments
if self.hp_search_backend == HPSearchBackend.SIGOPT
else trial
)
self.state.trial_params = hp_params(assignments)
else:
self.state.trial_params = None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(
args, self.state, self.control
)
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_dataloader)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(
args, self.state, self.control
)
if (
epoch == epochs_trained
and resume_from_checkpoint is not None
and steps_trained_in_current_epoch == 0
):
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(
epoch_dataloader, steps_trained_in_current_epoch
)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = num_examples % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
for _ in range(total_updates):
update_step += 1
num_batches = (
args.gradient_accumulation_steps
if update_step != (total_updates - 1)
else remainder
)
batch_samples, num_items_in_batch = self.get_batch_samples(
epoch_iterator, num_batches
)
for i, inputs in enumerate(batch_samples):
step += 1
do_sync_step = (
step + 1
) % args.gradient_accumulation_steps == 0 or (
step + 1
) == steps_in_epoch
# Since we perform prefetching, we need to manually set sync_gradients
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.accelerator.gradient_state._set_sync_gradients(True)
if self.args.include_num_input_tokens_seen:
main_input_name = getattr(
self.model, "main_input_name", "input_ids"
)
if main_input_name not in inputs:
logger.warning(
"Tried to track the number of tokens seen, however the current model is "
"not configured properly to know what item is the input. To fix this, add "
"a `main_input_name` attribute to the model class you are using."
)
else:
input_tokens = inputs[main_input_name].numel()
input_tokens = torch.tensor(
input_tokens, device=self.args.device, dtype=torch.int64
)
self.state.num_input_tokens_seen += (
self.accelerator.gather(input_tokens).sum().cpu().item()
)
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(
args, self.state, self.control
)
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
and self.accelerator.distributed_type
!= DistributedType.DEEPSPEED
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(
model, inputs, num_items_in_batch
)
if (
args.logging_nan_inf_filter
and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss = tr_loss + tr_loss / (
1 + self.state.global_step - self._globalstep_last_logged
)
else:
if tr_loss.device != tr_loss_step.device:
raise ValueError(
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
)
tr_loss = tr_loss + tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
if do_sync_step:
# Since we perform prefetching, we need to manually set sync_gradients to True
self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping
if is_sagemaker_mp_enabled() and args.fp16:
_grad_norm = self.optimizer.clip_master_grads(
args.max_grad_norm
)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
_grad_norm = nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
if (
is_accelerate_available()
and self.accelerator.distributed_type
== DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm()
# In some cases the grad norm may not return a float
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm
self.control = self.callback_handler.on_pre_optimizer_step(
args, self.state, self.control
)
self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(
args, self.state, self.control
)
optimizer_was_run = (
not self.accelerator.optimizer_step_was_skipped
)
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
if not isinstance(
self.lr_scheduler,
torch.optim.lr_scheduler.ReduceLROnPlateau,
):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = (
epoch + (step + 1 + steps_skipped) / steps_in_epoch
)
self.control = self.callback_handler.on_step_end(
args, self.state, self.control
)
self._maybe_log_save_evaluate(
tr_loss,
grad_norm,
model,
trial,
epoch,
ignore_keys_for_eval,
start_time,
)
else:
self.control = self.callback_handler.on_substep_end(
args, self.state, self.control
)
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if (
self.control.should_epoch_stop
or self.control.should_training_stop
):
if is_torch_xla_available():
xm.mark_step()
break
# We also need to break out of the nested loop
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
if step < 0:
logger.warning(
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(
args, self.state, self.control
)
self._maybe_log_save_evaluate(
tr_loss,
grad_norm,
model,
trial,
epoch,
ignore_keys_for_eval,
start_time,
)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_xla_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.control.should_training_stop:
break
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info(
"\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
)
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sure the model has been saved by process 0.
if is_torch_xla_available():
xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
effective_global_step = max(
self.state.global_step, 0.001
) # Avoid ZeroDivisionError
train_loss = self._total_loss_scalar / effective_global_step
metrics = speed_metrics(
"train",
start_time,
num_samples=num_train_samples,
num_steps=self.state.max_steps,
num_tokens=num_train_tokens,
)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(
use_mtime=False, output_dir=run_dir
)
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
if (
self.args.should_save
and self.state.best_model_checkpoint is not None
and self.args.save_total_limit == 1
):
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(
f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit"
)
shutil.rmtree(checkpoint, ignore_errors=True)
self.control = self.callback_handler.on_train_end(
args, self.state, self.control
)
# Wait for the checkpoint to be uploaded.
self._finish_current_push()
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:
self._deactivate_neftune(self.model)
return TrainOutput(self.state.global_step, train_loss, metrics)
return self.optimizer

3709
uv.lock generated

File diff suppressed because it is too large Load Diff