Compare commits
13 Commits
9ca588224d
...
42df13390d
Author | SHA1 | Date | |
---|---|---|---|
42df13390d | |||
266fcd57ad | |||
d686cbc254 | |||
b84ebb03c7 | |||
baccca420a | |||
70c446e548 | |||
0bc1034f35 | |||
3fe2c85f6b | |||
56e46f0e0c | |||
1d2e7b9dcd | |||
c100a59f0e | |||
188ea7df6e | |||
24a6c3c114 |
19
.vscode/settings.json
vendored
19
.vscode/settings.json
vendored
@ -1,6 +1,19 @@
|
|||||||
{
|
{
|
||||||
"python.analysis.extraPaths": [
|
"python.analysis.extraPaths": [
|
||||||
"src/transformers_repo/src/",
|
"./src/peft_repo/src/",
|
||||||
"src/peft_repo/src/"
|
"./src/transformers_repo/src/",
|
||||||
]
|
],
|
||||||
|
"python.analysis.exclude": [
|
||||||
|
"dataset/**/*",
|
||||||
|
],
|
||||||
|
"python.languageServer": "Default",
|
||||||
|
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||||
|
// "python.analysis.include": [
|
||||||
|
// "src/**/*"
|
||||||
|
// ],
|
||||||
|
"python.analysis.languageServerMode": "default",
|
||||||
|
"python.analysis.typeCheckingMode": "basic",
|
||||||
|
"python.analysis.userFileIndexingLimit": 10000,
|
||||||
|
"python.analysis.usePullDiagnostics": false,
|
||||||
|
"python.analysis.importFormat": "relative",
|
||||||
}
|
}
|
5
dataset/.gitignore
vendored
Normal file
5
dataset/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
derek-thomas*
|
||||||
|
*.lock
|
||||||
|
speechcolab*
|
||||||
|
lmms-lab*
|
||||||
|
downloads/*
|
2
dataset/OCR-VQA-200K/.gitignore
vendored
Normal file
2
dataset/OCR-VQA-200K/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
images/*
|
||||||
|
dataset.json
|
51
dataset/OCR-VQA-200K/download.py
Normal file
51
dataset/OCR-VQA-200K/download.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import urllib.request as ureq
|
||||||
|
import urllib.error
|
||||||
|
import concurrent.futures
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# Set the file paths for your Google Drive
|
||||||
|
dataset_path = "./dataset.json"
|
||||||
|
images_path = "./images"
|
||||||
|
download = 1 # Set to 0 if images are already downloaded
|
||||||
|
|
||||||
|
# Load dataset json file
|
||||||
|
with open(dataset_path, "r") as fp:
|
||||||
|
data = json.load(fp)
|
||||||
|
|
||||||
|
# Initialize a counter and a lock for thread-safe counting
|
||||||
|
downloaded_count = 0
|
||||||
|
count_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
# Function to download an image
|
||||||
|
def download_image(k):
|
||||||
|
global downloaded_count
|
||||||
|
imageURL = data[k]["imageURL"]
|
||||||
|
ext = os.path.splitext(imageURL)[1]
|
||||||
|
outputFile = os.path.join(images_path, f"{k}{ext}")
|
||||||
|
|
||||||
|
# Only download the image if it doesn't exist
|
||||||
|
if not os.path.exists(outputFile):
|
||||||
|
try:
|
||||||
|
ureq.urlretrieve(imageURL, outputFile)
|
||||||
|
|
||||||
|
with count_lock:
|
||||||
|
downloaded_count += 1
|
||||||
|
if downloaded_count % 100 == 0:
|
||||||
|
print(f"{downloaded_count} images downloaded.")
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
print(f"Error downloading {outputFile}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Download images using multiple threads
|
||||||
|
if download == 1:
|
||||||
|
if not os.path.exists(images_path):
|
||||||
|
os.makedirs(images_path)
|
||||||
|
|
||||||
|
# Create a thread pool and download the images in parallel
|
||||||
|
# Increase max_workers to potentially speed up downloads for many small files.
|
||||||
|
# The optimal number may vary based on your network and the server's capacity.
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=400) as executor:
|
||||||
|
executor.map(download_image, data.keys())
|
5
dataset/TextVQA/.gitignore
vendored
Normal file
5
dataset/TextVQA/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
images/test_images/*
|
||||||
|
images/train_images/*
|
||||||
|
TextVQA_0.5.1_test.json
|
||||||
|
TextVQA_0.5.1_train.json
|
||||||
|
TextVQA_0.5.1_val.json
|
3
dataset/vizwiz/Annotations/.gitignore
vendored
Normal file
3
dataset/vizwiz/Annotations/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
train.json
|
||||||
|
test.json
|
||||||
|
val.json
|
3
dataset/vizwiz/images/.gitignore
vendored
Normal file
3
dataset/vizwiz/images/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
val/*
|
||||||
|
train/*
|
||||||
|
test/*
|
@ -2,9 +2,11 @@
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"absl-py>=2.1.0",
|
"absl-py>=2.1.0",
|
||||||
"accelerate==1.2.1",
|
"accelerate==1.2.1",
|
||||||
|
"calflops>=0.3.2",
|
||||||
"datasets==3.2.0",
|
"datasets==3.2.0",
|
||||||
"deepspeed==0.16.2",
|
"deepspeed==0.16.2",
|
||||||
"evaluate==0.4.3",
|
"evaluate==0.4.3",
|
||||||
|
"huggingface-hub==0.30.1",
|
||||||
"librosa>=0.10.2.post1",
|
"librosa>=0.10.2.post1",
|
||||||
"markupsafe==2.1.5",
|
"markupsafe==2.1.5",
|
||||||
"ms-swift>=1.3.0",
|
"ms-swift>=1.3.0",
|
||||||
@ -19,9 +21,9 @@ dependencies = [
|
|||||||
"safetensors>=0.5.2",
|
"safetensors>=0.5.2",
|
||||||
"setuptools>=70.0.0",
|
"setuptools>=70.0.0",
|
||||||
"soundfile>=0.13.0",
|
"soundfile>=0.13.0",
|
||||||
"torch==2.5.1+cu124",
|
"torch==2.6.0",
|
||||||
"torchaudio==2.5.1+cu124",
|
"torchaudio==2.6.0",
|
||||||
"torchvision==0.20.1+cu124",
|
"torchvision==0.21.0",
|
||||||
"transformers==4.48.0",
|
"transformers==4.48.0",
|
||||||
"trl==0.13.0",
|
"trl==0.13.0",
|
||||||
"wandb>=0.19.4",
|
"wandb>=0.19.4",
|
||||||
|
5
src/.gitignore
vendored
5
src/.gitignore
vendored
@ -1 +1,4 @@
|
|||||||
checkpoint/*
|
checkpoint/*
|
||||||
|
wandb/*
|
||||||
|
test.py
|
||||||
|
results/
|
@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE
|
|||||||
debug: false
|
debug: false
|
||||||
deepspeed_config:
|
deepspeed_config:
|
||||||
deepspeed_multinode_launcher: standard
|
deepspeed_multinode_launcher: standard
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 2
|
||||||
zero3_init_flag: false
|
zero3_init_flag: false
|
||||||
zero_stage: 1
|
zero_stage: 1
|
||||||
distributed_type: DEEPSPEED
|
distributed_type: DEEPSPEED
|
||||||
@ -11,7 +11,7 @@ machine_rank: 0
|
|||||||
main_training_function: main
|
main_training_function: main
|
||||||
mixed_precision: 'bf16'
|
mixed_precision: 'bf16'
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 8
|
num_processes: 3
|
||||||
rdzv_backend: static
|
rdzv_backend: static
|
||||||
same_network: true
|
same_network: true
|
||||||
tpu_env: []
|
tpu_env: []
|
||||||
|
@ -12,7 +12,7 @@ machine_rank: 0
|
|||||||
main_training_function: main
|
main_training_function: main
|
||||||
mixed_precision: 'bf16'
|
mixed_precision: 'bf16'
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 8
|
num_processes: 4
|
||||||
rdzv_backend: static
|
rdzv_backend: static
|
||||||
same_network: true
|
same_network: true
|
||||||
tpu_env: []
|
tpu_env: []
|
||||||
|
@ -18,8 +18,13 @@ class GigaspeechDataset(Dataset):
|
|||||||
|
|
||||||
self.audio_processor = audio_processor
|
self.audio_processor = audio_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
gs = load_dataset("speechcolab/gigaspeech", "xs")
|
self.split = split
|
||||||
self.data = gs[split]
|
|
||||||
|
def load_data(self):
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
|
||||||
|
self.data = gs[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -54,7 +59,7 @@ class GigaspeechDataset(Dataset):
|
|||||||
audios=[(audio, sampling_rate)],
|
audios=[(audio, sampling_rate)],
|
||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample,
|
original=sample,
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||||
@ -87,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=text,
|
answer=text,
|
||||||
original=sample,
|
original=sample,
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_gigaspeech():
|
def test_gigaspeech():
|
||||||
@ -103,3 +108,21 @@ def test_gigaspeech():
|
|||||||
print(dataset[0])
|
print(dataset[0])
|
||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": GigaspeechDataset(split="train"),
|
||||||
|
"test": GigaspeechDataset(split="test"),
|
||||||
|
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="gigaspeech",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["audio", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_gigaspeech()
|
||||||
|
@ -22,14 +22,20 @@ class OCRVQADataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
self.vis_root = vis_root
|
self.vis_root = vis_root
|
||||||
|
|
||||||
|
from .vis_processor import size_processor
|
||||||
|
|
||||||
self.vis_processor = (
|
self.vis_processor = (
|
||||||
vis_processor if vis_processor is not None else self._vis_processor
|
vis_processor if vis_processor is not None else size_processor
|
||||||
)
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
self.split = split
|
||||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
|
self.ann_path = ann_path
|
||||||
elif split == "test":
|
|
||||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
|
def load_data(self):
|
||||||
|
if self.split == "train":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
|
||||||
|
elif self.split == "test":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
# "[vqa] {}",
|
# "[vqa] {}",
|
||||||
@ -64,24 +70,6 @@ class OCRVQADataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def _vis_processor(self, image: Image.Image):
|
|
||||||
width, height = image.size
|
|
||||||
if width > 500 or height > 500:
|
|
||||||
max_size = max(width, height)
|
|
||||||
ratio = 500 / max_size
|
|
||||||
new_width = int(width * ratio)
|
|
||||||
new_height = int(height * ratio)
|
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
|
||||||
|
|
||||||
if width < 28 or height < 28:
|
|
||||||
min_size = min(width, height)
|
|
||||||
ratio = 28 / min_size + 1
|
|
||||||
new_width = int(width * ratio)
|
|
||||||
new_height = int(height * ratio)
|
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.data[index]
|
sample = self.data[index]
|
||||||
image: Image.Image = Image.open(
|
image: Image.Image = Image.open(
|
||||||
@ -117,7 +105,7 @@ class OCRVQADataset(Dataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
images=[image],
|
images=[image],
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||||
@ -153,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
from .factory import register_dataset
|
||||||
|
from .format import dataset_dir as base_path
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": OCRVQADataset(
|
||||||
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": OCRVQADataset(
|
||||||
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": OCRVQADatasetForGeneration(
|
||||||
|
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||||
|
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])
|
||||||
|
139
src/dataset_library/RefCOCODataset.py
Normal file
139
src/dataset_library/RefCOCODataset.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from datasets import load_dataset, DatasetDict
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class RefCOCODataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vis_processor=None,
|
||||||
|
text_processor=None,
|
||||||
|
split: Literal["val", "test"] = "val",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
||||||
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
# print(sample)
|
||||||
|
images = sample["image"]
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
images = self.vis_processor(images)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=question,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Conversation(
|
||||||
|
role="assistant",
|
||||||
|
content=[ConverstationText(type="text", text=answer)],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[images],
|
||||||
|
chat=chat,
|
||||||
|
original=sample,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class RefCOCODatasetForGeneration(RefCOCODataset):
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
# print(sample)
|
||||||
|
images = sample["image"]
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
images = self.vis_processor(images)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"{question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[images],
|
||||||
|
chat=chat,
|
||||||
|
answer=answer,
|
||||||
|
original=sample,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_RefCOCO():
|
||||||
|
dataset = RefCOCODataset(
|
||||||
|
split="val",
|
||||||
|
)
|
||||||
|
print(dataset[3])
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
dataset = RefCOCODatasetForGeneration(
|
||||||
|
split="test",
|
||||||
|
)
|
||||||
|
print(dataset[3])
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": RefCOCODataset(split="val"),
|
||||||
|
"test": RefCOCODataset(split="test"),
|
||||||
|
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="refcoco",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_RefCOCO()
|
139
src/dataset_library/RefCOCOPlusDataset.py
Normal file
139
src/dataset_library/RefCOCOPlusDataset.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from datasets import load_dataset, DatasetDict
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class RefCOCOplusDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vis_processor=None,
|
||||||
|
text_processor=None,
|
||||||
|
split: Literal["val", "testA"] = "val",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
||||||
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
# print(sample)
|
||||||
|
images = sample["image"]
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
images = self.vis_processor(images)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=question,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Conversation(
|
||||||
|
role="assistant",
|
||||||
|
content=[ConverstationText(type="text", text=answer)],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[images],
|
||||||
|
chat=chat,
|
||||||
|
original=sample,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class RefCOCOplusDatasetForGeneration(RefCOCOplusDataset):
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
# print(sample)
|
||||||
|
images = sample["image"]
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
images = self.vis_processor(images)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"{question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[images],
|
||||||
|
chat=chat,
|
||||||
|
answer=answer,
|
||||||
|
original=sample,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_RefCOCOplus():
|
||||||
|
dataset = RefCOCOplusDataset(
|
||||||
|
split="val",
|
||||||
|
)
|
||||||
|
print(dataset[3])
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
dataset = RefCOCOplusDatasetForGeneration(
|
||||||
|
split="testA",
|
||||||
|
)
|
||||||
|
print(dataset[3])
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": RefCOCOplusDataset(split="val"),
|
||||||
|
"test": RefCOCOplusDataset(split="testA"),
|
||||||
|
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="refcocoplus",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_RefCOCOplus()
|
139
src/dataset_library/RefCOCOgDataset.py
Normal file
139
src/dataset_library/RefCOCOgDataset.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from datasets import load_dataset, DatasetDict
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class RefCOCOgDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vis_processor=None,
|
||||||
|
text_processor=None,
|
||||||
|
split: Literal["val", "test"] = "val",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
||||||
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
# print(sample)
|
||||||
|
images = sample["image"]
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
images = self.vis_processor(images)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=question,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Conversation(
|
||||||
|
role="assistant",
|
||||||
|
content=[ConverstationText(type="text", text=answer)],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[images],
|
||||||
|
chat=chat,
|
||||||
|
original=sample,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
# print(sample)
|
||||||
|
images = sample["image"]
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
images = self.vis_processor(images)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"{question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[images],
|
||||||
|
chat=chat,
|
||||||
|
answer=answer,
|
||||||
|
original=sample,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_RefCOCOg():
|
||||||
|
dataset = RefCOCOgDataset(
|
||||||
|
split="val",
|
||||||
|
)
|
||||||
|
print(dataset[3])
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
dataset = RefCOCOgDatasetForGeneration(
|
||||||
|
split="test",
|
||||||
|
)
|
||||||
|
print(dataset[3])
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": RefCOCOgDataset(split="val"),
|
||||||
|
"test": RefCOCOgDataset(split="test"),
|
||||||
|
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="refcocog",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_RefCOCOg()
|
@ -6,20 +6,25 @@ from .format import (
|
|||||||
DatasetOutput,
|
DatasetOutput,
|
||||||
)
|
)
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset, DatasetDict
|
||||||
|
|
||||||
|
|
||||||
class ScienceQADataset(Dataset):
|
class ScienceQADataset(Dataset):
|
||||||
def __init__(self, audio_processor=None, text_processor=None, split="train"):
|
def __init__(self, vis_processor=None, text_processor=None, split="train"):
|
||||||
"""
|
"""
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
ann_root (string): directory to store the annotation file
|
ann_root (string): directory to store the annotation file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.vis_processor = audio_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
ds = load_dataset("derek-thomas/ScienceQA")
|
self.split = split
|
||||||
self.data = ds[split]
|
|
||||||
|
def load_data(self):
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
|
||||||
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -60,7 +65,7 @@ class ScienceQADataset(Dataset):
|
|||||||
images=[images],
|
images=[images],
|
||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample,
|
original=sample,
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ScienceQADatasetForGeneration(ScienceQADataset):
|
class ScienceQADatasetForGeneration(ScienceQADataset):
|
||||||
@ -98,7 +103,7 @@ class ScienceQADatasetForGeneration(ScienceQADataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=choices[answer],
|
answer=choices[answer],
|
||||||
original=sample,
|
original=sample,
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_scienceQA():
|
def test_scienceQA():
|
||||||
@ -116,5 +121,19 @@ def test_scienceQA():
|
|||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": ScienceQADataset(split="train"),
|
||||||
|
"test": ScienceQADataset(split="test"),
|
||||||
|
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||||
|
}
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="scienceqa",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_scienceQA()
|
test_scienceQA()
|
||||||
|
@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
|
|||||||
vis_processor if vis_processor is not None else self._vis_processor
|
vis_processor if vis_processor is not None else self._vis_processor
|
||||||
)
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
self.split = split
|
||||||
|
self.ann_path = ann_path
|
||||||
|
self.vis_root = vis_root
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
if self.split == "train":
|
||||||
self.data = self.create_data(
|
self.data = self.create_data(
|
||||||
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
Path(self.ann_path, "TextVQA_0.5.1_train.json"),
|
||||||
vis_root=Path(vis_root, "train_images"),
|
vis_root=Path(self.vis_root, "train_images"),
|
||||||
)
|
)
|
||||||
elif split == "test":
|
elif self.split == "test":
|
||||||
self.data = self.create_data(
|
self.data = self.create_data(
|
||||||
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
Path(self.ann_path, "TextVQA_0.5.1_val.json"),
|
||||||
vis_root=Path(vis_root, "train_images"),
|
vis_root=Path(self.vis_root, "train_images"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
images=[image],
|
images=[image],
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class TextVQADatasetForGeneration(TextVQADataset):
|
class TextVQADatasetForGeneration(TextVQADataset):
|
||||||
@ -158,16 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
|
|||||||
chat=chat,
|
chat=chat,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
original=sample["original"],
|
original=sample["original"],
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_dataset():
|
from .format import dataset_dir
|
||||||
vis_root = "/home/zyy/dataset/TextVQA/images"
|
|
||||||
ann_path = "/home/zyy/dataset/TextVQA"
|
|
||||||
dataset = TextVQADataset(vis_root, ann_path)
|
|
||||||
for i in range(10):
|
|
||||||
print(dataset[i])
|
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": TextVQADataset(
|
||||||
|
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||||
|
ann_path=Path(dataset_dir, "TextVQA"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": TextVQADataset(
|
||||||
|
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||||
|
ann_path=Path(dataset_dir, "TextVQA"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": TextVQADatasetForGeneration(
|
||||||
|
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||||
|
ann_path=Path(dataset_dir, "TextVQA"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
from .factory import register_dataset
|
||||||
test_dataset()
|
|
||||||
|
register_dataset("textvqa", dataset, tag=["text", "image"])
|
||||||
|
168
src/dataset_library/VizWizDataset.py
Normal file
168
src/dataset_library/VizWizDataset.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from .format import (
|
||||||
|
Conversation,
|
||||||
|
ConverstationAudio,
|
||||||
|
ConverstationImage,
|
||||||
|
ConverstationText,
|
||||||
|
DatasetOutput,
|
||||||
|
)
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class VizWizDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
self.vis_root = vis_root
|
||||||
|
self.ann_path = ann_path
|
||||||
|
|
||||||
|
from .vis_processor import size_processor
|
||||||
|
|
||||||
|
self.vis_processor = (
|
||||||
|
vis_processor if vis_processor is not None else size_processor
|
||||||
|
)
|
||||||
|
self.text_processor = text_processor
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
if self.split == "train":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "train.json"))
|
||||||
|
elif self.split == "test":
|
||||||
|
self.data = self.create_data(Path(self.ann_path, "val.json"))
|
||||||
|
|
||||||
|
# self.instruction_pool = [
|
||||||
|
# "[vqa] {}",
|
||||||
|
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
def create_data(self, ann_path):
|
||||||
|
processed_data = []
|
||||||
|
with open(ann_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for i in range(len(data)):
|
||||||
|
if (
|
||||||
|
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
|
||||||
|
and data[i]["answerable"]
|
||||||
|
):
|
||||||
|
imageFile = data[i]["image"]
|
||||||
|
processed_data.append(
|
||||||
|
{
|
||||||
|
"question": data[i]["question"],
|
||||||
|
"answer": data[i]["answers"][0]["answer"],
|
||||||
|
"image_path": imageFile,
|
||||||
|
"original": data[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return processed_data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image: Image.Image = Image.open(
|
||||||
|
os.path.join(self.vis_root, sample["image_path"])
|
||||||
|
).convert("RGB")
|
||||||
|
# resize image
|
||||||
|
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Conversation(
|
||||||
|
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return DatasetOutput(
|
||||||
|
chat=chat,
|
||||||
|
original=sample["original"],
|
||||||
|
images=[image],
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class VizWizDatasetForGeneration(VizWizDataset):
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
|
"RGB"
|
||||||
|
)
|
||||||
|
# resize image
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
Conversation(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ConverstationImage(type="image", image_url=""),
|
||||||
|
ConverstationText(
|
||||||
|
type="text",
|
||||||
|
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return DatasetOutput(
|
||||||
|
images=[image],
|
||||||
|
chat=chat,
|
||||||
|
answer=answer,
|
||||||
|
original=sample["original"],
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
from .format import dataset_dir
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": VizWizDataset(
|
||||||
|
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
|
||||||
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": VizWizDataset(
|
||||||
|
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||||
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": VizWizDatasetForGeneration(
|
||||||
|
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||||
|
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
from .factory import register_dataset
|
||||||
|
|
||||||
|
register_dataset(
|
||||||
|
dataset_name="vizwiz",
|
||||||
|
dataset=dataset,
|
||||||
|
tag=["image", "text"],
|
||||||
|
)
|
@ -1,95 +1,78 @@
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Literal
|
from typing import Literal, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from dataset_library.format import dataset_dir
|
||||||
|
|
||||||
|
MAPPING_NAME_TO_DATASET: dict[
|
||||||
|
str, dict[Literal["train", "test", "generation"], Dataset]
|
||||||
|
] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_dataset(
|
||||||
|
dataset_name: str,
|
||||||
|
dataset: dict[Literal["train", "test", "generation"], Dataset],
|
||||||
|
tag: List[Literal["image", "text", "audio", "video"]] = [],
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_name (`str`):
|
||||||
|
The name of the dataset.
|
||||||
|
dataset (`Dataset`):
|
||||||
|
The dataset to register.
|
||||||
|
"""
|
||||||
|
dataset_name = dataset_name.lower()
|
||||||
|
if dataset_name in MAPPING_NAME_TO_DATASET:
|
||||||
|
raise ValueError(f"Dataset {dataset_name} already registered.")
|
||||||
|
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
|
||||||
|
|
||||||
|
|
||||||
|
from .GigaspeechDataset import (
|
||||||
|
GigaspeechDataset,
|
||||||
|
GigaspeechDatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||||
|
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||||
|
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||||
|
from .ScienceQADataset import (
|
||||||
|
ScienceQADataset,
|
||||||
|
ScienceQADatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .RefCOCODataset import (
|
||||||
|
RefCOCODataset,
|
||||||
|
RefCOCODatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .RefCOCOgDataset import (
|
||||||
|
RefCOCOgDataset,
|
||||||
|
RefCOCOgDatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .RefCOCOPlusDataset import (
|
||||||
|
RefCOCOplusDataset,
|
||||||
|
RefCOCOplusDatasetForGeneration,
|
||||||
|
)
|
||||||
|
from .VizWizDataset import (
|
||||||
|
VizWizDataset,
|
||||||
|
VizWizDatasetForGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
dataset_name, base_path="/home/zyy/dataset"
|
dataset_name: str, base_path=dataset_dir
|
||||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
"""
|
||||||
match dataset_name:
|
Get a dataset by name.
|
||||||
case "ocrvqa200k":
|
|
||||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
Args:
|
||||||
"train": OCRVQADataset(
|
dataset_name (`str`):
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
The name of the dataset.
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
base_path (`str`):
|
||||||
split="train",
|
The base path to the dataset.
|
||||||
),
|
"""
|
||||||
"test": OCRVQADataset(
|
dataset_name = dataset_name.lower()
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
if dataset_name not in MAPPING_NAME_TO_DATASET:
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
raise ValueError(f"Dataset {dataset_name} not registered.")
|
||||||
split="test",
|
for key in MAPPING_NAME_TO_DATASET[dataset_name]:
|
||||||
),
|
MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
|
||||||
"generation": OCRVQADatasetForGeneration(
|
return MAPPING_NAME_TO_DATASET[dataset_name]
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
case "chem":
|
|
||||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": ChemDataset(
|
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
|
||||||
ann_path=Path(base_path, "chem"),
|
|
||||||
split="train",
|
|
||||||
),
|
|
||||||
"test": ChemDataset(
|
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
|
||||||
ann_path=Path(base_path, "chem"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
"generation": ChemDatasetForGeneration(
|
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
|
||||||
ann_path=Path(base_path, "chem"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
case "gigaspeech":
|
|
||||||
from .GigaspeechDataset import (
|
|
||||||
GigaspeechDataset,
|
|
||||||
GigaspeechDatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": GigaspeechDataset(split="train"),
|
|
||||||
"test": GigaspeechDataset(split="test"),
|
|
||||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
|
||||||
}
|
|
||||||
|
|
||||||
case "textvqa":
|
|
||||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": TextVQADataset(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="train",
|
|
||||||
),
|
|
||||||
"test": TextVQADataset(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
"generation": TextVQADatasetForGeneration(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
case "scienceqa":
|
|
||||||
from .ScienceQADataset import (
|
|
||||||
ScienceQADataset,
|
|
||||||
ScienceQADatasetForGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": ScienceQADataset(split="train"),
|
|
||||||
"test": ScienceQADataset(split="test"),
|
|
||||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
|
||||||
}
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
from typing import Any, Tuple, TypedDict, Literal, Optional
|
from typing import Any, Tuple, TypedDict, Literal, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
dataset_dir = Path(__file__).resolve().parent.parent.parent / "dataset"
|
||||||
|
|
||||||
|
|
||||||
class ConverstationText(TypedDict):
|
class ConverstationText(TypedDict):
|
||||||
@ -25,8 +28,8 @@ class Conversation(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class DatasetOutput(TypedDict):
|
class DatasetOutput(TypedDict):
|
||||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
|
||||||
chat: list[Conversation]
|
chat: list[Conversation]
|
||||||
answer: Optional[str]
|
answer: Optional[str]
|
||||||
original: Any
|
original: Any
|
||||||
images: Optional[list[Image.Image]]
|
images: Optional[list[Image.Image]]
|
||||||
|
audios: Optional[list[Tuple[np.ndarray, int]]]
|
||||||
|
@ -1,46 +1,22 @@
|
|||||||
from .factory import get_dataset
|
import pytest
|
||||||
|
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
|
||||||
|
|
||||||
|
# Get all registered dataset names for parameterization
|
||||||
|
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
||||||
|
|
||||||
|
|
||||||
def test_gigaspeech():
|
@pytest.mark.parametrize("dataset_name", dataset_names)
|
||||||
dataset = get_dataset("gigaspeech")
|
def test_registered_datasets(dataset_name):
|
||||||
assert len(dataset["train"]) > 0
|
dataset = get_dataset(dataset_name)
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
# Test train split
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
assert "train" in dataset, f"Train split not found in {dataset_name}"
|
||||||
|
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
|
||||||
|
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
|
||||||
|
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
|
||||||
|
|
||||||
|
# Test test split
|
||||||
def test_chem():
|
assert "test" in dataset, f"Test split not found in {dataset_name}"
|
||||||
dataset = get_dataset("chem")
|
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
|
||||||
assert len(dataset["train"]) > 0
|
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_ocrvqa200k():
|
|
||||||
dataset = get_dataset("ocrvqa200k")
|
|
||||||
assert len(dataset["train"]) > 0
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_textvqa():
|
|
||||||
dataset = get_dataset("textvqa")
|
|
||||||
assert len(dataset["train"]) > 0
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_scienceqa():
|
|
||||||
dataset = get_dataset("scienceqa")
|
|
||||||
assert len(dataset["train"]) > 0
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
20
src/dataset_library/vis_processor.py
Normal file
20
src/dataset_library/vis_processor.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def size_processor(image: Image.Image):
|
||||||
|
width, height = image.size
|
||||||
|
if width > 500 or height > 500:
|
||||||
|
max_size = max(width, height)
|
||||||
|
ratio = 500 / max_size
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
if width < 28 or height < 28:
|
||||||
|
min_size = min(width, height)
|
||||||
|
ratio = 28 / min_size + 1
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
return image
|
@ -19,67 +19,39 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
from utils.args import (
|
||||||
|
ContinualScriptArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContinualRegularizationArguments,
|
||||||
|
)
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = TrlParser(
|
parser = TrlParser(
|
||||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
(
|
||||||
|
ContinualScriptArguments,
|
||||||
|
TrainingArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContinualRegularizationArguments,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
# for type hint
|
||||||
if 0 == 1:
|
if TYPE_CHECKING:
|
||||||
script_args = ContinualScriptArguments()
|
script_args = ContinualScriptArguments()
|
||||||
training_args = TrainingArguments()
|
training_args = TrainingArguments()
|
||||||
model_args = ModelConfig()
|
model_args = ContinualModelConfig()
|
||||||
|
reg_args = ContinualRegularizationArguments()
|
||||||
|
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
||||||
|
|
||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||||
torch_dtype = (
|
model_args=model_args, training_args=training_args
|
||||||
model_args.torch_dtype
|
)
|
||||||
if model_args.torch_dtype in ["auto", None]
|
|
||||||
else getattr(torch, model_args.torch_dtype)
|
|
||||||
)
|
|
||||||
quantization_config = get_quantization_config(model_args)
|
|
||||||
model_kwargs = dict(
|
|
||||||
attn_implementation=model_args.attn_implementation,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
)
|
|
||||||
from transformers import (
|
|
||||||
Qwen2VLProcessor,
|
|
||||||
Qwen2VLForConditionalGeneration,
|
|
||||||
AutoModelForVision2Seq,
|
|
||||||
AutoModel,
|
|
||||||
)
|
|
||||||
from peft.peft_model import PeftModelForCausalLM
|
|
||||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
||||||
training_args.output_dir,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# from peft_library import get_peft_model
|
|
||||||
|
|
||||||
processor = Qwen2VLProcessor.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
padding_side="left",
|
|
||||||
)
|
|
||||||
from model_library.qwen2vl import (
|
|
||||||
collate_fn_for_train,
|
|
||||||
collate_fn_for_evaluate,
|
|
||||||
)
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Dataset
|
# Dataset
|
||||||
################
|
################
|
||||||
@ -107,6 +79,6 @@ if __name__ == "__main__":
|
|||||||
collate_fn=collate_fn_for_evaluate,
|
collate_fn=collate_fn_for_evaluate,
|
||||||
)
|
)
|
||||||
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
||||||
from utils.evaluate_tool import evaluate_rouge, evalute_save
|
from utils.evaluate_tool import evaluate_rouge, evaluate_save
|
||||||
|
|
||||||
evalute_save(model, val_dataloader, processor, accelerator)
|
evaluate_save(model, val_dataloader, processor, accelerator)
|
||||||
|
@ -5,9 +5,12 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from utils.args import ContinualModelConfig
|
from utils.args import ContinualModelConfig
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_args: ContinualModelConfig):
|
def get_model(
|
||||||
|
model_args: ContinualModelConfig, training_args: TrainingArguments = None
|
||||||
|
):
|
||||||
torch_dtype = (
|
torch_dtype = (
|
||||||
model_args.torch_dtype
|
model_args.torch_dtype
|
||||||
if model_args.torch_dtype in ["auto", None]
|
if model_args.torch_dtype in ["auto", None]
|
||||||
@ -26,12 +29,20 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||||
|
if training_args is not None:
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
processor = Qwen2VLProcessor.from_pretrained(
|
processor = Qwen2VLProcessor.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
@ -49,11 +60,18 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
||||||
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
||||||
|
|
||||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
if training_args is not None:
|
||||||
model_args.model_name_or_path,
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
training_args.output_dir,
|
||||||
**model_kwargs,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
processor = Qwen2AudioProcessor.from_pretrained(
|
processor = Qwen2AudioProcessor.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
@ -68,4 +86,63 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||||
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
if training_args is not None:
|
||||||
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
||||||
|
from transformers.models.qwen2_5_omni import (
|
||||||
|
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||||
|
Qwen2_5OmniProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_args is not None:
|
||||||
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = Qwen2_5OmniProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from transformers import Qwen2AudioProcessor
|
from transformers import Qwen2AudioProcessor
|
||||||
|
from dataset_library.format import Conversation
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
def collate_fn_for_train(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
@ -60,7 +61,7 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
def collate_fn_for_evaluate(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(
|
processor.apply_chat_template(
|
||||||
@ -91,3 +92,4 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
|||||||
# answers_ids torch.Size([3, 10])
|
# answers_ids torch.Size([3, 10])
|
||||||
# answers_mask torch.Size([3, 10])
|
# answers_mask torch.Size([3, 10])
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||||
from .model import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
# from .model import Qwen2VLForConditionalGeneration_modified
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"collate_fn_for_train",
|
"collate_fn_for_train",
|
||||||
|
@ -1,9 +1,22 @@
|
|||||||
from transformers import Qwen2VLProcessor
|
# from transformers import Qwen2VLProcessor
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, "transformers_repo/src/")
|
||||||
|
sys.path.insert(0, "peft_repo/src/")
|
||||||
|
import transformers
|
||||||
|
import peft
|
||||||
from dataset_library.format import DatasetOutput
|
from dataset_library.format import DatasetOutput
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Qwen2VLProcessor
|
||||||
|
from transformers import Qwen2_5_VLProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_for_train(examples: list[DatasetOutput], processor: "Qwen2VLProcessor"):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
@ -32,36 +45,36 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
|
|||||||
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||||
# 151644 151645 8948 872 77091 None 151643
|
# 151644 151645 8948 872 77091 None 151643
|
||||||
|
|
||||||
# for i, label in enumerate(labels):
|
for i, label in enumerate(labels):
|
||||||
# now_index = 0
|
now_index = 0
|
||||||
# while now_index < len(label):
|
while now_index < len(label):
|
||||||
# if label[now_index] == im_start_token_id:
|
if label[now_index] == im_start_token_id:
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
# if (
|
if (
|
||||||
# label[now_index] == system_token_id
|
label[now_index] == system_token_id
|
||||||
# or label[now_index] == user_token_id
|
or label[now_index] == user_token_id
|
||||||
# ):
|
):
|
||||||
# while label[now_index] != im_end_token_id:
|
while label[now_index] != im_end_token_id:
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# elif label[now_index] == assistant_token_id:
|
elif label[now_index] == assistant_token_id:
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# label[now_index + 1] = -100
|
label[now_index + 1] = -100
|
||||||
# now_index += 2
|
now_index += 2
|
||||||
# while (
|
while (
|
||||||
# now_index < len(label) and label[now_index] != im_end_token_id
|
now_index < len(label) and label[now_index] != im_end_token_id
|
||||||
# ):
|
):
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
batch["labels"] = labels
|
batch["labels"] = labels
|
||||||
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
def collate_fn_for_evaluate(examples, processor: "Qwen2VLProcessor"):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(
|
processor.apply_chat_template(
|
||||||
@ -89,3 +102,68 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
|||||||
# answers_ids torch.Size([3, 10])
|
# answers_ids torch.Size([3, 10])
|
||||||
# answers_mask torch.Size([3, 10])
|
# answers_mask torch.Size([3, 10])
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# 随机生成一个图片
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
random_image = Image.fromarray(
|
||||||
|
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||||
|
)
|
||||||
|
example = {
|
||||||
|
"chat": [
|
||||||
|
# {"role": "user", "content": "What is the capital of France?"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "", # Assuming no image for this example
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "", # Assuming no image for this example
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "", # Assuming no image for this example
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is the capital of France?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}, # Assuming no image for this example
|
||||||
|
{"role": "assistant", "content": "The capital of France is Paris."},
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
random_image,
|
||||||
|
random_image,
|
||||||
|
random_image,
|
||||||
|
], # Assuming no images for this example
|
||||||
|
}
|
||||||
|
batch = collate_fn_for_train([example], processor)
|
||||||
|
# print(batch)
|
||||||
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
print(f"{k}: {v.shape}")
|
||||||
|
else:
|
||||||
|
print(f"{k}: {v}")
|
||||||
|
# input_ids: torch.Size([1, 101])
|
||||||
|
# attention_mask: torch.Size([1, 101])
|
||||||
|
# pixel_values: torch.Size([256, 1176])
|
||||||
|
# image_grid_thw: torch.Size([1, 3])
|
||||||
|
# labels: torch.Size([1, 101])
|
||||||
|
# Load model directly
|
||||||
|
from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
|
||||||
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-3B", torch_dtype="auto", device_map="auto")
|
||||||
|
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +73,6 @@ from peft.tuners import (
|
|||||||
from .tuners import MMOELoraModel, MMOELoraConfig
|
from .tuners import MMOELoraModel, MMOELoraConfig
|
||||||
from peft.tuners.tuners_utils import BaseTuner
|
from peft.tuners.tuners_utils import BaseTuner
|
||||||
from peft.utils import _prepare_prompt_learning_config
|
from peft.utils import _prepare_prompt_learning_config
|
||||||
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -46,7 +46,7 @@ from transformers.modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import PushToHubMixin
|
from transformers.utils import PushToHubMixin
|
||||||
|
|
||||||
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
|
from peft.utils.constants import DUMMY_MODEL_CONFIG
|
||||||
|
|
||||||
from peft import __version__
|
from peft import __version__
|
||||||
from peft.config import PeftConfig
|
from peft.config import PeftConfig
|
||||||
|
19
src/peft_library/regularizations/__init__.py
Normal file
19
src/peft_library/regularizations/__init__.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
class RegularizationMethod:
|
||||||
|
"""RegularizationMethod implement regularization strategies.
|
||||||
|
RegularizationMethod is a callable.
|
||||||
|
The method `update` is called to update the loss, typically at the end
|
||||||
|
of an experience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def pre_adapt(self, agent, exp):
|
||||||
|
pass # implementation may be empty if adapt is not needed
|
||||||
|
|
||||||
|
def post_adapt(self, agent, exp):
|
||||||
|
pass # implementation may be empty if adapt is not needed
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
from .ewc import EWC
|
||||||
|
from .lwf import LWF
|
58
src/peft_library/regularizations/ewc.py
Normal file
58
src/peft_library/regularizations/ewc.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from . import RegularizationMethod
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class EWC(RegularizationMethod):
|
||||||
|
"""Learning Without Forgetting.
|
||||||
|
|
||||||
|
The method applies knowledge distilllation to mitigate forgetting.
|
||||||
|
The teacher is the model checkpoint after the last experience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, EWC_lambda=1, temperature=2):
|
||||||
|
"""
|
||||||
|
:param alpha: distillation hyperparameter. It can be either a float
|
||||||
|
number or a list containing alpha for each experience.
|
||||||
|
:param temperature: softmax temperature for distillation
|
||||||
|
"""
|
||||||
|
self.EWC_lambda = EWC_lambda
|
||||||
|
self.temperature = temperature
|
||||||
|
self.fisher = {}
|
||||||
|
self.optpar = {}
|
||||||
|
""" In Avalanche, targets of different experiences are not ordered.
|
||||||
|
As a result, some units may be allocated even though their
|
||||||
|
corresponding class has never been seen by the model.
|
||||||
|
Knowledge distillation uses only units corresponding
|
||||||
|
to old classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def adapt(self, output, model, **kwargs):
|
||||||
|
ewc_loss = 0
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
dev = p.device
|
||||||
|
l = (
|
||||||
|
self.EWC_lambda
|
||||||
|
* self.fisher[n].to(dev)
|
||||||
|
* (p.data - self.optpar[n].to(dev)).pow(2)
|
||||||
|
)
|
||||||
|
ewc_loss += l.sum()
|
||||||
|
output["loss"] += ewc_loss
|
||||||
|
return output
|
||||||
|
|
||||||
|
def init_epoch(self, model):
|
||||||
|
"""Update the previous logits for the given question id."""
|
||||||
|
optpar = {}
|
||||||
|
fisher = {}
|
||||||
|
for n, p in model.module.base_model.model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
fisher[n] = torch.zeros(p.data.shape)
|
||||||
|
optpar[n] = p.clone().cpu().data
|
||||||
|
|
||||||
|
def update_fisher(self, model):
|
||||||
|
"""Update the fisher information for the given question id."""
|
||||||
|
for n, p in model.module.base_model.model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
fisher = self.fisher[n]
|
||||||
|
fisher += p.grad.data.pow(2).cpu()
|
||||||
|
self.fisher[n] = fisher
|
67
src/peft_library/regularizations/lwf.py
Normal file
67
src/peft_library/regularizations/lwf.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from . import RegularizationMethod
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LWF(RegularizationMethod):
|
||||||
|
"""Learning Without Forgetting.
|
||||||
|
|
||||||
|
The method applies knowledge distilllation to mitigate forgetting.
|
||||||
|
The teacher is the model checkpoint after the last experience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, LWF_lambda=1, temperature=2):
|
||||||
|
"""
|
||||||
|
:param alpha: distillation hyperparameter. It can be either a float
|
||||||
|
number or a list containing alpha for each experience.
|
||||||
|
:param temperature: softmax temperature for distillation
|
||||||
|
"""
|
||||||
|
self.LWF_lambda = LWF_lambda
|
||||||
|
self.temperature = temperature
|
||||||
|
self.previous_logits = {}
|
||||||
|
""" In Avalanche, targets of different experiences are not ordered.
|
||||||
|
As a result, some units may be allocated even though their
|
||||||
|
corresponding class has never been seen by the model.
|
||||||
|
Knowledge distillation uses only units corresponding
|
||||||
|
to old classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def adapt(self, output, **kwargs):
|
||||||
|
def modified_kl_div(old, new):
|
||||||
|
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
||||||
|
|
||||||
|
def smooth(logits, temp, dim):
|
||||||
|
log = logits ** (1 / temp)
|
||||||
|
return log / torch.sum(log, dim).unsqueeze(1)
|
||||||
|
|
||||||
|
lwf_loss = []
|
||||||
|
|
||||||
|
soft = torch.nn.Softmax(dim=1)
|
||||||
|
|
||||||
|
previous_keys = self.previous_logits.keys()
|
||||||
|
|
||||||
|
for index, question_id in enumerate(iterable=kwargs["question_ids"]):
|
||||||
|
if question_id in previous_keys:
|
||||||
|
previous_logits = self.previous_logits[question_id]
|
||||||
|
current_logits = output["logits"][index]
|
||||||
|
short_index = min(len(previous_logits), len(current_logits))
|
||||||
|
previous_logits = previous_logits[:short_index]
|
||||||
|
current_logits = current_logits[:short_index]
|
||||||
|
lwf_loss.append(
|
||||||
|
modified_kl_div(
|
||||||
|
old=smooth(
|
||||||
|
logits=soft(previous_logits).to(current_logits.device),
|
||||||
|
temp=2,
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
new=smooth(logits=soft(current_logits), temp=2, dim=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(lwf_loss) > 0:
|
||||||
|
output["loss"] += self.LWF_lambda * torch.stack(
|
||||||
|
tensors=lwf_loss, dim=0
|
||||||
|
).sum(dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def update_previous_logits(self, question_id, logits):
|
||||||
|
"""Update the previous logits for the given question id."""
|
||||||
|
self.previous_logits[question_id] = logits
|
@ -1 +1 @@
|
|||||||
Subproject commit ebab283576fc8803314314b5e1e4331c424a2198
|
Subproject commit a6e39fafb4e3ccce27671ce14afe62fb754c83c2
|
15
src/scripts/eval_omni.sh
Executable file
15
src/scripts/eval_omni.sh
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml evaluation.py \
|
||||||
|
--dataset_name textvqa \
|
||||||
|
--use_peft \
|
||||||
|
--peft_type MOELORA \
|
||||||
|
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||||
|
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||||
|
--per_device_train_batch_size 3 \
|
||||||
|
--per_device_eval_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--output_dir ./checkpoint/qwen2_5omni_moelora/ \
|
||||||
|
--bf16 \
|
||||||
|
--torch_dtype bfloat16
|
||||||
|
# --eval_strategy epoch \
|
25
src/scripts/train_omni_moelora.sh
Executable file
25
src/scripts/train_omni_moelora.sh
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||||
|
--dataset_name textvqa \
|
||||||
|
--use_peft \
|
||||||
|
--peft_type MOELORA \
|
||||||
|
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||||
|
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||||
|
--lora_r 8 \
|
||||||
|
--lora_alpha 32 \
|
||||||
|
--per_device_train_batch_size 3 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--num_train_epochs 1 \
|
||||||
|
--output_dir checkpoint/qwen2_5omni_moelora/ \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--warmup_ratio 0.03 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--bf16 \
|
||||||
|
--torch_dtype bfloat16 \
|
||||||
|
--logging_steps 300 \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--weight_decay 0.1 \
|
||||||
|
--eval_strategy steps \
|
||||||
|
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500
|
25
src/scripts/train_omni_olora.sh
Executable file
25
src/scripts/train_omni_olora.sh
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||||
|
--dataset_name textvqa \
|
||||||
|
--use_peft \
|
||||||
|
--peft_type OLORA \
|
||||||
|
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||||
|
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||||
|
--lora_r 8 \
|
||||||
|
--lora_alpha 32 \
|
||||||
|
--per_device_train_batch_size 3 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--num_train_epochs 1 \
|
||||||
|
--output_dir checkpoint/qwen2_5omni_olora/ \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--warmup_ratio 0.03 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--bf16 \
|
||||||
|
--torch_dtype bfloat16 \
|
||||||
|
--logging_steps 300 \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--weight_decay 0.1 \
|
||||||
|
--eval_strategy steps \
|
||||||
|
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500
|
31
src/test_evalutae.py
Normal file
31
src/test_evalutae.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import evaluate
|
||||||
|
|
||||||
|
# Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, and CIDEr
|
||||||
|
example = {
|
||||||
|
"generated": "The cat sat on the mat.",
|
||||||
|
"target": "The cat is sitting on the mat.",
|
||||||
|
"original": "The cat is sitting on the mat.",
|
||||||
|
}
|
||||||
|
evaluate_bleu = evaluate.load("bleu")
|
||||||
|
evaluate_rouge = evaluate.load("rouge")
|
||||||
|
evaluate_meteor = evaluate.load("meteor")
|
||||||
|
|
||||||
|
evaluate_bleu.add_batch(
|
||||||
|
predictions=[example["generated"]],
|
||||||
|
references=[[example["target"]]],
|
||||||
|
)
|
||||||
|
evaluate_rouge.add_batch(
|
||||||
|
predictions=[example["generated"]],
|
||||||
|
references=[[example["target"]]],
|
||||||
|
)
|
||||||
|
evaluate_meteor.add_batch(
|
||||||
|
predictions=[example["generated"]],
|
||||||
|
references=[[example["target"]]],
|
||||||
|
)
|
||||||
|
|
||||||
|
bleu = evaluate_bleu.compute()
|
||||||
|
rouge = evaluate_rouge.compute()
|
||||||
|
meteor = evaluate_meteor.compute()
|
||||||
|
|
||||||
|
comprehensive_results = sum(bleu['precisions']) + rouge['rougeL'] + meteor['meteor']
|
||||||
|
print("Comprehensive Results:", comprehensive_results/6)
|
27
src/todo.md
27
src/todo.md
@ -17,6 +17,31 @@
|
|||||||
|
|
||||||
[2025.01.19]
|
[2025.01.19]
|
||||||
|
|
||||||
- [ ] 多个数据集引入
|
- [x] 多个数据集引入
|
||||||
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
||||||
- [ ] 引入EWC和LWF
|
- [ ] 引入EWC和LWF
|
||||||
|
|
||||||
|
[2025.05.15]
|
||||||
|
|
||||||
|
- [x] vizwiz处理
|
||||||
|
|
||||||
|
[2025.05.16]
|
||||||
|
|
||||||
|
- [ ] 处理不同的持续学习框架,使得整体框架能够兼容
|
||||||
|
|
||||||
|
[2025.05.28]
|
||||||
|
|
||||||
|
- [x] MoeLora
|
||||||
|
- [ ] Coin Benchmark
|
||||||
|
- [x] 确定保存什么,便于后期测试
|
||||||
|
- [x] Olora (非实现问题,loss越来越高,感觉很难训练)
|
||||||
|
- [ ] Hide-Llava(复写基类引入clip,不同的adapter做平均,loralinear根据不同的name做插入top layer或正常layer,模型要求接受传入task_id即clip计算的最大相似)
|
||||||
|
- [ ] Hide-llava问题,前些层平均fusion很没有道理,后些层的moe处理,却整整引入了clip的计算量,(任务数确定task数量,使得一些方法没有扩展性)。现实场景要求:没法知道后面还有多少个数据集,然后减少遗忘,最好能够对后续未见数据集产生效果,moelora问题只能适当缓解,利用不同的参数承接不同的任务。 那这个benchmark,每次输入保留数据,baseline是进一个把之前所有的都训练一边,持续学习方法使用update的方式,比较不同数据集按批次输入的收益(找函数定义[How Efficient Are Today’s Continual Learning Algorithms?],[]),也就是准确度的积分。
|
||||||
|
|
||||||
|
[2025.05.30]
|
||||||
|
|
||||||
|
- [x] 评价指标
|
||||||
|
|
||||||
|
[2025.06.03]
|
||||||
|
|
||||||
|
- [ ] 预期算法,低计算成本,
|
||||||
|
60
src/train.py
60
src/train.py
@ -13,30 +13,42 @@ from trl import (
|
|||||||
TrlParser,
|
TrlParser,
|
||||||
)
|
)
|
||||||
from utils.trainer import ContinualTrainer
|
from utils.trainer import ContinualTrainer
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
from utils.args import (
|
||||||
|
ContinualScriptArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContinualRegularizationArguments,
|
||||||
|
)
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = TrlParser(
|
parser = TrlParser(
|
||||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
(
|
||||||
|
ContinualScriptArguments,
|
||||||
|
TrainingArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContinualRegularizationArguments,
|
||||||
|
) # type: ignore
|
||||||
)
|
)
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
# for type hint
|
||||||
if 0 == 1:
|
if TYPE_CHECKING:
|
||||||
script_args = ContinualScriptArguments()
|
script_args = ContinualScriptArguments()
|
||||||
training_args = TrainingArguments()
|
training_args = TrainingArguments()
|
||||||
model_args = ContinualModelConfig()
|
model_args = ContinualModelConfig()
|
||||||
|
reg_args = ContinualRegularizationArguments()
|
||||||
|
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
||||||
|
|
||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||||
model_args
|
model_args=model_args
|
||||||
)
|
)
|
||||||
################
|
################
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -47,12 +59,19 @@ if __name__ == "__main__":
|
|||||||
accelerator = create_accelerator_and_postprocess(training_args)
|
accelerator = create_accelerator_and_postprocess(training_args)
|
||||||
|
|
||||||
if model_args.peft_type == "MMOELORA":
|
if model_args.peft_type == "MMOELORA":
|
||||||
from peft_library.tuners import MMOELoraConfig
|
from peft.tuners import MMOELoraConfig
|
||||||
|
|
||||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
model.add_adapter(peft_config)
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
|
elif model_args.peft_type == "MOELORA":
|
||||||
|
from peft.tuners import MOELoraConfig
|
||||||
|
|
||||||
|
peft_config = MOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
elif model_args.peft_type == "LORA":
|
elif model_args.peft_type == "LORA":
|
||||||
from peft.tuners.lora import LoraConfig
|
from peft.tuners.lora import LoraConfig
|
||||||
|
|
||||||
@ -64,10 +83,25 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.add_adapter(peft_config)
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
|
elif model_args.peft_type == "OLORA":
|
||||||
|
from peft.tuners import LoraConfig
|
||||||
|
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
target_modules=model_args.lora_target_modules,
|
||||||
|
r=model_args.lora_r,
|
||||||
|
lora_alpha=model_args.lora_alpha,
|
||||||
|
lora_dropout=model_args.lora_dropout,
|
||||||
|
init_lora_weights="olora"
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
peft_config = None
|
peft_config = None
|
||||||
|
|
||||||
|
from peft import get_peft_model
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
@ -79,19 +113,21 @@ if __name__ == "__main__":
|
|||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=collate_fn_for_train,
|
data_collator=collate_fn_for_train,
|
||||||
train_dataset=dataset[script_args.dataset_train_split],
|
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
||||||
eval_dataset=(
|
eval_dataset=(
|
||||||
dataset[script_args.dataset_test_split]
|
dataset[script_args.dataset_test_split] # type: ignore
|
||||||
if training_args.eval_strategy != "no"
|
if training_args.eval_strategy != "no"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
|
reg_args=reg_args,
|
||||||
)
|
)
|
||||||
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
print("Saving model")
|
print("Saving model")
|
||||||
trainer.save_model(training_args.output_dir)
|
# trainer.save_model(training_args.output_dir)
|
||||||
|
model.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
print("Model saved")
|
print("Model saved")
|
||||||
@ -109,6 +145,6 @@ if __name__ == "__main__":
|
|||||||
# )
|
# )
|
||||||
# val_dataloader = accelerator.prepare(val_dataloader)
|
# val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
|
||||||
# from utils.evaluate_tool import evaluate_rouge
|
# from utils.evaluate_tool import evaluate_save
|
||||||
|
|
||||||
# evaluate_rouge(model, val_dataloader, processor)
|
# evaluate_save(model, val_dataloader, processor, accelerator)
|
||||||
|
15
src/train.sh
15
src/train.sh
@ -1,18 +1,19 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||||
--dataset_name chem \
|
--dataset_name textvqa \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type LORA \
|
--peft_type MOELORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||||
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
||||||
--lora_r 8 \
|
--lora_r 8 \
|
||||||
--lora_alpha 32 \
|
--lora_alpha 32 \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 3 \
|
||||||
--per_device_eval_batch_size 2 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--num_train_epochs 1 \
|
||||||
--output_dir checkpoint/qwen2_alllinear/ \
|
--output_dir checkpoint/qwen2_alllinear/ \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 5e-5 \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
--torch_dtype bfloat16 \
|
--torch_dtype bfloat16 \
|
||||||
--logging_steps 30 \
|
--logging_steps 30 \
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 8ee1a4eadda1d83cf65c024fe54364b5bd74e55f
|
Subproject commit 42a8639e1e827d6f0ab07d87078ff048b20dab19
|
@ -18,3 +18,16 @@ class ContinualModelConfig(ModelConfig):
|
|||||||
"""Model configuration for continual learning."""
|
"""Model configuration for continual learning."""
|
||||||
|
|
||||||
peft_type: Optional[str] = None
|
peft_type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContinualRegularizationArguments:
|
||||||
|
"""Regularization arguments for continual learning."""
|
||||||
|
|
||||||
|
# EWC
|
||||||
|
ewc_lambda: float = 0.0
|
||||||
|
ewc_enable: bool = False
|
||||||
|
|
||||||
|
# LWF
|
||||||
|
lwf_lambda: float = 0.0
|
||||||
|
lwf_enable: bool = False
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import evaluate
|
import evaluate
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
|
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||||
@ -7,10 +8,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
|||||||
|
|
||||||
for batch in val_dataloader:
|
for batch in val_dataloader:
|
||||||
completion = model.generate(
|
completion = model.generate(
|
||||||
input_ids=batch["input_ids"],
|
**batch,
|
||||||
attention_mask=batch["attention_mask"],
|
|
||||||
pixel_values=batch["pixel_values"],
|
|
||||||
image_grid_thw=batch["image_grid_thw"],
|
|
||||||
max_length=1000,
|
max_length=1000,
|
||||||
)
|
)
|
||||||
target = batch["answers_ids"]
|
target = batch["answers_ids"]
|
||||||
@ -27,7 +25,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
|||||||
print(glue.compute())
|
print(glue.compute())
|
||||||
|
|
||||||
|
|
||||||
def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
def evaluate_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
mtime = 0
|
mtime = 0
|
||||||
@ -53,6 +51,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
|||||||
answers = []
|
answers = []
|
||||||
completion = model.generate(
|
completion = model.generate(
|
||||||
**batch,
|
**batch,
|
||||||
|
# max_new_tokens=30,
|
||||||
max_length=1000,
|
max_length=1000,
|
||||||
)
|
)
|
||||||
generated_text = [
|
generated_text = [
|
||||||
@ -63,20 +62,17 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
|||||||
generated_text, skip_special_tokens=True
|
generated_text, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
||||||
for i in range(len(generated_text)):
|
|
||||||
answers.append(
|
|
||||||
{
|
|
||||||
"generated": generated_text[i],
|
|
||||||
"target": target_text[i],
|
|
||||||
"original": str(origianl[i]),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
world_size = accelerator.process_index
|
world_size = accelerator.process_index
|
||||||
|
|
||||||
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
for i in range(len(generated_text)):
|
||||||
for answer in answers:
|
answer = {
|
||||||
|
"generated": generated_text[i],
|
||||||
|
"target": target_text[i],
|
||||||
|
"original": str(origianl[i]),
|
||||||
|
}
|
||||||
|
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
||||||
f.write(json.dumps(answer) + "\n")
|
f.write(json.dumps(answer) + "\n")
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
@ -97,3 +93,71 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
|||||||
# delete file
|
# delete file
|
||||||
for file in files:
|
for file in files:
|
||||||
os.remove(f"results/{mtime}/{file}")
|
os.remove(f"results/{mtime}/{file}")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_from_jsonl_directory(directory_path):
|
||||||
|
"""
|
||||||
|
从指定目录读取所有jsonl文件并计算综合评估结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory_path: 包含jsonl文件的目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含各项指标和综合结果的字典
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 初始化评估器
|
||||||
|
evaluate_bleu = evaluate.load("bleu")
|
||||||
|
evaluate_rouge = evaluate.load("rouge")
|
||||||
|
evaluate_meteor = evaluate.load("meteor")
|
||||||
|
|
||||||
|
# 读取目录下所有jsonl文件
|
||||||
|
all_data = []
|
||||||
|
for file in os.listdir(directory_path):
|
||||||
|
if file.endswith(".jsonl"):
|
||||||
|
file_path = os.path.join(directory_path, file)
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
data = json.loads(line)
|
||||||
|
all_data.append(data)
|
||||||
|
|
||||||
|
if not all_data:
|
||||||
|
print(f"未在目录 {directory_path} 中找到有效的jsonl数据")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
predictions = [item["generated"] for item in all_data]
|
||||||
|
references = [[item["target"]] for item in all_data]
|
||||||
|
|
||||||
|
# 批量添加数据
|
||||||
|
evaluate_bleu.add_batch(predictions=predictions, references=references)
|
||||||
|
evaluate_rouge.add_batch(predictions=predictions, references=references)
|
||||||
|
evaluate_meteor.add_batch(predictions=predictions, references=references)
|
||||||
|
|
||||||
|
# 计算结果
|
||||||
|
bleu = evaluate_bleu.compute()
|
||||||
|
rouge = evaluate_rouge.compute()
|
||||||
|
meteor = evaluate_meteor.compute()
|
||||||
|
|
||||||
|
# 计算综合结果
|
||||||
|
comprehensive_score = (sum(bleu["precisions"]) + rouge["rougeL"] + meteor["meteor"]) / 6
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"bleu": bleu,
|
||||||
|
"rouge": rouge,
|
||||||
|
"meteor": meteor,
|
||||||
|
"comprehensive_score": comprehensive_score,
|
||||||
|
"total_samples": len(all_data),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"评估完成,共处理 {len(all_data)} 条数据")
|
||||||
|
print(f"BLEU分数: {bleu}")
|
||||||
|
print(f"ROUGE分数: {rouge}")
|
||||||
|
print(f"METEOR分数: {meteor}")
|
||||||
|
print(f"综合分数: {comprehensive_score}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
@ -1,26 +1,73 @@
|
|||||||
# _________________________________________________________
|
# _________________________________________________________
|
||||||
|
|
||||||
|
|
||||||
from transformers.trainer import (
|
|
||||||
Trainer,
|
|
||||||
_is_peft_model,
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
|
||||||
tpu_spmd_dataloader,
|
|
||||||
logger,
|
|
||||||
has_length,
|
|
||||||
sys,
|
|
||||||
)
|
|
||||||
from transformers.trainer import *
|
from transformers.trainer import *
|
||||||
|
from transformers import (
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from .args import ContinualRegularizationArguments
|
||||||
|
from peft_library.regularizations import EWC, LWF
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
|
||||||
|
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
|
||||||
|
logits = outputs.logits
|
||||||
|
device = logits.device
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :]
|
||||||
|
shift_labels = labels[..., 1:].to(device)
|
||||||
|
# Save memory
|
||||||
|
masks = shift_labels != -100
|
||||||
|
shift_logits = shift_logits[masks]
|
||||||
|
shift_labels = shift_labels[masks]
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(reduction="none")
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
loss = loss.mean()
|
||||||
|
else:
|
||||||
|
# compat transformers>=4.46
|
||||||
|
loss = loss.sum() / num_items_in_batch
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class ContinualTrainer(Trainer):
|
class ContinualTrainer(Trainer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model, args, data_collator, train_dataset, eval_dataset, accelerator
|
self,
|
||||||
|
model,
|
||||||
|
args: TrainingArguments,
|
||||||
|
data_collator,
|
||||||
|
train_dataset,
|
||||||
|
eval_dataset,
|
||||||
|
accelerator,
|
||||||
|
reg_args: ContinualRegularizationArguments = None,
|
||||||
):
|
):
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
super().__init__(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
data_collator,
|
||||||
|
train_dataset,
|
||||||
|
eval_dataset,
|
||||||
|
# compute_loss_func=ce_loss_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
if reg_args.ewc_enable:
|
||||||
|
self.ewc_lambda = reg_args.ewc_lambda
|
||||||
|
from peft_library.regularizations.ewc import EWC
|
||||||
|
|
||||||
|
self.EWC = EWC()
|
||||||
|
# fisher = t
|
||||||
|
|
||||||
|
if reg_args.lwf_enable:
|
||||||
|
self.lwf_lambda = reg_args.lwf_lambda
|
||||||
|
from peft_library.regularizations.lwf import LWF
|
||||||
|
|
||||||
|
self.LWF = LWF()
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
|
|
||||||
if self.accelerator is not None:
|
if self.accelerator is not None:
|
||||||
self.is_deepspeed_enabled = (
|
self.is_deepspeed_enabled = (
|
||||||
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||||
@ -32,718 +79,79 @@ class ContinualTrainer(Trainer):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
super().create_accelerator_and_postprocess()
|
super().create_accelerator_and_postprocess()
|
||||||
|
|
||||||
def compute_loss(
|
def create_optimizer(self):
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
Setup the optimizer.
|
||||||
|
|
||||||
Subclass and override for custom behavior.
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||||
|
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
||||||
"""
|
"""
|
||||||
if (
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
self.label_smoother is not None or self.compute_loss_func is not None
|
|
||||||
) and "labels" in inputs:
|
|
||||||
labels = inputs.pop("labels")
|
|
||||||
else:
|
|
||||||
labels = None
|
|
||||||
if self.model_accepts_loss_kwargs:
|
|
||||||
loss_kwargs = {}
|
|
||||||
if num_items_in_batch is not None:
|
|
||||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
|
||||||
inputs = {**inputs, **loss_kwargs}
|
|
||||||
outputs = model(**inputs)
|
|
||||||
# Save past state if it exists
|
|
||||||
# TODO: this needs to be fixed and made cleaner later.
|
|
||||||
if self.args.past_index >= 0:
|
|
||||||
self._past = outputs[self.args.past_index]
|
|
||||||
|
|
||||||
if labels is not None:
|
if self.optimizer is None:
|
||||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
if _is_peft_model(unwrapped_model):
|
optimizer_grouped_parameters = [
|
||||||
model_name = unwrapped_model.base_model.model._get_name()
|
{
|
||||||
else:
|
"params": [
|
||||||
model_name = unwrapped_model._get_name()
|
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n)
|
||||||
# User-defined compute_loss function
|
],
|
||||||
if self.compute_loss_func is not None:
|
"weight_decay": self.args.weight_decay,
|
||||||
loss = self.compute_loss_func(
|
"lr": self.args.learning_rate,
|
||||||
outputs, labels, num_items_in_batch=num_items_in_batch
|
},
|
||||||
)
|
{
|
||||||
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
"params": [
|
||||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n)
|
||||||
else:
|
],
|
||||||
loss = self.label_smoother(outputs, labels)
|
"weight_decay": self.args.weight_decay,
|
||||||
else:
|
"lr": self.args.learning_rate / 10,
|
||||||
if isinstance(outputs, dict) and "loss" not in outputs:
|
},
|
||||||
raise ValueError(
|
{
|
||||||
"The model did not return a loss from the inputs, only the following keys: "
|
"params": [
|
||||||
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
||||||
)
|
],
|
||||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
"weight_decay": 0.0,
|
||||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
"lr": self.args.learning_rate,
|
||||||
|
},
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
|
|
||||||
def _inner_training_loop(
|
|
||||||
self,
|
|
||||||
batch_size=None,
|
|
||||||
args=None,
|
|
||||||
resume_from_checkpoint=None,
|
|
||||||
trial=None,
|
|
||||||
ignore_keys_for_eval=None,
|
|
||||||
):
|
|
||||||
self.accelerator.free_memory()
|
|
||||||
self._train_batch_size = batch_size
|
|
||||||
if self.args.auto_find_batch_size:
|
|
||||||
if self.state.train_batch_size != self._train_batch_size:
|
|
||||||
from accelerate.utils import release_memory
|
|
||||||
|
|
||||||
(self.model_wrapped,) = release_memory(self.model_wrapped)
|
|
||||||
self.model_wrapped = self.model
|
|
||||||
|
|
||||||
# Check for DeepSpeed *after* the intial pass and modify the config
|
|
||||||
if self.is_deepspeed_enabled:
|
|
||||||
# Temporarily unset `self.args.train_batch_size`
|
|
||||||
original_bs = self.args.per_device_train_batch_size
|
|
||||||
self.args.per_device_train_batch_size = (
|
|
||||||
self._train_batch_size // max(1, self.args.n_gpu)
|
|
||||||
)
|
|
||||||
self.propagate_args_to_deepspeed(True)
|
|
||||||
self.args.per_device_train_batch_size = original_bs
|
|
||||||
self.state.train_batch_size = self._train_batch_size
|
|
||||||
logger.debug(
|
|
||||||
f"Currently training with a batch size of: {self._train_batch_size}"
|
|
||||||
)
|
|
||||||
# Data loader and number of training steps
|
|
||||||
train_dataloader = self.get_train_dataloader()
|
|
||||||
if self.is_fsdp_xla_v2_enabled:
|
|
||||||
train_dataloader = tpu_spmd_dataloader(train_dataloader)
|
|
||||||
|
|
||||||
# Setting up training control variables:
|
|
||||||
# number of training epochs: num_train_epochs
|
|
||||||
# number of training steps per epoch: num_update_steps_per_epoch
|
|
||||||
# total number of training steps to execute: max_steps
|
|
||||||
total_train_batch_size = (
|
|
||||||
self._train_batch_size * args.gradient_accumulation_steps * args.world_size
|
|
||||||
)
|
|
||||||
|
|
||||||
len_dataloader = None
|
|
||||||
num_train_tokens = None
|
|
||||||
if has_length(train_dataloader):
|
|
||||||
len_dataloader = len(train_dataloader)
|
|
||||||
num_update_steps_per_epoch = (
|
|
||||||
len_dataloader // args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
|
||||||
num_examples = self.num_examples(train_dataloader)
|
|
||||||
if args.max_steps > 0:
|
|
||||||
max_steps = args.max_steps
|
|
||||||
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
|
||||||
args.max_steps % num_update_steps_per_epoch > 0
|
|
||||||
)
|
|
||||||
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
|
|
||||||
# the best we can do.
|
|
||||||
num_train_samples = args.max_steps * total_train_batch_size
|
|
||||||
if args.include_tokens_per_second:
|
|
||||||
num_train_tokens = (
|
|
||||||
self.num_tokens(train_dataloader, args.max_steps)
|
|
||||||
* args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
max_steps = math.ceil(
|
|
||||||
args.num_train_epochs * num_update_steps_per_epoch
|
|
||||||
)
|
|
||||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
|
||||||
num_train_samples = (
|
|
||||||
self.num_examples(train_dataloader) * args.num_train_epochs
|
|
||||||
)
|
|
||||||
if args.include_tokens_per_second:
|
|
||||||
num_train_tokens = (
|
|
||||||
self.num_tokens(train_dataloader) * args.num_train_epochs
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
args.max_steps > 0
|
|
||||||
): # Rely on max_steps when dataloader does not have a working size
|
|
||||||
max_steps = args.max_steps
|
|
||||||
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
|
||||||
num_train_epochs = sys.maxsize
|
|
||||||
num_update_steps_per_epoch = max_steps
|
|
||||||
num_examples = total_train_batch_size * args.max_steps
|
|
||||||
num_train_samples = args.max_steps * total_train_batch_size
|
|
||||||
if args.include_tokens_per_second:
|
|
||||||
num_train_tokens = (
|
|
||||||
self.num_tokens(train_dataloader, args.max_steps)
|
|
||||||
* args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
|
|
||||||
f" {args.max_steps}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
|
||||||
if self.args.n_gpu > 1:
|
|
||||||
# nn.DataParallel(model) replicates the model, creating new variables and module
|
|
||||||
# references registered here no longer work on other gpus, breaking the module
|
|
||||||
raise ValueError(
|
|
||||||
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
|
|
||||||
" (torchrun or torch.distributed.launch (deprecated))."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
|
||||||
|
|
||||||
delay_optimizer_creation = (
|
|
||||||
is_sagemaker_mp_enabled()
|
|
||||||
or self.is_fsdp_xla_enabled
|
|
||||||
or self.is_fsdp_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
|
||||||
if self._created_lr_scheduler:
|
|
||||||
self.lr_scheduler = None
|
|
||||||
self._created_lr_scheduler = False
|
|
||||||
|
|
||||||
if self.is_deepspeed_enabled:
|
|
||||||
self.optimizer, self.lr_scheduler = deepspeed_init(
|
|
||||||
self, num_training_steps=max_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
if not delay_optimizer_creation:
|
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
|
||||||
|
|
||||||
self.state = TrainerState(
|
|
||||||
stateful_callbacks=[
|
|
||||||
cb
|
|
||||||
for cb in self.callback_handler.callbacks + [self.control]
|
|
||||||
if isinstance(cb, ExportableState)
|
|
||||||
]
|
]
|
||||||
)
|
|
||||||
self.state.is_hyper_param_search = trial is not None
|
|
||||||
self.state.train_batch_size = self._train_batch_size
|
|
||||||
|
|
||||||
# Compute absolute values for logging, eval, and save if given as ratio
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
if args.logging_steps is not None:
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
if args.logging_steps < 1:
|
|
||||||
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
|
|
||||||
else:
|
else:
|
||||||
self.state.logging_steps = args.logging_steps
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
|
||||||
if args.eval_steps is not None:
|
|
||||||
if args.eval_steps < 1:
|
|
||||||
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
|
|
||||||
else:
|
|
||||||
self.state.eval_steps = args.eval_steps
|
|
||||||
if args.save_steps is not None:
|
|
||||||
if args.save_steps < 1:
|
|
||||||
self.state.save_steps = math.ceil(max_steps * args.save_steps)
|
|
||||||
else:
|
|
||||||
self.state.save_steps = args.save_steps
|
|
||||||
|
|
||||||
# Activate gradient checkpointing if needed
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
if args.gradient_checkpointing:
|
# e.g. for GaLore optimizer.
|
||||||
self.model.gradient_checkpointing_enable(
|
if "params" in optimizer_kwargs:
|
||||||
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
)
|
|
||||||
|
|
||||||
model = self._wrap_model(self.model_wrapped)
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
# as the model is wrapped, don't use `accelerator.prepare`
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
# this is for unhandled cases such as
|
# to avoid arguments conflicts.
|
||||||
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
use_accelerator_prepare = True if model is self.model else False
|
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
||||||
|
|
||||||
if use_accelerator_prepare and self.is_fsdp_enabled:
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
# In case of auto_find_batch_size=True
|
|
||||||
# Remove FSDP wrapping from sub-models.
|
|
||||||
self.model = unwrap_model(self.model, recursive=True)
|
|
||||||
|
|
||||||
if delay_optimizer_creation:
|
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
|
||||||
if use_accelerator_prepare:
|
import bitsandbytes
|
||||||
# configure fsdp plugin for qlora if any
|
|
||||||
self._fsdp_qlora_plugin_updates()
|
|
||||||
if self.accelerator.mixed_precision != "fp8":
|
|
||||||
self.model = self.accelerator.prepare(self.model)
|
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
|
||||||
|
|
||||||
# prepare using `accelerator` prepare
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
if use_accelerator_prepare:
|
|
||||||
self.model.train()
|
|
||||||
if hasattr(self.lr_scheduler, "step"):
|
|
||||||
if self.use_apex:
|
|
||||||
model = self.accelerator.prepare(self.model)
|
|
||||||
else:
|
|
||||||
model, self.optimizer = self.accelerator.prepare(
|
|
||||||
self.model, self.optimizer
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
|
|
||||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
|
||||||
self.model, self.optimizer, self.lr_scheduler
|
|
||||||
)
|
|
||||||
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
|
||||||
# In this case we are in DDP + LOMO, which should be supported
|
|
||||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
skipped = 0
|
||||||
self.model = self.model_wrapped = model
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
||||||
|
logger.info(f"skipped {module}: {skipped / 2**20}M params")
|
||||||
|
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||||
|
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
logger.info(f"skipped: {skipped / 2**20}M params")
|
||||||
|
|
||||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
if is_sagemaker_mp_enabled():
|
||||||
if model is not self.model:
|
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||||
self.model_wrapped = model
|
|
||||||
|
|
||||||
# backward compatibility
|
return self.optimizer
|
||||||
if self.is_deepspeed_enabled:
|
|
||||||
self.deepspeed = self.model_wrapped
|
|
||||||
|
|
||||||
# ckpt loading
|
|
||||||
if resume_from_checkpoint is not None:
|
|
||||||
if self.is_deepspeed_enabled:
|
|
||||||
deepspeed_load_checkpoint(
|
|
||||||
self.model_wrapped,
|
|
||||||
resume_from_checkpoint,
|
|
||||||
load_module_strict=not _is_peft_model(self.model),
|
|
||||||
)
|
|
||||||
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
|
|
||||||
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
|
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
|
||||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
|
||||||
|
|
||||||
# important: at this point:
|
|
||||||
# self.model is the Transformers Model
|
|
||||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
|
|
||||||
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
|
|
||||||
|
|
||||||
# Train!
|
|
||||||
logger.info("***** Running training *****")
|
|
||||||
logger.info(f" Num examples = {num_examples:,}")
|
|
||||||
logger.info(f" Num Epochs = {num_train_epochs:,}")
|
|
||||||
logger.info(
|
|
||||||
f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}"
|
|
||||||
)
|
|
||||||
if self.args.per_device_train_batch_size != self._train_batch_size:
|
|
||||||
logger.info(
|
|
||||||
f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}"
|
|
||||||
)
|
|
||||||
logger.info(f" Total optimization steps = {max_steps:,}")
|
|
||||||
logger.info(
|
|
||||||
f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.state.epoch = 0
|
|
||||||
start_time = time.time()
|
|
||||||
epochs_trained = 0
|
|
||||||
steps_trained_in_current_epoch = 0
|
|
||||||
steps_trained_progress_bar = None
|
|
||||||
|
|
||||||
# Check if continuing training from a checkpoint
|
|
||||||
if resume_from_checkpoint is not None and os.path.isfile(
|
|
||||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
|
||||||
):
|
|
||||||
self.state = TrainerState.load_from_json(
|
|
||||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
|
||||||
)
|
|
||||||
self.compare_trainer_and_checkpoint_args(self.args, self.state)
|
|
||||||
self._load_callback_state()
|
|
||||||
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
|
|
||||||
if not args.ignore_data_skip:
|
|
||||||
steps_trained_in_current_epoch = self.state.global_step % (
|
|
||||||
num_update_steps_per_epoch
|
|
||||||
)
|
|
||||||
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
|
|
||||||
else:
|
|
||||||
steps_trained_in_current_epoch = 0
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
" Continuing training from checkpoint, will skip to saved global_step"
|
|
||||||
)
|
|
||||||
logger.info(f" Continuing training from epoch {epochs_trained}")
|
|
||||||
logger.info(
|
|
||||||
f" Continuing training from global step {self.state.global_step}"
|
|
||||||
)
|
|
||||||
if not args.ignore_data_skip:
|
|
||||||
logger.info(
|
|
||||||
f" Will skip the first {epochs_trained} epochs then the first"
|
|
||||||
f" {steps_trained_in_current_epoch} batches in the first epoch."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the references
|
|
||||||
self.callback_handler.model = self.model
|
|
||||||
self.callback_handler.optimizer = self.optimizer
|
|
||||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
|
||||||
self.callback_handler.train_dataloader = train_dataloader
|
|
||||||
if self.hp_name is not None and self._trial is not None:
|
|
||||||
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
|
|
||||||
# parameter to Train when using DDP.
|
|
||||||
self.state.trial_name = self.hp_name(self._trial)
|
|
||||||
if trial is not None:
|
|
||||||
assignments = (
|
|
||||||
trial.assignments
|
|
||||||
if self.hp_search_backend == HPSearchBackend.SIGOPT
|
|
||||||
else trial
|
|
||||||
)
|
|
||||||
self.state.trial_params = hp_params(assignments)
|
|
||||||
else:
|
|
||||||
self.state.trial_params = None
|
|
||||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
|
||||||
# to set this after the load.
|
|
||||||
self.state.max_steps = max_steps
|
|
||||||
self.state.num_train_epochs = num_train_epochs
|
|
||||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
|
||||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
|
||||||
|
|
||||||
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
|
||||||
tr_loss = torch.tensor(0.0).to(args.device)
|
|
||||||
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
|
||||||
self._total_loss_scalar = 0.0
|
|
||||||
self._globalstep_last_logged = self.state.global_step
|
|
||||||
model.zero_grad()
|
|
||||||
grad_norm: Optional[float] = None
|
|
||||||
self.control = self.callback_handler.on_train_begin(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.eval_on_start:
|
|
||||||
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
|
|
||||||
|
|
||||||
for epoch in range(epochs_trained, num_train_epochs):
|
|
||||||
epoch_dataloader = train_dataloader
|
|
||||||
if hasattr(epoch_dataloader, "set_epoch"):
|
|
||||||
epoch_dataloader.set_epoch(epoch)
|
|
||||||
|
|
||||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
|
||||||
if args.past_index >= 0:
|
|
||||||
self._past = None
|
|
||||||
|
|
||||||
steps_in_epoch = (
|
|
||||||
len(epoch_dataloader)
|
|
||||||
if len_dataloader is not None
|
|
||||||
else args.max_steps * args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
self.control = self.callback_handler.on_epoch_begin(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
epoch == epochs_trained
|
|
||||||
and resume_from_checkpoint is not None
|
|
||||||
and steps_trained_in_current_epoch == 0
|
|
||||||
):
|
|
||||||
self._load_rng_state(resume_from_checkpoint)
|
|
||||||
|
|
||||||
rng_to_sync = False
|
|
||||||
steps_skipped = 0
|
|
||||||
if steps_trained_in_current_epoch > 0:
|
|
||||||
epoch_dataloader = skip_first_batches(
|
|
||||||
epoch_dataloader, steps_trained_in_current_epoch
|
|
||||||
)
|
|
||||||
steps_skipped = steps_trained_in_current_epoch
|
|
||||||
steps_trained_in_current_epoch = 0
|
|
||||||
rng_to_sync = True
|
|
||||||
|
|
||||||
step = -1
|
|
||||||
epoch_iterator = iter(epoch_dataloader)
|
|
||||||
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
|
||||||
remainder = num_examples % args.gradient_accumulation_steps
|
|
||||||
if remainder == 0:
|
|
||||||
remainder = args.gradient_accumulation_steps
|
|
||||||
update_step = -1
|
|
||||||
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
|
|
||||||
for _ in range(total_updates):
|
|
||||||
update_step += 1
|
|
||||||
num_batches = (
|
|
||||||
args.gradient_accumulation_steps
|
|
||||||
if update_step != (total_updates - 1)
|
|
||||||
else remainder
|
|
||||||
)
|
|
||||||
batch_samples, num_items_in_batch = self.get_batch_samples(
|
|
||||||
epoch_iterator, num_batches
|
|
||||||
)
|
|
||||||
for i, inputs in enumerate(batch_samples):
|
|
||||||
step += 1
|
|
||||||
do_sync_step = (
|
|
||||||
step + 1
|
|
||||||
) % args.gradient_accumulation_steps == 0 or (
|
|
||||||
step + 1
|
|
||||||
) == steps_in_epoch
|
|
||||||
# Since we perform prefetching, we need to manually set sync_gradients
|
|
||||||
if not do_sync_step:
|
|
||||||
self.accelerator.gradient_state._set_sync_gradients(False)
|
|
||||||
else:
|
|
||||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
|
||||||
|
|
||||||
if self.args.include_num_input_tokens_seen:
|
|
||||||
main_input_name = getattr(
|
|
||||||
self.model, "main_input_name", "input_ids"
|
|
||||||
)
|
|
||||||
if main_input_name not in inputs:
|
|
||||||
logger.warning(
|
|
||||||
"Tried to track the number of tokens seen, however the current model is "
|
|
||||||
"not configured properly to know what item is the input. To fix this, add "
|
|
||||||
"a `main_input_name` attribute to the model class you are using."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
input_tokens = inputs[main_input_name].numel()
|
|
||||||
input_tokens = torch.tensor(
|
|
||||||
input_tokens, device=self.args.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
self.state.num_input_tokens_seen += (
|
|
||||||
self.accelerator.gather(input_tokens).sum().cpu().item()
|
|
||||||
)
|
|
||||||
if rng_to_sync:
|
|
||||||
self._load_rng_state(resume_from_checkpoint)
|
|
||||||
rng_to_sync = False
|
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
|
||||||
if steps_trained_in_current_epoch > 0:
|
|
||||||
steps_trained_in_current_epoch -= 1
|
|
||||||
if steps_trained_progress_bar is not None:
|
|
||||||
steps_trained_progress_bar.update(1)
|
|
||||||
if steps_trained_in_current_epoch == 0:
|
|
||||||
self._load_rng_state(resume_from_checkpoint)
|
|
||||||
continue
|
|
||||||
elif steps_trained_progress_bar is not None:
|
|
||||||
steps_trained_progress_bar.close()
|
|
||||||
steps_trained_progress_bar = None
|
|
||||||
|
|
||||||
if step % args.gradient_accumulation_steps == 0:
|
|
||||||
self.control = self.callback_handler.on_step_begin(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
|
|
||||||
context = (
|
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
|
||||||
if i != len(batch_samples) - 1
|
|
||||||
and self.accelerator.distributed_type
|
|
||||||
!= DistributedType.DEEPSPEED
|
|
||||||
else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
with context():
|
|
||||||
tr_loss_step = self.training_step(
|
|
||||||
model, inputs, num_items_in_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
args.logging_nan_inf_filter
|
|
||||||
and not is_torch_xla_available()
|
|
||||||
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
|
|
||||||
):
|
|
||||||
# if loss is nan or inf simply add the average of previous logged losses
|
|
||||||
tr_loss = tr_loss + tr_loss / (
|
|
||||||
1 + self.state.global_step - self._globalstep_last_logged
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if tr_loss.device != tr_loss_step.device:
|
|
||||||
raise ValueError(
|
|
||||||
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
|
|
||||||
)
|
|
||||||
tr_loss = tr_loss + tr_loss_step
|
|
||||||
|
|
||||||
self.current_flos += float(self.floating_point_ops(inputs))
|
|
||||||
|
|
||||||
if do_sync_step:
|
|
||||||
# Since we perform prefetching, we need to manually set sync_gradients to True
|
|
||||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
|
||||||
|
|
||||||
# Gradient clipping
|
|
||||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
|
||||||
# deepspeed does its own clipping
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled() and args.fp16:
|
|
||||||
_grad_norm = self.optimizer.clip_master_grads(
|
|
||||||
args.max_grad_norm
|
|
||||||
)
|
|
||||||
elif self.use_apex:
|
|
||||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
|
||||||
_grad_norm = nn.utils.clip_grad_norm_(
|
|
||||||
amp.master_params(self.optimizer),
|
|
||||||
args.max_grad_norm,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
|
||||||
model.parameters(),
|
|
||||||
args.max_grad_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
is_accelerate_available()
|
|
||||||
and self.accelerator.distributed_type
|
|
||||||
== DistributedType.DEEPSPEED
|
|
||||||
):
|
|
||||||
grad_norm = model.get_global_grad_norm()
|
|
||||||
# In some cases the grad norm may not return a float
|
|
||||||
if hasattr(grad_norm, "item"):
|
|
||||||
grad_norm = grad_norm.item()
|
|
||||||
else:
|
|
||||||
grad_norm = _grad_norm
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_pre_optimizer_step(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_optimizer_step(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_was_run = (
|
|
||||||
not self.accelerator.optimizer_step_was_skipped
|
|
||||||
)
|
|
||||||
if optimizer_was_run:
|
|
||||||
# Delay optimizer scheduling until metrics are generated
|
|
||||||
if not isinstance(
|
|
||||||
self.lr_scheduler,
|
|
||||||
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
||||||
):
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
model.zero_grad()
|
|
||||||
self.state.global_step += 1
|
|
||||||
self.state.epoch = (
|
|
||||||
epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
|
||||||
)
|
|
||||||
self.control = self.callback_handler.on_step_end(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
self._maybe_log_save_evaluate(
|
|
||||||
tr_loss,
|
|
||||||
grad_norm,
|
|
||||||
model,
|
|
||||||
trial,
|
|
||||||
epoch,
|
|
||||||
ignore_keys_for_eval,
|
|
||||||
start_time,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.control = self.callback_handler.on_substep_end(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
|
||||||
# each step. Since we are breaking the loop early, we need to manually
|
|
||||||
# insert the mark_step here.
|
|
||||||
if (
|
|
||||||
self.control.should_epoch_stop
|
|
||||||
or self.control.should_training_stop
|
|
||||||
):
|
|
||||||
if is_torch_xla_available():
|
|
||||||
xm.mark_step()
|
|
||||||
break
|
|
||||||
# We also need to break out of the nested loop
|
|
||||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
|
||||||
if is_torch_xla_available():
|
|
||||||
xm.mark_step()
|
|
||||||
break
|
|
||||||
if step < 0:
|
|
||||||
logger.warning(
|
|
||||||
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
|
|
||||||
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
|
|
||||||
f" num_steps ({max_steps}) higher than the number of available samples."
|
|
||||||
)
|
|
||||||
self.control.should_training_stop = True
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_epoch_end(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
self._maybe_log_save_evaluate(
|
|
||||||
tr_loss,
|
|
||||||
grad_norm,
|
|
||||||
model,
|
|
||||||
trial,
|
|
||||||
epoch,
|
|
||||||
ignore_keys_for_eval,
|
|
||||||
start_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
|
||||||
if is_torch_xla_available():
|
|
||||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
|
||||||
xm.master_print(met.metrics_report())
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
|
||||||
"configured. Check your training configuration if this is unexpected."
|
|
||||||
)
|
|
||||||
if self.control.should_training_stop:
|
|
||||||
break
|
|
||||||
|
|
||||||
if args.past_index and hasattr(self, "_past"):
|
|
||||||
# Clean the state at the end of training
|
|
||||||
delattr(self, "_past")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
|
|
||||||
)
|
|
||||||
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
|
||||||
# Wait for everyone to get here so we are sure the model has been saved by process 0.
|
|
||||||
if is_torch_xla_available():
|
|
||||||
xm.rendezvous("load_best_model_at_end")
|
|
||||||
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
|
||||||
dist.barrier()
|
|
||||||
elif is_sagemaker_mp_enabled():
|
|
||||||
smp.barrier()
|
|
||||||
|
|
||||||
self._load_best_model()
|
|
||||||
|
|
||||||
# add remaining tr_loss
|
|
||||||
self._total_loss_scalar += tr_loss.item()
|
|
||||||
effective_global_step = max(
|
|
||||||
self.state.global_step, 0.001
|
|
||||||
) # Avoid ZeroDivisionError
|
|
||||||
train_loss = self._total_loss_scalar / effective_global_step
|
|
||||||
|
|
||||||
metrics = speed_metrics(
|
|
||||||
"train",
|
|
||||||
start_time,
|
|
||||||
num_samples=num_train_samples,
|
|
||||||
num_steps=self.state.max_steps,
|
|
||||||
num_tokens=num_train_tokens,
|
|
||||||
)
|
|
||||||
self.store_flos()
|
|
||||||
metrics["total_flos"] = self.state.total_flos
|
|
||||||
metrics["train_loss"] = train_loss
|
|
||||||
|
|
||||||
self.is_in_train = False
|
|
||||||
|
|
||||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
|
||||||
|
|
||||||
self.log(metrics)
|
|
||||||
|
|
||||||
run_dir = self._get_output_dir(trial)
|
|
||||||
checkpoints_sorted = self._sorted_checkpoints(
|
|
||||||
use_mtime=False, output_dir=run_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
|
|
||||||
if (
|
|
||||||
self.args.should_save
|
|
||||||
and self.state.best_model_checkpoint is not None
|
|
||||||
and self.args.save_total_limit == 1
|
|
||||||
):
|
|
||||||
for checkpoint in checkpoints_sorted:
|
|
||||||
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
|
|
||||||
logger.info(
|
|
||||||
f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit"
|
|
||||||
)
|
|
||||||
shutil.rmtree(checkpoint, ignore_errors=True)
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_end(
|
|
||||||
args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for the checkpoint to be uploaded.
|
|
||||||
self._finish_current_push()
|
|
||||||
|
|
||||||
# After training we make sure to retrieve back the original forward pass method
|
|
||||||
# for the embedding layer by removing the forward post hook.
|
|
||||||
if self.neftune_noise_alpha is not None:
|
|
||||||
self._deactivate_neftune(self.model)
|
|
||||||
|
|
||||||
return TrainOutput(self.state.global_step, train_loss, metrics)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user