Compare commits
16 Commits
Author | SHA1 | Date | |
---|---|---|---|
42df13390d | |||
266fcd57ad | |||
d686cbc254 | |||
b84ebb03c7 | |||
baccca420a | |||
70c446e548 | |||
0bc1034f35 | |||
3fe2c85f6b | |||
56e46f0e0c | |||
1d2e7b9dcd | |||
c100a59f0e | |||
188ea7df6e | |||
24a6c3c114 | |||
9ca588224d | |||
da99ec4564 | |||
bcb0494f52 |
19
.vscode/settings.json
vendored
Normal file
19
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./src/peft_repo/src/",
|
||||
"./src/transformers_repo/src/",
|
||||
],
|
||||
"python.analysis.exclude": [
|
||||
"dataset/**/*",
|
||||
],
|
||||
"python.languageServer": "Default",
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
// "python.analysis.include": [
|
||||
// "src/**/*"
|
||||
// ],
|
||||
"python.analysis.languageServerMode": "default",
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.analysis.userFileIndexingLimit": 10000,
|
||||
"python.analysis.usePullDiagnostics": false,
|
||||
"python.analysis.importFormat": "relative",
|
||||
}
|
5
dataset/.gitignore
vendored
Normal file
5
dataset/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
derek-thomas*
|
||||
*.lock
|
||||
speechcolab*
|
||||
lmms-lab*
|
||||
downloads/*
|
2
dataset/OCR-VQA-200K/.gitignore
vendored
Normal file
2
dataset/OCR-VQA-200K/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
images/*
|
||||
dataset.json
|
51
dataset/OCR-VQA-200K/download.py
Normal file
51
dataset/OCR-VQA-200K/download.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
import json
|
||||
import urllib.request as ureq
|
||||
import urllib.error
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
# Set the file paths for your Google Drive
|
||||
dataset_path = "./dataset.json"
|
||||
images_path = "./images"
|
||||
download = 1 # Set to 0 if images are already downloaded
|
||||
|
||||
# Load dataset json file
|
||||
with open(dataset_path, "r") as fp:
|
||||
data = json.load(fp)
|
||||
|
||||
# Initialize a counter and a lock for thread-safe counting
|
||||
downloaded_count = 0
|
||||
count_lock = threading.Lock()
|
||||
|
||||
|
||||
# Function to download an image
|
||||
def download_image(k):
|
||||
global downloaded_count
|
||||
imageURL = data[k]["imageURL"]
|
||||
ext = os.path.splitext(imageURL)[1]
|
||||
outputFile = os.path.join(images_path, f"{k}{ext}")
|
||||
|
||||
# Only download the image if it doesn't exist
|
||||
if not os.path.exists(outputFile):
|
||||
try:
|
||||
ureq.urlretrieve(imageURL, outputFile)
|
||||
|
||||
with count_lock:
|
||||
downloaded_count += 1
|
||||
if downloaded_count % 100 == 0:
|
||||
print(f"{downloaded_count} images downloaded.")
|
||||
except urllib.error.URLError as e:
|
||||
print(f"Error downloading {outputFile}: {e}")
|
||||
|
||||
|
||||
# Download images using multiple threads
|
||||
if download == 1:
|
||||
if not os.path.exists(images_path):
|
||||
os.makedirs(images_path)
|
||||
|
||||
# Create a thread pool and download the images in parallel
|
||||
# Increase max_workers to potentially speed up downloads for many small files.
|
||||
# The optimal number may vary based on your network and the server's capacity.
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=400) as executor:
|
||||
executor.map(download_image, data.keys())
|
5
dataset/TextVQA/.gitignore
vendored
Normal file
5
dataset/TextVQA/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
images/test_images/*
|
||||
images/train_images/*
|
||||
TextVQA_0.5.1_test.json
|
||||
TextVQA_0.5.1_train.json
|
||||
TextVQA_0.5.1_val.json
|
3
dataset/vizwiz/Annotations/.gitignore
vendored
Normal file
3
dataset/vizwiz/Annotations/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
train.json
|
||||
test.json
|
||||
val.json
|
3
dataset/vizwiz/images/.gitignore
vendored
Normal file
3
dataset/vizwiz/images/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
val/*
|
||||
train/*
|
||||
test/*
|
@ -2,11 +2,14 @@
|
||||
dependencies = [
|
||||
"absl-py>=2.1.0",
|
||||
"accelerate==1.2.1",
|
||||
"calflops>=0.3.2",
|
||||
"datasets==3.2.0",
|
||||
"deepspeed==0.16.2",
|
||||
"evaluate==0.4.3",
|
||||
"huggingface-hub==0.30.1",
|
||||
"librosa>=0.10.2.post1",
|
||||
"markupsafe==2.1.5",
|
||||
"ms-swift>=1.3.0",
|
||||
"nltk>=3.9.1",
|
||||
"numba>=0.60.0",
|
||||
"peft==0.14.0",
|
||||
@ -18,11 +21,12 @@ dependencies = [
|
||||
"safetensors>=0.5.2",
|
||||
"setuptools>=70.0.0",
|
||||
"soundfile>=0.13.0",
|
||||
"torch==2.5.1+cu124",
|
||||
"torchaudio==2.5.1+cu124",
|
||||
"torchvision==0.20.1+cu124",
|
||||
"torch==2.6.0",
|
||||
"torchaudio==2.6.0",
|
||||
"torchvision==0.21.0",
|
||||
"transformers==4.48.0",
|
||||
"trl==0.13.0",
|
||||
"wandb>=0.19.4",
|
||||
"wheel>=0.45.1",
|
||||
]
|
||||
description = "Add your description here"
|
||||
@ -53,14 +57,14 @@ concurrent-builds = 4
|
||||
name = "pytorch"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["ms-swift"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
exclude = "transformers_repo|peft_repo|.venv"
|
||||
target-version = ["py311"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["--color=yes", "--durations=0", "-v", "--capture=tee-sys"]
|
||||
norecursedirs = [
|
||||
"src/transformers_repo",
|
||||
"src/peft_repo",
|
||||
".venv",
|
||||
]
|
||||
norecursedirs = ["src/transformers_repo", "src/peft_repo", ".venv"]
|
||||
|
5
src/.gitignore
vendored
5
src/.gitignore
vendored
@ -1 +1,4 @@
|
||||
checkpoint/*
|
||||
checkpoint/*
|
||||
wandb/*
|
||||
test.py
|
||||
results/
|
@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_accumulation_steps: 2
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
distributed_type: DEEPSPEED
|
||||
@ -11,7 +11,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
num_processes: 3
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
@ -12,7 +12,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
@ -1,5 +1,6 @@
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from .format import Conversation, ConverstationText, ConverstationImage, DatasetOutput
|
||||
import json
|
||||
import os
|
||||
|
||||
@ -25,9 +26,9 @@ class ChemDataset(Dataset):
|
||||
|
||||
def _vis_processor(self, image: Image.Image):
|
||||
width, height = image.size
|
||||
if width > 800 or height > 800:
|
||||
if width > 500 or height > 500:
|
||||
max_size = max(width, height)
|
||||
ratio = 800 / max_size
|
||||
ratio = 500 / max_size
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
@ -85,31 +86,30 @@ class ChemDataset(Dataset):
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sample["system"],
|
||||
},
|
||||
Conversation(
|
||||
type="system",
|
||||
content=[ConverstationText(type="text", text=sample["system"])],
|
||||
),
|
||||
Conversation(
|
||||
type="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text", text=f"[vqa] {question.replace('<image>','')}"
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[vqa] {question.replace('<image>','')}",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
||||
),
|
||||
Conversation(
|
||||
type="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||
),
|
||||
]
|
||||
return {
|
||||
"image": image,
|
||||
"chat": chat,
|
||||
}
|
||||
|
||||
return DatasetOutput(
|
||||
images=[image],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
)
|
||||
|
||||
|
||||
class ChemDatasetForGeneration(ChemDataset):
|
||||
@ -127,27 +127,20 @@ class ChemDatasetForGeneration(ChemDataset):
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sample["system"],
|
||||
},
|
||||
Conversation(
|
||||
type="system",
|
||||
content=[ConverstationText(type="text", text=sample["system"])],
|
||||
),
|
||||
Conversation(
|
||||
type="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text", text=f"[vqa] {question.replace('<image>','')}"
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[vqa] {question.replace('<image>','')}",
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
]
|
||||
from .format import DatasetOutput
|
||||
|
||||
return DatasetOutput(
|
||||
images=[image],
|
||||
|
@ -18,8 +18,13 @@ class GigaspeechDataset(Dataset):
|
||||
|
||||
self.audio_processor = audio_processor
|
||||
self.text_processor = text_processor
|
||||
gs = load_dataset("speechcolab/gigaspeech", "xs")
|
||||
self.data = gs[split]
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = gs[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -51,10 +56,10 @@ class GigaspeechDataset(Dataset):
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
audio=[(audio, sampling_rate)],
|
||||
audios=[(audio, sampling_rate)],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||
@ -83,11 +88,11 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
audio=[(audio, sampling_rate)],
|
||||
audios=[(audio, sampling_rate)],
|
||||
chat=chat,
|
||||
answer=text,
|
||||
original=sample,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_gigaspeech():
|
||||
@ -103,3 +108,21 @@ def test_gigaspeech():
|
||||
print(dataset[0])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
from .factory import register_dataset
|
||||
|
||||
dataset = {
|
||||
"train": GigaspeechDataset(split="train"),
|
||||
"test": GigaspeechDataset(split="test"),
|
||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
register_dataset(
|
||||
dataset_name="gigaspeech",
|
||||
dataset=dataset,
|
||||
tag=["audio", "text"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gigaspeech()
|
||||
|
@ -22,14 +22,20 @@ class OCRVQADataset(Dataset):
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
from .vis_processor import size_processor
|
||||
|
||||
self.vis_processor = (
|
||||
vis_processor if vis_processor is not None else self._vis_processor
|
||||
vis_processor if vis_processor is not None else size_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
|
||||
elif split == "test":
|
||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
|
||||
self.split = split
|
||||
self.ann_path = ann_path
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
@ -64,24 +70,6 @@ class OCRVQADataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _vis_processor(self, image: Image.Image):
|
||||
width, height = image.size
|
||||
if width > 500 or height > 500:
|
||||
max_size = max(width, height)
|
||||
ratio = 500 / max_size
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
if width < 28 or height < 28:
|
||||
min_size = min(width, height)
|
||||
ratio = 28 / min_size + 1
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
return image
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image: Image.Image = Image.open(
|
||||
@ -117,7 +105,7 @@ class OCRVQADataset(Dataset):
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
@ -153,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
from .factory import register_dataset
|
||||
from .format import dataset_dir as base_path
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
"generation": OCRVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])
|
||||
|
139
src/dataset_library/RefCOCODataset.py
Normal file
139
src/dataset_library/RefCOCODataset.py
Normal file
@ -0,0 +1,139 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class RefCOCODataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
vis_processor=None,
|
||||
text_processor=None,
|
||||
split: Literal["val", "test"] = "val",
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=question,
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=answer)],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class RefCOCODatasetForGeneration(RefCOCODataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"{question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_RefCOCO():
|
||||
dataset = RefCOCODataset(
|
||||
split="val",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = RefCOCODatasetForGeneration(
|
||||
split="test",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCODataset(split="val"),
|
||||
"test": RefCOCODataset(split="test"),
|
||||
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcoco",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCO()
|
139
src/dataset_library/RefCOCOPlusDataset.py
Normal file
139
src/dataset_library/RefCOCOPlusDataset.py
Normal file
@ -0,0 +1,139 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class RefCOCOplusDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
vis_processor=None,
|
||||
text_processor=None,
|
||||
split: Literal["val", "testA"] = "val",
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=question,
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=answer)],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class RefCOCOplusDatasetForGeneration(RefCOCOplusDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"{question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_RefCOCOplus():
|
||||
dataset = RefCOCOplusDataset(
|
||||
split="val",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = RefCOCOplusDatasetForGeneration(
|
||||
split="testA",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOplusDataset(split="val"),
|
||||
"test": RefCOCOplusDataset(split="testA"),
|
||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcocoplus",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCOplus()
|
139
src/dataset_library/RefCOCOgDataset.py
Normal file
139
src/dataset_library/RefCOCOgDataset.py
Normal file
@ -0,0 +1,139 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class RefCOCOgDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
vis_processor=None,
|
||||
text_processor=None,
|
||||
split: Literal["val", "test"] = "val",
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=question,
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=answer)],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"{question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_RefCOCOg():
|
||||
dataset = RefCOCOgDataset(
|
||||
split="val",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = RefCOCOgDatasetForGeneration(
|
||||
split="test",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOgDataset(split="val"),
|
||||
"test": RefCOCOgDataset(split="test"),
|
||||
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcocog",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCOg()
|
139
src/dataset_library/ScienceQADataset.py
Normal file
139
src/dataset_library/ScienceQADataset.py
Normal file
@ -0,0 +1,139 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
|
||||
class ScienceQADataset(Dataset):
|
||||
def __init__(self, vis_processor=None, text_processor=None, split="train"):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
choices = sample["choices"]
|
||||
task = sample["task"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[{task}] '{question}' choose from '{choices}'",
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=choices[answer])],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class ScienceQADatasetForGeneration(ScienceQADataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
choices = sample["choices"]
|
||||
task = sample["task"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[{task}] '{question}' choose from '{choices}'",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=choices[answer],
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_scienceQA():
|
||||
dataset = ScienceQADataset(
|
||||
split="train",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = ScienceQADatasetForGeneration(
|
||||
split="train",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": ScienceQADataset(split="train"),
|
||||
"test": ScienceQADataset(split="test"),
|
||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="scienceqa",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_scienceQA()
|
@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
|
||||
vis_processor if vis_processor is not None else self._vis_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.split = split
|
||||
self.ann_path = ann_path
|
||||
self.vis_root = vis_root
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(
|
||||
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
||||
vis_root=Path(vis_root, "train_images"),
|
||||
Path(self.ann_path, "TextVQA_0.5.1_train.json"),
|
||||
vis_root=Path(self.vis_root, "train_images"),
|
||||
)
|
||||
elif split == "test":
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(
|
||||
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
||||
vis_root=Path(vis_root, "train_images"),
|
||||
Path(self.ann_path, "TextVQA_0.5.1_val.json"),
|
||||
vis_root=Path(self.vis_root, "train_images"),
|
||||
)
|
||||
|
||||
# self.instruction_pool = [
|
||||
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class TextVQADatasetForGeneration(TextVQADataset):
|
||||
@ -158,16 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_dataset():
|
||||
vis_root = "/home/zyy/dataset/TextVQA/images"
|
||||
ann_path = "/home/zyy/dataset/TextVQA"
|
||||
dataset = TextVQADataset(vis_root, ann_path)
|
||||
for i in range(10):
|
||||
print(dataset[i])
|
||||
from .format import dataset_dir
|
||||
|
||||
dataset = {
|
||||
"train": TextVQADataset(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="train",
|
||||
),
|
||||
"test": TextVQADataset(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
"generation": TextVQADatasetForGeneration(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dataset()
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset("textvqa", dataset, tag=["text", "image"])
|
||||
|
168
src/dataset_library/VizWizDataset.py
Normal file
168
src/dataset_library/VizWizDataset.py
Normal file
@ -0,0 +1,168 @@
|
||||
from PIL import Image
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class VizWizDataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
self.ann_path = ann_path
|
||||
|
||||
from .vis_processor import size_processor
|
||||
|
||||
self.vis_processor = (
|
||||
vis_processor if vis_processor is not None else size_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(Path(self.ann_path, "train.json"))
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(Path(self.ann_path, "val.json"))
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||
# ]
|
||||
|
||||
def create_data(self, ann_path):
|
||||
processed_data = []
|
||||
with open(ann_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in range(len(data)):
|
||||
if (
|
||||
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
|
||||
and data[i]["answerable"]
|
||||
):
|
||||
imageFile = data[i]["image"]
|
||||
processed_data.append(
|
||||
{
|
||||
"question": data[i]["question"],
|
||||
"answer": data[i]["answers"][0]["answer"],
|
||||
"image_path": imageFile,
|
||||
"original": data[i],
|
||||
}
|
||||
)
|
||||
return processed_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image: Image.Image = Image.open(
|
||||
os.path.join(self.vis_root, sample["image_path"])
|
||||
).convert("RGB")
|
||||
# resize image
|
||||
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class VizWizDatasetForGeneration(VizWizDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||
"RGB"
|
||||
)
|
||||
# resize image
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
return DatasetOutput(
|
||||
images=[image],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
) # type: ignore
|
||||
|
||||
|
||||
from .format import dataset_dir
|
||||
|
||||
dataset = {
|
||||
"train": VizWizDataset(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="train",
|
||||
),
|
||||
"test": VizWizDataset(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
"generation": VizWizDatasetForGeneration(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="vizwiz",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
@ -0,0 +1 @@
|
||||
from .factory import get_dataset
|
@ -1,81 +1,78 @@
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Literal
|
||||
from typing import Literal, List
|
||||
from pathlib import Path
|
||||
from dataset_library.format import dataset_dir
|
||||
|
||||
MAPPING_NAME_TO_DATASET: dict[
|
||||
str, dict[Literal["train", "test", "generation"], Dataset]
|
||||
] = {}
|
||||
|
||||
|
||||
def register_dataset(
|
||||
dataset_name: str,
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset],
|
||||
tag: List[Literal["image", "text", "audio", "video"]] = [],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Register a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (`str`):
|
||||
The name of the dataset.
|
||||
dataset (`Dataset`):
|
||||
The dataset to register.
|
||||
"""
|
||||
dataset_name = dataset_name.lower()
|
||||
if dataset_name in MAPPING_NAME_TO_DATASET:
|
||||
raise ValueError(f"Dataset {dataset_name} already registered.")
|
||||
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
|
||||
|
||||
|
||||
from .GigaspeechDataset import (
|
||||
GigaspeechDataset,
|
||||
GigaspeechDatasetForGeneration,
|
||||
)
|
||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||
from .ScienceQADataset import (
|
||||
ScienceQADataset,
|
||||
ScienceQADatasetForGeneration,
|
||||
)
|
||||
from .RefCOCODataset import (
|
||||
RefCOCODataset,
|
||||
RefCOCODatasetForGeneration,
|
||||
)
|
||||
from .RefCOCOgDataset import (
|
||||
RefCOCOgDataset,
|
||||
RefCOCOgDatasetForGeneration,
|
||||
)
|
||||
from .RefCOCOPlusDataset import (
|
||||
RefCOCOplusDataset,
|
||||
RefCOCOplusDatasetForGeneration,
|
||||
)
|
||||
from .VizWizDataset import (
|
||||
VizWizDataset,
|
||||
VizWizDatasetForGeneration,
|
||||
)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
dataset_name, base_path="/home/zyy/dataset"
|
||||
dataset_name: str, base_path=dataset_dir
|
||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||
if dataset_name == "ocrvqa200k":
|
||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
"""
|
||||
Get a dataset by name.
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
"generation": OCRVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
if dataset_name == "chem":
|
||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": ChemDataset(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="train",
|
||||
),
|
||||
"test": ChemDataset(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="test",
|
||||
),
|
||||
"generation": ChemDatasetForGeneration(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
if dataset_name == "gigaspeech":
|
||||
from .GigaspeechDataset import GigaspeechDataset, GigaspeechDatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": GigaspeechDataset(split="train"),
|
||||
"test": GigaspeechDataset(split="test"),
|
||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
if dataset_name == "textvqa":
|
||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": TextVQADataset(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="train",
|
||||
),
|
||||
"test": TextVQADataset(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
"generation": TextVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
return dataset
|
||||
Args:
|
||||
dataset_name (`str`):
|
||||
The name of the dataset.
|
||||
base_path (`str`):
|
||||
The base path to the dataset.
|
||||
"""
|
||||
dataset_name = dataset_name.lower()
|
||||
if dataset_name not in MAPPING_NAME_TO_DATASET:
|
||||
raise ValueError(f"Dataset {dataset_name} not registered.")
|
||||
for key in MAPPING_NAME_TO_DATASET[dataset_name]:
|
||||
MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
|
||||
return MAPPING_NAME_TO_DATASET[dataset_name]
|
||||
|
@ -1,6 +1,9 @@
|
||||
from typing import Any, Tuple, TypedDict, Literal, Optional
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
dataset_dir = Path(__file__).resolve().parent.parent.parent / "dataset"
|
||||
|
||||
|
||||
class ConverstationText(TypedDict):
|
||||
@ -25,8 +28,8 @@ class Conversation(TypedDict):
|
||||
|
||||
|
||||
class DatasetOutput(TypedDict):
|
||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
||||
chat: list[Conversation]
|
||||
answer: Optional[str]
|
||||
original: Any
|
||||
images: Optional[list[Image.Image]]
|
||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
||||
|
@ -1,37 +1,22 @@
|
||||
from .factory import get_dataset
|
||||
import pytest
|
||||
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
|
||||
|
||||
# Get all registered dataset names for parameterization
|
||||
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
||||
|
||||
|
||||
def test_gigaspeech():
|
||||
dataset = get_dataset("gigaspeech")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
@pytest.mark.parametrize("dataset_name", dataset_names)
|
||||
def test_registered_datasets(dataset_name):
|
||||
dataset = get_dataset(dataset_name)
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
# Test train split
|
||||
assert "train" in dataset, f"Train split not found in {dataset_name}"
|
||||
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
|
||||
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
|
||||
|
||||
|
||||
def test_chem():
|
||||
dataset = get_dataset("chem")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_ocrvqa200k():
|
||||
dataset = get_dataset("ocrvqa200k")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_textvqa():
|
||||
dataset = get_dataset("textvqa")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
# Test test split
|
||||
assert "test" in dataset, f"Test split not found in {dataset_name}"
|
||||
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
|
||||
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore
|
||||
|
20
src/dataset_library/vis_processor.py
Normal file
20
src/dataset_library/vis_processor.py
Normal file
@ -0,0 +1,20 @@
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def size_processor(image: Image.Image):
|
||||
width, height = image.size
|
||||
if width > 500 or height > 500:
|
||||
max_size = max(width, height)
|
||||
ratio = 500 / max_size
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
if width < 28 or height < 28:
|
||||
min_size = min(width, height)
|
||||
ratio = 28 / min_size + 1
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
return image
|
@ -19,67 +19,39 @@ from trl import (
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
from utils.args import (
|
||||
ContinualScriptArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser(
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||
(
|
||||
ContinualScriptArguments,
|
||||
TrainingArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
)
|
||||
)
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 0 == 1:
|
||||
if TYPE_CHECKING:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ModelConfig()
|
||||
model_args = ContinualModelConfig()
|
||||
reg_args = ContinualRegularizationArguments()
|
||||
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
from transformers import (
|
||||
Qwen2VLProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModel,
|
||||
)
|
||||
from peft.peft_model import PeftModelForCausalLM
|
||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# from peft_library import get_peft_model
|
||||
|
||||
processor = Qwen2VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
from model_library.qwen2vl import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||
model_args=model_args, training_args=training_args
|
||||
)
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
@ -88,8 +60,7 @@ if __name__ == "__main__":
|
||||
|
||||
accelerator = create_accelerator_and_postprocess(training_args)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print(model)
|
||||
accelerator.print(model)
|
||||
|
||||
for dataset_name in script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_name)
|
||||
@ -99,12 +70,15 @@ if __name__ == "__main__":
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
bs = 3 if dataset_name not in ["scienceqa"] else 1
|
||||
accelerator.print("Batch size:", bs)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
dataset[script_args.dataset_generation_split],
|
||||
batch_size=3,
|
||||
batch_size=bs,
|
||||
collate_fn=collate_fn_for_evaluate,
|
||||
)
|
||||
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
||||
from utils.evaluate_tool import evaluate_rouge, evalute_save
|
||||
from utils.evaluate_tool import evaluate_rouge, evaluate_save
|
||||
|
||||
evalute_save(model, val_dataloader, processor, accelerator)
|
||||
evaluate_save(model, val_dataloader, processor, accelerator)
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \
|
||||
--dataset_name CHEM \
|
||||
--dataset_name chem \
|
||||
--use_peft \
|
||||
--peft_type MMOELORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
|
@ -5,9 +5,12 @@ from trl import (
|
||||
get_quantization_config,
|
||||
)
|
||||
from utils.args import ContinualModelConfig
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
def get_model(model_args: ContinualModelConfig):
|
||||
def get_model(
|
||||
model_args: ContinualModelConfig, training_args: TrainingArguments = None
|
||||
):
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
@ -26,12 +29,20 @@ def get_model(model_args: ContinualModelConfig):
|
||||
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
if training_args is not None:
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
processor = Qwen2VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
@ -49,11 +60,18 @@ def get_model(model_args: ContinualModelConfig):
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
||||
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
||||
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
if training_args is not None:
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
processor = Qwen2AudioProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
@ -68,4 +86,63 @@ def get_model(model_args: ContinualModelConfig):
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
if training_args is not None:
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
||||
from transformers.models.qwen2_5_omni import (
|
||||
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||
Qwen2_5OmniProcessor,
|
||||
)
|
||||
|
||||
if training_args is not None:
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
||||
|
@ -1,4 +1,8 @@
|
||||
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||
from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]
|
||||
__all__ = [
|
||||
"collate_fn_for_train",
|
||||
"collate_fn_for_evaluate",
|
||||
"Qwen2VLForConditionalGeneration_modified",
|
||||
]
|
||||
|
@ -1,20 +1,21 @@
|
||||
from transformers import Qwen2AudioProcessor
|
||||
from dataset_library.format import Conversation
|
||||
|
||||
|
||||
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
def collate_fn_for_train(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
for example in examples
|
||||
]
|
||||
audios = [example["audio"][0] for example in examples]
|
||||
audios = [example["audios"][0] for example in examples]
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
sampling_rate=examples[0]["audio"][1],
|
||||
sampling_rate=examples[0]["audios"][1],
|
||||
)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
@ -60,7 +61,7 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
def collate_fn_for_evaluate(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
@ -69,10 +70,16 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
audios = [example["audio"] for example in examples]
|
||||
audios = [example["audios"] for example in examples]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, audios=audios, return_tensors="pt", padding=True)
|
||||
batch = processor(
|
||||
text=texts,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
sampling_rate=examples[0]["audios"][1],
|
||||
)
|
||||
|
||||
answers = [example["answer"] for example in examples]
|
||||
answers = processor(text=answers, return_tensors="pt", padding=True)
|
||||
@ -85,3 +92,4 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
# answers_ids torch.Size([3, 10])
|
||||
# answers_mask torch.Size([3, 10])
|
||||
return batch
|
||||
|
||||
|
@ -1,4 +1,9 @@
|
||||
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||
from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]
|
||||
# from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
__all__ = [
|
||||
"collate_fn_for_train",
|
||||
"collate_fn_for_evaluate",
|
||||
"Qwen2VLForConditionalGeneration_modified",
|
||||
]
|
||||
|
@ -1,16 +1,33 @@
|
||||
from transformers import Qwen2VLProcessor
|
||||
# from transformers import Qwen2VLProcessor
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "transformers_repo/src/")
|
||||
sys.path.insert(0, "peft_repo/src/")
|
||||
import transformers
|
||||
import peft
|
||||
from dataset_library.format import DatasetOutput
|
||||
|
||||
import torch
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Qwen2VLProcessor
|
||||
from transformers import Qwen2_5_VLProcessor
|
||||
|
||||
|
||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: "Qwen2VLProcessor"):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
images = [example["images"] for example in examples]
|
||||
images = [
|
||||
example["images"] for example in examples if example["images"][0] is not None
|
||||
]
|
||||
images = images if len(images) > 0 else None
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
|
||||
@ -57,7 +74,7 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
def collate_fn_for_evaluate(examples, processor: "Qwen2VLProcessor"):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
@ -66,7 +83,10 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
images = [example["images"] for example in examples]
|
||||
images = [
|
||||
example["images"] for example in examples if example["images"][0] is not None
|
||||
]
|
||||
images = images if len(images) > 0 else None
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
@ -74,7 +94,6 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
answers = [example["answer"] for example in examples]
|
||||
answers = processor(text=answers, return_tensors="pt", padding=True)
|
||||
batch["answers_ids"] = answers["input_ids"]
|
||||
batch["answers_mask"] = answers["attention_mask"]
|
||||
batch["original_data"] = [example["original"] for example in examples]
|
||||
# input_ids torch.Size([3, 370])
|
||||
# attention_mask torch.Size([3, 370])
|
||||
@ -83,3 +102,68 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
# answers_ids torch.Size([3, 10])
|
||||
# answers_mask torch.Size([3, 10])
|
||||
return batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
from PIL import Image
|
||||
|
||||
# 随机生成一个图片
|
||||
import numpy as np
|
||||
|
||||
random_image = Image.fromarray(
|
||||
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
)
|
||||
example = {
|
||||
"chat": [
|
||||
# {"role": "user", "content": "What is the capital of France?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is the capital of France?",
|
||||
},
|
||||
],
|
||||
}, # Assuming no image for this example
|
||||
{"role": "assistant", "content": "The capital of France is Paris."},
|
||||
],
|
||||
"images": [
|
||||
random_image,
|
||||
random_image,
|
||||
random_image,
|
||||
], # Assuming no images for this example
|
||||
}
|
||||
batch = collate_fn_for_train([example], processor)
|
||||
# print(batch)
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
print(f"{k}: {v.shape}")
|
||||
else:
|
||||
print(f"{k}: {v}")
|
||||
# input_ids: torch.Size([1, 101])
|
||||
# attention_mask: torch.Size([1, 101])
|
||||
# pixel_values: torch.Size([256, 1176])
|
||||
# image_grid_thw: torch.Size([1, 3])
|
||||
# labels: torch.Size([1, 101])
|
||||
# Load model directly
|
||||
from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-3B", torch_dtype="auto", device_map="auto")
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")
|
||||
|
||||
|
||||
|
@ -73,7 +73,6 @@ from peft.tuners import (
|
||||
from .tuners import MMOELoraModel, MMOELoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
from peft.utils import _prepare_prompt_learning_config
|
||||
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -46,7 +46,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers.utils import PushToHubMixin
|
||||
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG
|
||||
|
||||
from peft import __version__
|
||||
from peft.config import PeftConfig
|
||||
|
19
src/peft_library/regularizations/__init__.py
Normal file
19
src/peft_library/regularizations/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
class RegularizationMethod:
|
||||
"""RegularizationMethod implement regularization strategies.
|
||||
RegularizationMethod is a callable.
|
||||
The method `update` is called to update the loss, typically at the end
|
||||
of an experience.
|
||||
"""
|
||||
|
||||
def pre_adapt(self, agent, exp):
|
||||
pass # implementation may be empty if adapt is not needed
|
||||
|
||||
def post_adapt(self, agent, exp):
|
||||
pass # implementation may be empty if adapt is not needed
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
from .ewc import EWC
|
||||
from .lwf import LWF
|
58
src/peft_library/regularizations/ewc.py
Normal file
58
src/peft_library/regularizations/ewc.py
Normal file
@ -0,0 +1,58 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
|
||||
class EWC(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
The method applies knowledge distilllation to mitigate forgetting.
|
||||
The teacher is the model checkpoint after the last experience.
|
||||
"""
|
||||
|
||||
def __init__(self, EWC_lambda=1, temperature=2):
|
||||
"""
|
||||
:param alpha: distillation hyperparameter. It can be either a float
|
||||
number or a list containing alpha for each experience.
|
||||
:param temperature: softmax temperature for distillation
|
||||
"""
|
||||
self.EWC_lambda = EWC_lambda
|
||||
self.temperature = temperature
|
||||
self.fisher = {}
|
||||
self.optpar = {}
|
||||
""" In Avalanche, targets of different experiences are not ordered.
|
||||
As a result, some units may be allocated even though their
|
||||
corresponding class has never been seen by the model.
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
|
||||
def adapt(self, output, model, **kwargs):
|
||||
ewc_loss = 0
|
||||
for n, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
dev = p.device
|
||||
l = (
|
||||
self.EWC_lambda
|
||||
* self.fisher[n].to(dev)
|
||||
* (p.data - self.optpar[n].to(dev)).pow(2)
|
||||
)
|
||||
ewc_loss += l.sum()
|
||||
output["loss"] += ewc_loss
|
||||
return output
|
||||
|
||||
def init_epoch(self, model):
|
||||
"""Update the previous logits for the given question id."""
|
||||
optpar = {}
|
||||
fisher = {}
|
||||
for n, p in model.module.base_model.model.named_parameters():
|
||||
if p.requires_grad:
|
||||
fisher[n] = torch.zeros(p.data.shape)
|
||||
optpar[n] = p.clone().cpu().data
|
||||
|
||||
def update_fisher(self, model):
|
||||
"""Update the fisher information for the given question id."""
|
||||
for n, p in model.module.base_model.model.named_parameters():
|
||||
if p.requires_grad:
|
||||
fisher = self.fisher[n]
|
||||
fisher += p.grad.data.pow(2).cpu()
|
||||
self.fisher[n] = fisher
|
67
src/peft_library/regularizations/lwf.py
Normal file
67
src/peft_library/regularizations/lwf.py
Normal file
@ -0,0 +1,67 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
|
||||
class LWF(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
The method applies knowledge distilllation to mitigate forgetting.
|
||||
The teacher is the model checkpoint after the last experience.
|
||||
"""
|
||||
|
||||
def __init__(self, LWF_lambda=1, temperature=2):
|
||||
"""
|
||||
:param alpha: distillation hyperparameter. It can be either a float
|
||||
number or a list containing alpha for each experience.
|
||||
:param temperature: softmax temperature for distillation
|
||||
"""
|
||||
self.LWF_lambda = LWF_lambda
|
||||
self.temperature = temperature
|
||||
self.previous_logits = {}
|
||||
""" In Avalanche, targets of different experiences are not ordered.
|
||||
As a result, some units may be allocated even though their
|
||||
corresponding class has never been seen by the model.
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
|
||||
def adapt(self, output, **kwargs):
|
||||
def modified_kl_div(old, new):
|
||||
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
||||
|
||||
def smooth(logits, temp, dim):
|
||||
log = logits ** (1 / temp)
|
||||
return log / torch.sum(log, dim).unsqueeze(1)
|
||||
|
||||
lwf_loss = []
|
||||
|
||||
soft = torch.nn.Softmax(dim=1)
|
||||
|
||||
previous_keys = self.previous_logits.keys()
|
||||
|
||||
for index, question_id in enumerate(iterable=kwargs["question_ids"]):
|
||||
if question_id in previous_keys:
|
||||
previous_logits = self.previous_logits[question_id]
|
||||
current_logits = output["logits"][index]
|
||||
short_index = min(len(previous_logits), len(current_logits))
|
||||
previous_logits = previous_logits[:short_index]
|
||||
current_logits = current_logits[:short_index]
|
||||
lwf_loss.append(
|
||||
modified_kl_div(
|
||||
old=smooth(
|
||||
logits=soft(previous_logits).to(current_logits.device),
|
||||
temp=2,
|
||||
dim=1,
|
||||
),
|
||||
new=smooth(logits=soft(current_logits), temp=2, dim=1),
|
||||
)
|
||||
)
|
||||
if len(lwf_loss) > 0:
|
||||
output["loss"] += self.LWF_lambda * torch.stack(
|
||||
tensors=lwf_loss, dim=0
|
||||
).sum(dim=0)
|
||||
return output
|
||||
|
||||
def update_previous_logits(self, question_id, logits):
|
||||
"""Update the previous logits for the given question id."""
|
||||
self.previous_logits[question_id] = logits
|
@ -1 +1 @@
|
||||
Subproject commit 1a92c6fe39143de2ad4247d2ad8beeed2d3500e8
|
||||
Subproject commit a6e39fafb4e3ccce27671ce14afe62fb754c83c2
|
15
src/scripts/eval_omni.sh
Executable file
15
src/scripts/eval_omni.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml evaluation.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type MOELORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--output_dir ./checkpoint/qwen2_5omni_moelora/ \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16
|
||||
# --eval_strategy epoch \
|
25
src/scripts/train_omni_moelora.sh
Executable file
25
src/scripts/train_omni_moelora.sh
Executable file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type MOELORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir checkpoint/qwen2_5omni_moelora/ \
|
||||
--learning_rate 2e-4 \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type cosine \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 300 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.1 \
|
||||
--eval_strategy steps \
|
||||
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500
|
25
src/scripts/train_omni_olora.sh
Executable file
25
src/scripts/train_omni_olora.sh
Executable file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type OLORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir checkpoint/qwen2_5omni_olora/ \
|
||||
--learning_rate 2e-4 \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type cosine \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 300 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.1 \
|
||||
--eval_strategy steps \
|
||||
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500
|
31
src/test_evalutae.py
Normal file
31
src/test_evalutae.py
Normal file
@ -0,0 +1,31 @@
|
||||
import evaluate
|
||||
|
||||
# Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, and CIDEr
|
||||
example = {
|
||||
"generated": "The cat sat on the mat.",
|
||||
"target": "The cat is sitting on the mat.",
|
||||
"original": "The cat is sitting on the mat.",
|
||||
}
|
||||
evaluate_bleu = evaluate.load("bleu")
|
||||
evaluate_rouge = evaluate.load("rouge")
|
||||
evaluate_meteor = evaluate.load("meteor")
|
||||
|
||||
evaluate_bleu.add_batch(
|
||||
predictions=[example["generated"]],
|
||||
references=[[example["target"]]],
|
||||
)
|
||||
evaluate_rouge.add_batch(
|
||||
predictions=[example["generated"]],
|
||||
references=[[example["target"]]],
|
||||
)
|
||||
evaluate_meteor.add_batch(
|
||||
predictions=[example["generated"]],
|
||||
references=[[example["target"]]],
|
||||
)
|
||||
|
||||
bleu = evaluate_bleu.compute()
|
||||
rouge = evaluate_rouge.compute()
|
||||
meteor = evaluate_meteor.compute()
|
||||
|
||||
comprehensive_results = sum(bleu['precisions']) + rouge['rougeL'] + meteor['meteor']
|
||||
print("Comprehensive Results:", comprehensive_results/6)
|
35
src/todo.md
35
src/todo.md
@ -12,5 +12,36 @@
|
||||
[2025.01.03]
|
||||
|
||||
- [ ] 处理量化逻辑
|
||||
- [ ] 严查moelora的原始代码,太粗糙了😡
|
||||
- [ ] 未知原因trainer后处理时间长
|
||||
- [X] 严查moelora的原始代码,太粗糙了😡
|
||||
- [X] 未知原因trainer后处理时间长
|
||||
|
||||
[2025.01.19]
|
||||
|
||||
- [x] 多个数据集引入
|
||||
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
||||
- [ ] 引入EWC和LWF
|
||||
|
||||
[2025.05.15]
|
||||
|
||||
- [x] vizwiz处理
|
||||
|
||||
[2025.05.16]
|
||||
|
||||
- [ ] 处理不同的持续学习框架,使得整体框架能够兼容
|
||||
|
||||
[2025.05.28]
|
||||
|
||||
- [x] MoeLora
|
||||
- [ ] Coin Benchmark
|
||||
- [x] 确定保存什么,便于后期测试
|
||||
- [x] Olora (非实现问题,loss越来越高,感觉很难训练)
|
||||
- [ ] Hide-Llava(复写基类引入clip,不同的adapter做平均,loralinear根据不同的name做插入top layer或正常layer,模型要求接受传入task_id即clip计算的最大相似)
|
||||
- [ ] Hide-llava问题,前些层平均fusion很没有道理,后些层的moe处理,却整整引入了clip的计算量,(任务数确定task数量,使得一些方法没有扩展性)。现实场景要求:没法知道后面还有多少个数据集,然后减少遗忘,最好能够对后续未见数据集产生效果,moelora问题只能适当缓解,利用不同的参数承接不同的任务。 那这个benchmark,每次输入保留数据,baseline是进一个把之前所有的都训练一边,持续学习方法使用update的方式,比较不同数据集按批次输入的收益(找函数定义[How Efficient Are Today’s Continual Learning Algorithms?],[]),也就是准确度的积分。
|
||||
|
||||
[2025.05.30]
|
||||
|
||||
- [x] 评价指标
|
||||
|
||||
[2025.06.03]
|
||||
|
||||
- [ ] 预期算法,低计算成本,
|
||||
|
85
src/train.py
85
src/train.py
@ -13,30 +13,42 @@ from trl import (
|
||||
TrlParser,
|
||||
)
|
||||
from utils.trainer import ContinualTrainer
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
from utils.args import (
|
||||
ContinualScriptArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
)
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser(
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||
(
|
||||
ContinualScriptArguments,
|
||||
TrainingArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
) # type: ignore
|
||||
)
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 0 == 1:
|
||||
if TYPE_CHECKING:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ContinualModelConfig()
|
||||
reg_args = ContinualRegularizationArguments()
|
||||
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||
model_args
|
||||
model_args=model_args
|
||||
)
|
||||
################
|
||||
# Dataset
|
||||
@ -47,22 +59,49 @@ if __name__ == "__main__":
|
||||
accelerator = create_accelerator_and_postprocess(training_args)
|
||||
|
||||
if model_args.peft_type == "MMOELORA":
|
||||
from peft_library.tuners import MMOELoraConfig
|
||||
from peft.tuners import MMOELoraConfig
|
||||
|
||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
elif model_args.peft_type == "MOELORA":
|
||||
from peft.tuners import MOELoraConfig
|
||||
|
||||
peft_config = MOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
elif model_args.peft_type == "LORA":
|
||||
from peft.tuners.lora import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
||||
peft_config = LoraConfig(
|
||||
target_modules=model_args.lora_target_modules,
|
||||
r=model_args.lora_r,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
elif model_args.peft_type == "OLORA":
|
||||
from peft.tuners import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(
|
||||
target_modules=model_args.lora_target_modules,
|
||||
r=model_args.lora_r,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
init_lora_weights="olora"
|
||||
)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
else:
|
||||
peft_config = None
|
||||
|
||||
from peft import get_peft_model
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print(model)
|
||||
|
||||
@ -74,36 +113,38 @@ if __name__ == "__main__":
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn_for_train,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
dataset[script_args.dataset_test_split] # type: ignore
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
accelerator=accelerator,
|
||||
reg_args=reg_args,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print("Saving model")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
# trainer.save_model(training_args.output_dir)
|
||||
model.save_pretrained(training_args.output_dir)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print("Model saved")
|
||||
# 同步 accelerator
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
model.eval()
|
||||
# model.eval()
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
# from torch.utils.data import DataLoader
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
dataset[script_args.dataset_generation_split],
|
||||
batch_size=3,
|
||||
collate_fn=collate_fn_for_evaluate,
|
||||
)
|
||||
val_dataloader = accelerator.prepare(val_dataloader)
|
||||
# val_dataloader = DataLoader(
|
||||
# dataset[script_args.dataset_generation_split],
|
||||
# batch_size=3,
|
||||
# collate_fn=collate_fn_for_evaluate,
|
||||
# )
|
||||
# val_dataloader = accelerator.prepare(val_dataloader)
|
||||
|
||||
from utils.evaluate_tool import evaluate_rouge
|
||||
# from utils.evaluate_tool import evaluate_save
|
||||
|
||||
evaluate_rouge(model, val_dataloader, processor)
|
||||
# evaluate_save(model, val_dataloader, processor, accelerator)
|
||||
|
24
src/train.sh
24
src/train.sh
@ -1,15 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name gigaspeech \
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type LORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--lora_target_modules q_proj v_proj \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--peft_type MOELORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--output_dir checkpoint/qwen2mmoe/ \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir checkpoint/qwen2_alllinear/ \
|
||||
--learning_rate 5e-5 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 30
|
||||
--logging_steps 30 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.1
|
@ -1 +1 @@
|
||||
Subproject commit 6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec
|
||||
Subproject commit 42a8639e1e827d6f0ab07d87078ff048b20dab19
|
@ -1,7 +1,10 @@
|
||||
from accelerate import Accelerator, DataLoaderConfiguration
|
||||
from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
|
||||
def create_accelerator_and_postprocess(args):
|
||||
def create_accelerator_and_postprocess(args: TrainingArguments):
|
||||
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
|
||||
grad_acc_kwargs = {}
|
||||
if args.accelerator_config.gradient_accumulation_kwargs is not None:
|
||||
@ -33,9 +36,7 @@ def create_accelerator_and_postprocess(args):
|
||||
# this would have been updated above, no need for it anymore
|
||||
accelerator_config.pop("gradient_accumulation_kwargs")
|
||||
|
||||
accelerator_args = {
|
||||
"deepspeed_plugin": args.deepspeed_plugin,
|
||||
}
|
||||
accelerator_args = {"deepspeed_plugin": args.deepspeed_plugin, "log_with": "wandb"}
|
||||
accelerator_args["dataloader_config"] = dataloader_config
|
||||
# create accelerator object
|
||||
accelerator = Accelerator(**accelerator_args)
|
||||
|
@ -18,3 +18,16 @@ class ContinualModelConfig(ModelConfig):
|
||||
"""Model configuration for continual learning."""
|
||||
|
||||
peft_type: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContinualRegularizationArguments:
|
||||
"""Regularization arguments for continual learning."""
|
||||
|
||||
# EWC
|
||||
ewc_lambda: float = 0.0
|
||||
ewc_enable: bool = False
|
||||
|
||||
# LWF
|
||||
lwf_lambda: float = 0.0
|
||||
lwf_enable: bool = False
|
||||
|
@ -1,5 +1,6 @@
|
||||
import evaluate
|
||||
from accelerate import Accelerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||
@ -7,10 +8,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
||||
|
||||
for batch in val_dataloader:
|
||||
completion = model.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
pixel_values=batch["pixel_values"],
|
||||
image_grid_thw=batch["image_grid_thw"],
|
||||
**batch,
|
||||
max_length=1000,
|
||||
)
|
||||
target = batch["answers_ids"]
|
||||
@ -27,7 +25,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
||||
print(glue.compute())
|
||||
|
||||
|
||||
def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||
def evaluate_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||
import os
|
||||
|
||||
mtime = 0
|
||||
@ -38,8 +36,9 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
mtime = time
|
||||
|
||||
# 获取目录最后修改时间
|
||||
if not os.path.exists(f"results/{mtime}"):
|
||||
os.makedirs(f"results/{mtime}")
|
||||
if accelerator.is_local_main_process:
|
||||
if not os.path.exists(f"results/{mtime}"):
|
||||
os.makedirs(f"results/{mtime}")
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -47,15 +46,14 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
bar = tqdm(total=len(val_dataloader))
|
||||
|
||||
for batch in val_dataloader:
|
||||
target = batch.pop("answers_ids")
|
||||
origianl = batch.pop("original_data")
|
||||
answers = []
|
||||
completion = model.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
pixel_values=batch["pixel_values"],
|
||||
image_grid_thw=batch["image_grid_thw"],
|
||||
**batch,
|
||||
# max_new_tokens=30,
|
||||
max_length=1000,
|
||||
)
|
||||
target = batch["answers_ids"]
|
||||
generated_text = [
|
||||
out_ids[len(in_ids) :]
|
||||
for out_ids, in_ids in zip(completion, batch["input_ids"])
|
||||
@ -64,21 +62,19 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
generated_text, skip_special_tokens=True
|
||||
)
|
||||
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
||||
for i in range(len(generated_text)):
|
||||
answers.append(
|
||||
{
|
||||
"generated": generated_text[i],
|
||||
"target": target_text[i],
|
||||
"original": batch["original_data"][i],
|
||||
}
|
||||
)
|
||||
import json
|
||||
|
||||
world_size = accelerator.process_index
|
||||
|
||||
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
||||
for answer in answers:
|
||||
for i in range(len(generated_text)):
|
||||
answer = {
|
||||
"generated": generated_text[i],
|
||||
"target": target_text[i],
|
||||
"original": str(origianl[i]),
|
||||
}
|
||||
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
||||
f.write(json.dumps(answer) + "\n")
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
bar.update(1)
|
||||
accelerator.wait_for_everyone()
|
||||
@ -97,3 +93,71 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
# delete file
|
||||
for file in files:
|
||||
os.remove(f"results/{mtime}/{file}")
|
||||
|
||||
|
||||
def evaluate_from_jsonl_directory(directory_path):
|
||||
"""
|
||||
从指定目录读取所有jsonl文件并计算综合评估结果
|
||||
|
||||
Args:
|
||||
directory_path: 包含jsonl文件的目录路径
|
||||
|
||||
Returns:
|
||||
dict: 包含各项指标和综合结果的字典
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
||||
# 初始化评估器
|
||||
evaluate_bleu = evaluate.load("bleu")
|
||||
evaluate_rouge = evaluate.load("rouge")
|
||||
evaluate_meteor = evaluate.load("meteor")
|
||||
|
||||
# 读取目录下所有jsonl文件
|
||||
all_data = []
|
||||
for file in os.listdir(directory_path):
|
||||
if file.endswith(".jsonl"):
|
||||
file_path = os.path.join(directory_path, file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
data = json.loads(line)
|
||||
all_data.append(data)
|
||||
|
||||
if not all_data:
|
||||
print(f"未在目录 {directory_path} 中找到有效的jsonl数据")
|
||||
return None
|
||||
|
||||
# 准备数据
|
||||
predictions = [item["generated"] for item in all_data]
|
||||
references = [[item["target"]] for item in all_data]
|
||||
|
||||
# 批量添加数据
|
||||
evaluate_bleu.add_batch(predictions=predictions, references=references)
|
||||
evaluate_rouge.add_batch(predictions=predictions, references=references)
|
||||
evaluate_meteor.add_batch(predictions=predictions, references=references)
|
||||
|
||||
# 计算结果
|
||||
bleu = evaluate_bleu.compute()
|
||||
rouge = evaluate_rouge.compute()
|
||||
meteor = evaluate_meteor.compute()
|
||||
|
||||
# 计算综合结果
|
||||
comprehensive_score = (sum(bleu["precisions"]) + rouge["rougeL"] + meteor["meteor"]) / 6
|
||||
|
||||
results = {
|
||||
"bleu": bleu,
|
||||
"rouge": rouge,
|
||||
"meteor": meteor,
|
||||
"comprehensive_score": comprehensive_score,
|
||||
"total_samples": len(all_data),
|
||||
}
|
||||
|
||||
print(f"评估完成,共处理 {len(all_data)} 条数据")
|
||||
print(f"BLEU分数: {bleu}")
|
||||
print(f"ROUGE分数: {rouge}")
|
||||
print(f"METEOR分数: {meteor}")
|
||||
print(f"综合分数: {comprehensive_score}")
|
||||
|
||||
return results
|
||||
|
@ -2,16 +2,72 @@
|
||||
|
||||
|
||||
from transformers.trainer import *
|
||||
from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from .args import ContinualRegularizationArguments
|
||||
from peft_library.regularizations import EWC, LWF
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
|
||||
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
|
||||
logits = outputs.logits
|
||||
device = logits.device
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:].to(device)
|
||||
# Save memory
|
||||
masks = shift_labels != -100
|
||||
shift_logits = shift_logits[masks]
|
||||
shift_labels = shift_labels[masks]
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="none")
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if num_items_in_batch is None:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
# compat transformers>=4.46
|
||||
loss = loss.sum() / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
||||
class ContinualTrainer(Trainer):
|
||||
def __init__(
|
||||
self, model, args, data_collator, train_dataset, eval_dataset, accelerator
|
||||
self,
|
||||
model,
|
||||
args: TrainingArguments,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
accelerator,
|
||||
reg_args: ContinualRegularizationArguments = None,
|
||||
):
|
||||
self.accelerator = accelerator
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
# compute_loss_func=ce_loss_func,
|
||||
)
|
||||
|
||||
if reg_args.ewc_enable:
|
||||
self.ewc_lambda = reg_args.ewc_lambda
|
||||
from peft_library.regularizations.ewc import EWC
|
||||
|
||||
self.EWC = EWC()
|
||||
# fisher = t
|
||||
|
||||
if reg_args.lwf_enable:
|
||||
self.lwf_lambda = reg_args.lwf_lambda
|
||||
from peft_library.regularizations.lwf import LWF
|
||||
|
||||
self.LWF = LWF()
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
|
||||
if self.accelerator is not None:
|
||||
self.is_deepspeed_enabled = (
|
||||
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
@ -23,718 +79,79 @@ class ContinualTrainer(Trainer):
|
||||
return
|
||||
else:
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
|
||||
def create_optimizer(self):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
Setup the optimizer.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
||||
"""
|
||||
if (
|
||||
self.label_smoother is not None or self.compute_loss_func is not None
|
||||
) and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if labels is not None:
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
if _is_peft_model(unwrapped_model):
|
||||
model_name = unwrapped_model.base_model.model._get_name()
|
||||
else:
|
||||
model_name = unwrapped_model._get_name()
|
||||
# User-defined compute_loss function
|
||||
if self.compute_loss_func is not None:
|
||||
loss = self.compute_loss_func(
|
||||
outputs, labels, num_items_in_batch=num_items_in_batch
|
||||
)
|
||||
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||
else:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
else:
|
||||
if isinstance(outputs, dict) and "loss" not in outputs:
|
||||
raise ValueError(
|
||||
"The model did not return a loss from the inputs, only the following keys: "
|
||||
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
||||
)
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _inner_training_loop(
|
||||
self,
|
||||
batch_size=None,
|
||||
args=None,
|
||||
resume_from_checkpoint=None,
|
||||
trial=None,
|
||||
ignore_keys_for_eval=None,
|
||||
):
|
||||
self.accelerator.free_memory()
|
||||
self._train_batch_size = batch_size
|
||||
if self.args.auto_find_batch_size:
|
||||
if self.state.train_batch_size != self._train_batch_size:
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
(self.model_wrapped,) = release_memory(self.model_wrapped)
|
||||
self.model_wrapped = self.model
|
||||
|
||||
# Check for DeepSpeed *after* the intial pass and modify the config
|
||||
if self.is_deepspeed_enabled:
|
||||
# Temporarily unset `self.args.train_batch_size`
|
||||
original_bs = self.args.per_device_train_batch_size
|
||||
self.args.per_device_train_batch_size = (
|
||||
self._train_batch_size // max(1, self.args.n_gpu)
|
||||
)
|
||||
self.propagate_args_to_deepspeed(True)
|
||||
self.args.per_device_train_batch_size = original_bs
|
||||
self.state.train_batch_size = self._train_batch_size
|
||||
logger.debug(
|
||||
f"Currently training with a batch size of: {self._train_batch_size}"
|
||||
)
|
||||
# Data loader and number of training steps
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
if self.is_fsdp_xla_v2_enabled:
|
||||
train_dataloader = tpu_spmd_dataloader(train_dataloader)
|
||||
|
||||
# Setting up training control variables:
|
||||
# number of training epochs: num_train_epochs
|
||||
# number of training steps per epoch: num_update_steps_per_epoch
|
||||
# total number of training steps to execute: max_steps
|
||||
total_train_batch_size = (
|
||||
self._train_batch_size * args.gradient_accumulation_steps * args.world_size
|
||||
)
|
||||
|
||||
len_dataloader = None
|
||||
num_train_tokens = None
|
||||
if has_length(train_dataloader):
|
||||
len_dataloader = len(train_dataloader)
|
||||
num_update_steps_per_epoch = (
|
||||
len_dataloader // args.gradient_accumulation_steps
|
||||
)
|
||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||
num_examples = self.num_examples(train_dataloader)
|
||||
if args.max_steps > 0:
|
||||
max_steps = args.max_steps
|
||||
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
||||
args.max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
|
||||
# the best we can do.
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = (
|
||||
self.num_tokens(train_dataloader, args.max_steps)
|
||||
* args.gradient_accumulation_steps
|
||||
)
|
||||
else:
|
||||
max_steps = math.ceil(
|
||||
args.num_train_epochs * num_update_steps_per_epoch
|
||||
)
|
||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
||||
num_train_samples = (
|
||||
self.num_examples(train_dataloader) * args.num_train_epochs
|
||||
)
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = (
|
||||
self.num_tokens(train_dataloader) * args.num_train_epochs
|
||||
)
|
||||
elif (
|
||||
args.max_steps > 0
|
||||
): # Rely on max_steps when dataloader does not have a working size
|
||||
max_steps = args.max_steps
|
||||
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
||||
num_train_epochs = sys.maxsize
|
||||
num_update_steps_per_epoch = max_steps
|
||||
num_examples = total_train_batch_size * args.max_steps
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = (
|
||||
self.num_tokens(train_dataloader, args.max_steps)
|
||||
* args.gradient_accumulation_steps
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
|
||||
f" {args.max_steps}"
|
||||
)
|
||||
|
||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
||||
if self.args.n_gpu > 1:
|
||||
# nn.DataParallel(model) replicates the model, creating new variables and module
|
||||
# references registered here no longer work on other gpus, breaking the module
|
||||
raise ValueError(
|
||||
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
|
||||
" (torchrun or torch.distributed.launch (deprecated))."
|
||||
)
|
||||
else:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = (
|
||||
is_sagemaker_mp_enabled()
|
||||
or self.is_fsdp_xla_enabled
|
||||
or self.is_fsdp_enabled
|
||||
)
|
||||
|
||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||
if self._created_lr_scheduler:
|
||||
self.lr_scheduler = None
|
||||
self._created_lr_scheduler = False
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.optimizer, self.lr_scheduler = deepspeed_init(
|
||||
self, num_training_steps=max_steps
|
||||
)
|
||||
|
||||
if not delay_optimizer_creation:
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
self.state = TrainerState(
|
||||
stateful_callbacks=[
|
||||
cb
|
||||
for cb in self.callback_handler.callbacks + [self.control]
|
||||
if isinstance(cb, ExportableState)
|
||||
if self.optimizer is None:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": self.args.learning_rate,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": self.args.learning_rate / 10,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
"lr": self.args.learning_rate,
|
||||
},
|
||||
]
|
||||
)
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
self.state.train_batch_size = self._train_batch_size
|
||||
|
||||
# Compute absolute values for logging, eval, and save if given as ratio
|
||||
if args.logging_steps is not None:
|
||||
if args.logging_steps < 1:
|
||||
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
self.state.logging_steps = args.logging_steps
|
||||
if args.eval_steps is not None:
|
||||
if args.eval_steps < 1:
|
||||
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
|
||||
else:
|
||||
self.state.eval_steps = args.eval_steps
|
||||
if args.save_steps is not None:
|
||||
if args.save_steps < 1:
|
||||
self.state.save_steps = math.ceil(max_steps * args.save_steps)
|
||||
else:
|
||||
self.state.save_steps = args.save_steps
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
|
||||
|
||||
# Activate gradient checkpointing if needed
|
||||
if args.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs
|
||||
)
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
model = self._wrap_model(self.model_wrapped)
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# as the model is wrapped, don't use `accelerator.prepare`
|
||||
# this is for unhandled cases such as
|
||||
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||
use_accelerator_prepare = True if model is self.model else False
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
||||
|
||||
if use_accelerator_prepare and self.is_fsdp_enabled:
|
||||
# In case of auto_find_batch_size=True
|
||||
# Remove FSDP wrapping from sub-models.
|
||||
self.model = unwrap_model(self.model, recursive=True)
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
|
||||
if delay_optimizer_creation:
|
||||
if use_accelerator_prepare:
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
if self.accelerator.mixed_precision != "fp8":
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
|
||||
import bitsandbytes
|
||||
|
||||
# prepare using `accelerator` prepare
|
||||
if use_accelerator_prepare:
|
||||
self.model.train()
|
||||
if hasattr(self.lr_scheduler, "step"):
|
||||
if self.use_apex:
|
||||
model = self.accelerator.prepare(self.model)
|
||||
else:
|
||||
model, self.optimizer = self.accelerator.prepare(
|
||||
self.model, self.optimizer
|
||||
)
|
||||
else:
|
||||
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
|
||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
# In this case we are in DDP + LOMO, which should be supported
|
||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
self.model = self.model_wrapped = model
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
||||
logger.info(f"skipped {module}: {skipped / 2**20}M params")
|
||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
logger.info(f"skipped: {skipped / 2**20}M params")
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
if model is not self.model:
|
||||
self.model_wrapped = model
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||
|
||||
# backward compatibility
|
||||
if self.is_deepspeed_enabled:
|
||||
self.deepspeed = self.model_wrapped
|
||||
|
||||
# ckpt loading
|
||||
if resume_from_checkpoint is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
deepspeed_load_checkpoint(
|
||||
self.model_wrapped,
|
||||
resume_from_checkpoint,
|
||||
load_module_strict=not _is_peft_model(self.model),
|
||||
)
|
||||
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
|
||||
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
|
||||
# important: at this point:
|
||||
# self.model is the Transformers Model
|
||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
|
||||
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {num_examples:,}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info(
|
||||
f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}"
|
||||
)
|
||||
if self.args.per_device_train_batch_size != self._train_batch_size:
|
||||
logger.info(
|
||||
f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}"
|
||||
)
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_steps:,}")
|
||||
logger.info(
|
||||
f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}"
|
||||
)
|
||||
|
||||
self.state.epoch = 0
|
||||
start_time = time.time()
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
# Check if continuing training from a checkpoint
|
||||
if resume_from_checkpoint is not None and os.path.isfile(
|
||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
||||
):
|
||||
self.state = TrainerState.load_from_json(
|
||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
||||
)
|
||||
self.compare_trainer_and_checkpoint_args(self.args, self.state)
|
||||
self._load_callback_state()
|
||||
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
|
||||
if not args.ignore_data_skip:
|
||||
steps_trained_in_current_epoch = self.state.global_step % (
|
||||
num_update_steps_per_epoch
|
||||
)
|
||||
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
|
||||
else:
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
logger.info(
|
||||
" Continuing training from checkpoint, will skip to saved global_step"
|
||||
)
|
||||
logger.info(f" Continuing training from epoch {epochs_trained}")
|
||||
logger.info(
|
||||
f" Continuing training from global step {self.state.global_step}"
|
||||
)
|
||||
if not args.ignore_data_skip:
|
||||
logger.info(
|
||||
f" Will skip the first {epochs_trained} epochs then the first"
|
||||
f" {steps_trained_in_current_epoch} batches in the first epoch."
|
||||
)
|
||||
|
||||
# Update the references
|
||||
self.callback_handler.model = self.model
|
||||
self.callback_handler.optimizer = self.optimizer
|
||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||
self.callback_handler.train_dataloader = train_dataloader
|
||||
if self.hp_name is not None and self._trial is not None:
|
||||
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
|
||||
# parameter to Train when using DDP.
|
||||
self.state.trial_name = self.hp_name(self._trial)
|
||||
if trial is not None:
|
||||
assignments = (
|
||||
trial.assignments
|
||||
if self.hp_search_backend == HPSearchBackend.SIGOPT
|
||||
else trial
|
||||
)
|
||||
self.state.trial_params = hp_params(assignments)
|
||||
else:
|
||||
self.state.trial_params = None
|
||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||
# to set this after the load.
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||
|
||||
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
||||
tr_loss = torch.tensor(0.0).to(args.device)
|
||||
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
||||
self._total_loss_scalar = 0.0
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
model.zero_grad()
|
||||
grad_norm: Optional[float] = None
|
||||
self.control = self.callback_handler.on_train_begin(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
if args.eval_on_start:
|
||||
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
|
||||
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
epoch_dataloader = train_dataloader
|
||||
if hasattr(epoch_dataloader, "set_epoch"):
|
||||
epoch_dataloader.set_epoch(epoch)
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
steps_in_epoch = (
|
||||
len(epoch_dataloader)
|
||||
if len_dataloader is not None
|
||||
else args.max_steps * args.gradient_accumulation_steps
|
||||
)
|
||||
self.control = self.callback_handler.on_epoch_begin(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
if (
|
||||
epoch == epochs_trained
|
||||
and resume_from_checkpoint is not None
|
||||
and steps_trained_in_current_epoch == 0
|
||||
):
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
|
||||
rng_to_sync = False
|
||||
steps_skipped = 0
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
epoch_dataloader = skip_first_batches(
|
||||
epoch_dataloader, steps_trained_in_current_epoch
|
||||
)
|
||||
steps_skipped = steps_trained_in_current_epoch
|
||||
steps_trained_in_current_epoch = 0
|
||||
rng_to_sync = True
|
||||
|
||||
step = -1
|
||||
epoch_iterator = iter(epoch_dataloader)
|
||||
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
||||
remainder = num_examples % args.gradient_accumulation_steps
|
||||
if remainder == 0:
|
||||
remainder = args.gradient_accumulation_steps
|
||||
update_step = -1
|
||||
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
|
||||
for _ in range(total_updates):
|
||||
update_step += 1
|
||||
num_batches = (
|
||||
args.gradient_accumulation_steps
|
||||
if update_step != (total_updates - 1)
|
||||
else remainder
|
||||
)
|
||||
batch_samples, num_items_in_batch = self.get_batch_samples(
|
||||
epoch_iterator, num_batches
|
||||
)
|
||||
for i, inputs in enumerate(batch_samples):
|
||||
step += 1
|
||||
do_sync_step = (
|
||||
step + 1
|
||||
) % args.gradient_accumulation_steps == 0 or (
|
||||
step + 1
|
||||
) == steps_in_epoch
|
||||
# Since we perform prefetching, we need to manually set sync_gradients
|
||||
if not do_sync_step:
|
||||
self.accelerator.gradient_state._set_sync_gradients(False)
|
||||
else:
|
||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||
|
||||
if self.args.include_num_input_tokens_seen:
|
||||
main_input_name = getattr(
|
||||
self.model, "main_input_name", "input_ids"
|
||||
)
|
||||
if main_input_name not in inputs:
|
||||
logger.warning(
|
||||
"Tried to track the number of tokens seen, however the current model is "
|
||||
"not configured properly to know what item is the input. To fix this, add "
|
||||
"a `main_input_name` attribute to the model class you are using."
|
||||
)
|
||||
else:
|
||||
input_tokens = inputs[main_input_name].numel()
|
||||
input_tokens = torch.tensor(
|
||||
input_tokens, device=self.args.device, dtype=torch.int64
|
||||
)
|
||||
self.state.num_input_tokens_seen += (
|
||||
self.accelerator.gather(input_tokens).sum().cpu().item()
|
||||
)
|
||||
if rng_to_sync:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
rng_to_sync = False
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
if steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.update(1)
|
||||
if steps_trained_in_current_epoch == 0:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
continue
|
||||
elif steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.close()
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1
|
||||
and self.accelerator.distributed_type
|
||||
!= DistributedType.DEEPSPEED
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
tr_loss_step = self.training_step(
|
||||
model, inputs, num_items_in_batch
|
||||
)
|
||||
|
||||
if (
|
||||
args.logging_nan_inf_filter
|
||||
and not is_torch_xla_available()
|
||||
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
|
||||
):
|
||||
# if loss is nan or inf simply add the average of previous logged losses
|
||||
tr_loss = tr_loss + tr_loss / (
|
||||
1 + self.state.global_step - self._globalstep_last_logged
|
||||
)
|
||||
else:
|
||||
if tr_loss.device != tr_loss_step.device:
|
||||
raise ValueError(
|
||||
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
|
||||
)
|
||||
tr_loss = tr_loss + tr_loss_step
|
||||
|
||||
self.current_flos += float(self.floating_point_ops(inputs))
|
||||
|
||||
if do_sync_step:
|
||||
# Since we perform prefetching, we need to manually set sync_gradients to True
|
||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||
|
||||
# Gradient clipping
|
||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||
# deepspeed does its own clipping
|
||||
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
_grad_norm = self.optimizer.clip_master_grads(
|
||||
args.max_grad_norm
|
||||
)
|
||||
elif self.use_apex:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
_grad_norm = nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
else:
|
||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
|
||||
if (
|
||||
is_accelerate_available()
|
||||
and self.accelerator.distributed_type
|
||||
== DistributedType.DEEPSPEED
|
||||
):
|
||||
grad_norm = model.get_global_grad_norm()
|
||||
# In some cases the grad norm may not return a float
|
||||
if hasattr(grad_norm, "item"):
|
||||
grad_norm = grad_norm.item()
|
||||
else:
|
||||
grad_norm = _grad_norm
|
||||
|
||||
self.control = self.callback_handler.on_pre_optimizer_step(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
self.control = self.callback_handler.on_optimizer_step(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
optimizer_was_run = (
|
||||
not self.accelerator.optimizer_step_was_skipped
|
||||
)
|
||||
if optimizer_was_run:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(
|
||||
self.lr_scheduler,
|
||||
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
||||
):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = (
|
||||
epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||
)
|
||||
self.control = self.callback_handler.on_step_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
self._maybe_log_save_evaluate(
|
||||
tr_loss,
|
||||
grad_norm,
|
||||
model,
|
||||
trial,
|
||||
epoch,
|
||||
ignore_keys_for_eval,
|
||||
start_time,
|
||||
)
|
||||
else:
|
||||
self.control = self.callback_handler.on_substep_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
||||
# each step. Since we are breaking the loop early, we need to manually
|
||||
# insert the mark_step here.
|
||||
if (
|
||||
self.control.should_epoch_stop
|
||||
or self.control.should_training_stop
|
||||
):
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
break
|
||||
# We also need to break out of the nested loop
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
break
|
||||
if step < 0:
|
||||
logger.warning(
|
||||
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
|
||||
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
|
||||
f" num_steps ({max_steps}) higher than the number of available samples."
|
||||
)
|
||||
self.control.should_training_stop = True
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
self._maybe_log_save_evaluate(
|
||||
tr_loss,
|
||||
grad_norm,
|
||||
model,
|
||||
trial,
|
||||
epoch,
|
||||
ignore_keys_for_eval,
|
||||
start_time,
|
||||
)
|
||||
|
||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||
if is_torch_xla_available():
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
else:
|
||||
logger.warning(
|
||||
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
||||
"configured. Check your training configuration if this is unexpected."
|
||||
)
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of training
|
||||
delattr(self, "_past")
|
||||
|
||||
logger.info(
|
||||
"\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
|
||||
)
|
||||
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
||||
# Wait for everyone to get here so we are sure the model has been saved by process 0.
|
||||
if is_torch_xla_available():
|
||||
xm.rendezvous("load_best_model_at_end")
|
||||
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
dist.barrier()
|
||||
elif is_sagemaker_mp_enabled():
|
||||
smp.barrier()
|
||||
|
||||
self._load_best_model()
|
||||
|
||||
# add remaining tr_loss
|
||||
self._total_loss_scalar += tr_loss.item()
|
||||
effective_global_step = max(
|
||||
self.state.global_step, 0.001
|
||||
) # Avoid ZeroDivisionError
|
||||
train_loss = self._total_loss_scalar / effective_global_step
|
||||
|
||||
metrics = speed_metrics(
|
||||
"train",
|
||||
start_time,
|
||||
num_samples=num_train_samples,
|
||||
num_steps=self.state.max_steps,
|
||||
num_tokens=num_train_tokens,
|
||||
)
|
||||
self.store_flos()
|
||||
metrics["total_flos"] = self.state.total_flos
|
||||
metrics["train_loss"] = train_loss
|
||||
|
||||
self.is_in_train = False
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
|
||||
self.log(metrics)
|
||||
|
||||
run_dir = self._get_output_dir(trial)
|
||||
checkpoints_sorted = self._sorted_checkpoints(
|
||||
use_mtime=False, output_dir=run_dir
|
||||
)
|
||||
|
||||
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
|
||||
if (
|
||||
self.args.should_save
|
||||
and self.state.best_model_checkpoint is not None
|
||||
and self.args.save_total_limit == 1
|
||||
):
|
||||
for checkpoint in checkpoints_sorted:
|
||||
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
|
||||
logger.info(
|
||||
f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit"
|
||||
)
|
||||
shutil.rmtree(checkpoint, ignore_errors=True)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
# Wait for the checkpoint to be uploaded.
|
||||
self._finish_current_push()
|
||||
|
||||
# After training we make sure to retrieve back the original forward pass method
|
||||
# for the embedding layer by removing the forward post hook.
|
||||
if self.neftune_noise_alpha is not None:
|
||||
self._deactivate_neftune(self.model)
|
||||
|
||||
return TrainOutput(self.state.global_step, train_loss, metrics)
|
||||
return self.optimizer
|
||||
|
Loading…
Reference in New Issue
Block a user