Compare commits
13 Commits
9ca588224d
...
42df13390d
Author | SHA1 | Date | |
---|---|---|---|
42df13390d | |||
266fcd57ad | |||
d686cbc254 | |||
b84ebb03c7 | |||
baccca420a | |||
70c446e548 | |||
0bc1034f35 | |||
3fe2c85f6b | |||
56e46f0e0c | |||
1d2e7b9dcd | |||
c100a59f0e | |||
188ea7df6e | |||
24a6c3c114 |
19
.vscode/settings.json
vendored
19
.vscode/settings.json
vendored
@ -1,6 +1,19 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"src/transformers_repo/src/",
|
||||
"src/peft_repo/src/"
|
||||
]
|
||||
"./src/peft_repo/src/",
|
||||
"./src/transformers_repo/src/",
|
||||
],
|
||||
"python.analysis.exclude": [
|
||||
"dataset/**/*",
|
||||
],
|
||||
"python.languageServer": "Default",
|
||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||
// "python.analysis.include": [
|
||||
// "src/**/*"
|
||||
// ],
|
||||
"python.analysis.languageServerMode": "default",
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.analysis.userFileIndexingLimit": 10000,
|
||||
"python.analysis.usePullDiagnostics": false,
|
||||
"python.analysis.importFormat": "relative",
|
||||
}
|
5
dataset/.gitignore
vendored
Normal file
5
dataset/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
derek-thomas*
|
||||
*.lock
|
||||
speechcolab*
|
||||
lmms-lab*
|
||||
downloads/*
|
2
dataset/OCR-VQA-200K/.gitignore
vendored
Normal file
2
dataset/OCR-VQA-200K/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
images/*
|
||||
dataset.json
|
51
dataset/OCR-VQA-200K/download.py
Normal file
51
dataset/OCR-VQA-200K/download.py
Normal file
@ -0,0 +1,51 @@
|
||||
import os
|
||||
import json
|
||||
import urllib.request as ureq
|
||||
import urllib.error
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
# Set the file paths for your Google Drive
|
||||
dataset_path = "./dataset.json"
|
||||
images_path = "./images"
|
||||
download = 1 # Set to 0 if images are already downloaded
|
||||
|
||||
# Load dataset json file
|
||||
with open(dataset_path, "r") as fp:
|
||||
data = json.load(fp)
|
||||
|
||||
# Initialize a counter and a lock for thread-safe counting
|
||||
downloaded_count = 0
|
||||
count_lock = threading.Lock()
|
||||
|
||||
|
||||
# Function to download an image
|
||||
def download_image(k):
|
||||
global downloaded_count
|
||||
imageURL = data[k]["imageURL"]
|
||||
ext = os.path.splitext(imageURL)[1]
|
||||
outputFile = os.path.join(images_path, f"{k}{ext}")
|
||||
|
||||
# Only download the image if it doesn't exist
|
||||
if not os.path.exists(outputFile):
|
||||
try:
|
||||
ureq.urlretrieve(imageURL, outputFile)
|
||||
|
||||
with count_lock:
|
||||
downloaded_count += 1
|
||||
if downloaded_count % 100 == 0:
|
||||
print(f"{downloaded_count} images downloaded.")
|
||||
except urllib.error.URLError as e:
|
||||
print(f"Error downloading {outputFile}: {e}")
|
||||
|
||||
|
||||
# Download images using multiple threads
|
||||
if download == 1:
|
||||
if not os.path.exists(images_path):
|
||||
os.makedirs(images_path)
|
||||
|
||||
# Create a thread pool and download the images in parallel
|
||||
# Increase max_workers to potentially speed up downloads for many small files.
|
||||
# The optimal number may vary based on your network and the server's capacity.
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=400) as executor:
|
||||
executor.map(download_image, data.keys())
|
5
dataset/TextVQA/.gitignore
vendored
Normal file
5
dataset/TextVQA/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
images/test_images/*
|
||||
images/train_images/*
|
||||
TextVQA_0.5.1_test.json
|
||||
TextVQA_0.5.1_train.json
|
||||
TextVQA_0.5.1_val.json
|
3
dataset/vizwiz/Annotations/.gitignore
vendored
Normal file
3
dataset/vizwiz/Annotations/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
train.json
|
||||
test.json
|
||||
val.json
|
3
dataset/vizwiz/images/.gitignore
vendored
Normal file
3
dataset/vizwiz/images/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
val/*
|
||||
train/*
|
||||
test/*
|
@ -2,9 +2,11 @@
|
||||
dependencies = [
|
||||
"absl-py>=2.1.0",
|
||||
"accelerate==1.2.1",
|
||||
"calflops>=0.3.2",
|
||||
"datasets==3.2.0",
|
||||
"deepspeed==0.16.2",
|
||||
"evaluate==0.4.3",
|
||||
"huggingface-hub==0.30.1",
|
||||
"librosa>=0.10.2.post1",
|
||||
"markupsafe==2.1.5",
|
||||
"ms-swift>=1.3.0",
|
||||
@ -19,9 +21,9 @@ dependencies = [
|
||||
"safetensors>=0.5.2",
|
||||
"setuptools>=70.0.0",
|
||||
"soundfile>=0.13.0",
|
||||
"torch==2.5.1+cu124",
|
||||
"torchaudio==2.5.1+cu124",
|
||||
"torchvision==0.20.1+cu124",
|
||||
"torch==2.6.0",
|
||||
"torchaudio==2.6.0",
|
||||
"torchvision==0.21.0",
|
||||
"transformers==4.48.0",
|
||||
"trl==0.13.0",
|
||||
"wandb>=0.19.4",
|
||||
|
5
src/.gitignore
vendored
5
src/.gitignore
vendored
@ -1 +1,4 @@
|
||||
checkpoint/*
|
||||
checkpoint/*
|
||||
wandb/*
|
||||
test.py
|
||||
results/
|
@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_accumulation_steps: 2
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
distributed_type: DEEPSPEED
|
||||
@ -11,7 +11,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
num_processes: 3
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
@ -12,7 +12,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
@ -18,8 +18,13 @@ class GigaspeechDataset(Dataset):
|
||||
|
||||
self.audio_processor = audio_processor
|
||||
self.text_processor = text_processor
|
||||
gs = load_dataset("speechcolab/gigaspeech", "xs")
|
||||
self.data = gs[split]
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
gs = load_dataset("speechcolab/gigaspeech", "xs", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = gs[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -54,7 +59,7 @@ class GigaspeechDataset(Dataset):
|
||||
audios=[(audio, sampling_rate)],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||
@ -87,7 +92,7 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||
chat=chat,
|
||||
answer=text,
|
||||
original=sample,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_gigaspeech():
|
||||
@ -103,3 +108,21 @@ def test_gigaspeech():
|
||||
print(dataset[0])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
from .factory import register_dataset
|
||||
|
||||
dataset = {
|
||||
"train": GigaspeechDataset(split="train"),
|
||||
"test": GigaspeechDataset(split="test"),
|
||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
register_dataset(
|
||||
dataset_name="gigaspeech",
|
||||
dataset=dataset,
|
||||
tag=["audio", "text"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gigaspeech()
|
||||
|
@ -22,14 +22,20 @@ class OCRVQADataset(Dataset):
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
from .vis_processor import size_processor
|
||||
|
||||
self.vis_processor = (
|
||||
vis_processor if vis_processor is not None else self._vis_processor
|
||||
vis_processor if vis_processor is not None else size_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
|
||||
elif split == "test":
|
||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
|
||||
self.split = split
|
||||
self.ann_path = ann_path
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=3)
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
@ -64,24 +70,6 @@ class OCRVQADataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _vis_processor(self, image: Image.Image):
|
||||
width, height = image.size
|
||||
if width > 500 or height > 500:
|
||||
max_size = max(width, height)
|
||||
ratio = 500 / max_size
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
if width < 28 or height < 28:
|
||||
min_size = min(width, height)
|
||||
ratio = 28 / min_size + 1
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
return image
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image: Image.Image = Image.open(
|
||||
@ -117,7 +105,7 @@ class OCRVQADataset(Dataset):
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
@ -153,4 +141,28 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
from .factory import register_dataset
|
||||
from .format import dataset_dir as base_path
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
"generation": OCRVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])
|
||||
|
139
src/dataset_library/RefCOCODataset.py
Normal file
139
src/dataset_library/RefCOCODataset.py
Normal file
@ -0,0 +1,139 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class RefCOCODataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
vis_processor=None,
|
||||
text_processor=None,
|
||||
split: Literal["val", "test"] = "val",
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=question,
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=answer)],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class RefCOCODatasetForGeneration(RefCOCODataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"{question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_RefCOCO():
|
||||
dataset = RefCOCODataset(
|
||||
split="val",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = RefCOCODatasetForGeneration(
|
||||
split="test",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCODataset(split="val"),
|
||||
"test": RefCOCODataset(split="test"),
|
||||
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcoco",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCO()
|
139
src/dataset_library/RefCOCOPlusDataset.py
Normal file
139
src/dataset_library/RefCOCOPlusDataset.py
Normal file
@ -0,0 +1,139 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class RefCOCOplusDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
vis_processor=None,
|
||||
text_processor=None,
|
||||
split: Literal["val", "testA"] = "val",
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=question,
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=answer)],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class RefCOCOplusDatasetForGeneration(RefCOCOplusDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"{question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_RefCOCOplus():
|
||||
dataset = RefCOCOplusDataset(
|
||||
split="val",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = RefCOCOplusDatasetForGeneration(
|
||||
split="testA",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOplusDataset(split="val"),
|
||||
"test": RefCOCOplusDataset(split="testA"),
|
||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcocoplus",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCOplus()
|
139
src/dataset_library/RefCOCOgDataset.py
Normal file
139
src/dataset_library/RefCOCOgDataset.py
Normal file
@ -0,0 +1,139 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class RefCOCOgDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
vis_processor=None,
|
||||
text_processor=None,
|
||||
split: Literal["val", "test"] = "val",
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=question,
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=answer)],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class RefCOCOgDatasetForGeneration(RefCOCOgDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"{question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample,
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_RefCOCOg():
|
||||
dataset = RefCOCOgDataset(
|
||||
split="val",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = RefCOCOgDatasetForGeneration(
|
||||
split="test",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOgDataset(split="val"),
|
||||
"test": RefCOCOgDataset(split="test"),
|
||||
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcocog",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_RefCOCOg()
|
@ -6,20 +6,25 @@ from .format import (
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
|
||||
class ScienceQADataset(Dataset):
|
||||
def __init__(self, audio_processor=None, text_processor=None, split="train"):
|
||||
def __init__(self, vis_processor=None, text_processor=None, split="train"):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = audio_processor
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
ds = load_dataset("derek-thomas/ScienceQA")
|
||||
self.data = ds[split]
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -60,7 +65,7 @@ class ScienceQADataset(Dataset):
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class ScienceQADatasetForGeneration(ScienceQADataset):
|
||||
@ -98,7 +103,7 @@ class ScienceQADatasetForGeneration(ScienceQADataset):
|
||||
chat=chat,
|
||||
answer=choices[answer],
|
||||
original=sample,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_scienceQA():
|
||||
@ -116,5 +121,19 @@ def test_scienceQA():
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": ScienceQADataset(split="train"),
|
||||
"test": ScienceQADataset(split="test"),
|
||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="scienceqa",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_scienceQA()
|
||||
|
@ -25,15 +25,20 @@ class TextVQADataset(Dataset):
|
||||
vis_processor if vis_processor is not None else self._vis_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.split = split
|
||||
self.ann_path = ann_path
|
||||
self.vis_root = vis_root
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(
|
||||
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
||||
vis_root=Path(vis_root, "train_images"),
|
||||
Path(self.ann_path, "TextVQA_0.5.1_train.json"),
|
||||
vis_root=Path(self.vis_root, "train_images"),
|
||||
)
|
||||
elif split == "test":
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(
|
||||
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
||||
vis_root=Path(vis_root, "train_images"),
|
||||
Path(self.ann_path, "TextVQA_0.5.1_val.json"),
|
||||
vis_root=Path(self.vis_root, "train_images"),
|
||||
)
|
||||
|
||||
# self.instruction_pool = [
|
||||
@ -124,7 +129,7 @@ class TextVQADataset(Dataset):
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class TextVQADatasetForGeneration(TextVQADataset):
|
||||
@ -158,16 +163,29 @@ class TextVQADatasetForGeneration(TextVQADataset):
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_dataset():
|
||||
vis_root = "/home/zyy/dataset/TextVQA/images"
|
||||
ann_path = "/home/zyy/dataset/TextVQA"
|
||||
dataset = TextVQADataset(vis_root, ann_path)
|
||||
for i in range(10):
|
||||
print(dataset[i])
|
||||
from .format import dataset_dir
|
||||
|
||||
dataset = {
|
||||
"train": TextVQADataset(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="train",
|
||||
),
|
||||
"test": TextVQADataset(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
"generation": TextVQADatasetForGeneration(
|
||||
vis_root=Path(dataset_dir, "TextVQA", "images"),
|
||||
ann_path=Path(dataset_dir, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dataset()
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset("textvqa", dataset, tag=["text", "image"])
|
||||
|
168
src/dataset_library/VizWizDataset.py
Normal file
168
src/dataset_library/VizWizDataset.py
Normal file
@ -0,0 +1,168 @@
|
||||
from PIL import Image
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class VizWizDataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
self.ann_path = ann_path
|
||||
|
||||
from .vis_processor import size_processor
|
||||
|
||||
self.vis_processor = (
|
||||
vis_processor if vis_processor is not None else size_processor
|
||||
)
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
if self.split == "train":
|
||||
self.data = self.create_data(Path(self.ann_path, "train.json"))
|
||||
elif self.split == "test":
|
||||
self.data = self.create_data(Path(self.ann_path, "val.json"))
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||
# ]
|
||||
|
||||
def create_data(self, ann_path):
|
||||
processed_data = []
|
||||
with open(ann_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in range(len(data)):
|
||||
if (
|
||||
os.path.exists(os.path.join(self.vis_root, data[i]["image"]))
|
||||
and data[i]["answerable"]
|
||||
):
|
||||
imageFile = data[i]["image"]
|
||||
processed_data.append(
|
||||
{
|
||||
"question": data[i]["question"],
|
||||
"answer": data[i]["answers"][0]["answer"],
|
||||
"image_path": imageFile,
|
||||
"original": data[i],
|
||||
}
|
||||
)
|
||||
return processed_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image: Image.Image = Image.open(
|
||||
os.path.join(self.vis_root, sample["image_path"])
|
||||
).convert("RGB")
|
||||
# resize image
|
||||
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
chat=chat,
|
||||
original=sample["original"],
|
||||
images=[image],
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class VizWizDatasetForGeneration(VizWizDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||
"RGB"
|
||||
)
|
||||
# resize image
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
return DatasetOutput(
|
||||
images=[image],
|
||||
chat=chat,
|
||||
answer=answer,
|
||||
original=sample["original"],
|
||||
) # type: ignore
|
||||
|
||||
|
||||
from .format import dataset_dir
|
||||
|
||||
dataset = {
|
||||
"train": VizWizDataset(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "train"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="train",
|
||||
),
|
||||
"test": VizWizDataset(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
"generation": VizWizDatasetForGeneration(
|
||||
vis_root=Path(dataset_dir, "vizwiz", "images", "val"),
|
||||
ann_path=Path(dataset_dir, "vizwiz", "Annotations"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="vizwiz",
|
||||
dataset=dataset,
|
||||
tag=["image", "text"],
|
||||
)
|
@ -1,95 +1,78 @@
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Literal
|
||||
from typing import Literal, List
|
||||
from pathlib import Path
|
||||
from dataset_library.format import dataset_dir
|
||||
|
||||
MAPPING_NAME_TO_DATASET: dict[
|
||||
str, dict[Literal["train", "test", "generation"], Dataset]
|
||||
] = {}
|
||||
|
||||
|
||||
def register_dataset(
|
||||
dataset_name: str,
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset],
|
||||
tag: List[Literal["image", "text", "audio", "video"]] = [],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Register a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (`str`):
|
||||
The name of the dataset.
|
||||
dataset (`Dataset`):
|
||||
The dataset to register.
|
||||
"""
|
||||
dataset_name = dataset_name.lower()
|
||||
if dataset_name in MAPPING_NAME_TO_DATASET:
|
||||
raise ValueError(f"Dataset {dataset_name} already registered.")
|
||||
MAPPING_NAME_TO_DATASET[dataset_name] = dataset
|
||||
|
||||
|
||||
from .GigaspeechDataset import (
|
||||
GigaspeechDataset,
|
||||
GigaspeechDatasetForGeneration,
|
||||
)
|
||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||
from .ScienceQADataset import (
|
||||
ScienceQADataset,
|
||||
ScienceQADatasetForGeneration,
|
||||
)
|
||||
from .RefCOCODataset import (
|
||||
RefCOCODataset,
|
||||
RefCOCODatasetForGeneration,
|
||||
)
|
||||
from .RefCOCOgDataset import (
|
||||
RefCOCOgDataset,
|
||||
RefCOCOgDatasetForGeneration,
|
||||
)
|
||||
from .RefCOCOPlusDataset import (
|
||||
RefCOCOplusDataset,
|
||||
RefCOCOplusDatasetForGeneration,
|
||||
)
|
||||
from .VizWizDataset import (
|
||||
VizWizDataset,
|
||||
VizWizDatasetForGeneration,
|
||||
)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
dataset_name, base_path="/home/zyy/dataset"
|
||||
dataset_name: str, base_path=dataset_dir
|
||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||
match dataset_name:
|
||||
case "ocrvqa200k":
|
||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
"""
|
||||
Get a dataset by name.
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
"generation": OCRVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
case "chem":
|
||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": ChemDataset(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="train",
|
||||
),
|
||||
"test": ChemDataset(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="test",
|
||||
),
|
||||
"generation": ChemDatasetForGeneration(
|
||||
vis_root=Path(base_path, "chem", "images"),
|
||||
ann_path=Path(base_path, "chem"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
case "gigaspeech":
|
||||
from .GigaspeechDataset import (
|
||||
GigaspeechDataset,
|
||||
GigaspeechDatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": GigaspeechDataset(split="train"),
|
||||
"test": GigaspeechDataset(split="test"),
|
||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
case "textvqa":
|
||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": TextVQADataset(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="train",
|
||||
),
|
||||
"test": TextVQADataset(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
"generation": TextVQADatasetForGeneration(
|
||||
vis_root=Path(base_path, "TextVQA", "images"),
|
||||
ann_path=Path(base_path, "TextVQA"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
case "scienceqa":
|
||||
from .ScienceQADataset import (
|
||||
ScienceQADataset,
|
||||
ScienceQADatasetForGeneration,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"train": ScienceQADataset(split="train"),
|
||||
"test": ScienceQADataset(split="test"),
|
||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
return dataset
|
||||
Args:
|
||||
dataset_name (`str`):
|
||||
The name of the dataset.
|
||||
base_path (`str`):
|
||||
The base path to the dataset.
|
||||
"""
|
||||
dataset_name = dataset_name.lower()
|
||||
if dataset_name not in MAPPING_NAME_TO_DATASET:
|
||||
raise ValueError(f"Dataset {dataset_name} not registered.")
|
||||
for key in MAPPING_NAME_TO_DATASET[dataset_name]:
|
||||
MAPPING_NAME_TO_DATASET[dataset_name][key].load_data()
|
||||
return MAPPING_NAME_TO_DATASET[dataset_name]
|
||||
|
@ -1,6 +1,9 @@
|
||||
from typing import Any, Tuple, TypedDict, Literal, Optional
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
dataset_dir = Path(__file__).resolve().parent.parent.parent / "dataset"
|
||||
|
||||
|
||||
class ConverstationText(TypedDict):
|
||||
@ -25,8 +28,8 @@ class Conversation(TypedDict):
|
||||
|
||||
|
||||
class DatasetOutput(TypedDict):
|
||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
||||
chat: list[Conversation]
|
||||
answer: Optional[str]
|
||||
original: Any
|
||||
images: Optional[list[Image.Image]]
|
||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
||||
|
@ -1,46 +1,22 @@
|
||||
from .factory import get_dataset
|
||||
import pytest
|
||||
from .factory import get_dataset, MAPPING_NAME_TO_DATASET
|
||||
|
||||
# Get all registered dataset names for parameterization
|
||||
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
||||
|
||||
|
||||
def test_gigaspeech():
|
||||
dataset = get_dataset("gigaspeech")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
@pytest.mark.parametrize("dataset_name", dataset_names)
|
||||
def test_registered_datasets(dataset_name):
|
||||
dataset = get_dataset(dataset_name)
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
# Test train split
|
||||
assert "train" in dataset, f"Train split not found in {dataset_name}"
|
||||
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
|
||||
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
|
||||
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
|
||||
|
||||
|
||||
def test_chem():
|
||||
dataset = get_dataset("chem")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_ocrvqa200k():
|
||||
dataset = get_dataset("ocrvqa200k")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_textvqa():
|
||||
dataset = get_dataset("textvqa")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_scienceqa():
|
||||
dataset = get_dataset("scienceqa")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
# Test test split
|
||||
assert "test" in dataset, f"Test split not found in {dataset_name}"
|
||||
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
|
||||
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
|
||||
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore
|
||||
|
20
src/dataset_library/vis_processor.py
Normal file
20
src/dataset_library/vis_processor.py
Normal file
@ -0,0 +1,20 @@
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def size_processor(image: Image.Image):
|
||||
width, height = image.size
|
||||
if width > 500 or height > 500:
|
||||
max_size = max(width, height)
|
||||
ratio = 500 / max_size
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
if width < 28 or height < 28:
|
||||
min_size = min(width, height)
|
||||
ratio = 28 / min_size + 1
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
return image
|
@ -19,67 +19,39 @@ from trl import (
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
from utils.args import (
|
||||
ContinualScriptArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser(
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||
(
|
||||
ContinualScriptArguments,
|
||||
TrainingArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
)
|
||||
)
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 0 == 1:
|
||||
if TYPE_CHECKING:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ModelConfig()
|
||||
model_args = ContinualModelConfig()
|
||||
reg_args = ContinualRegularizationArguments()
|
||||
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
from transformers import (
|
||||
Qwen2VLProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModel,
|
||||
)
|
||||
from peft.peft_model import PeftModelForCausalLM
|
||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# from peft_library import get_peft_model
|
||||
|
||||
processor = Qwen2VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
from model_library.qwen2vl import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||
model_args=model_args, training_args=training_args
|
||||
)
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
@ -107,6 +79,6 @@ if __name__ == "__main__":
|
||||
collate_fn=collate_fn_for_evaluate,
|
||||
)
|
||||
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
||||
from utils.evaluate_tool import evaluate_rouge, evalute_save
|
||||
from utils.evaluate_tool import evaluate_rouge, evaluate_save
|
||||
|
||||
evalute_save(model, val_dataloader, processor, accelerator)
|
||||
evaluate_save(model, val_dataloader, processor, accelerator)
|
||||
|
@ -5,9 +5,12 @@ from trl import (
|
||||
get_quantization_config,
|
||||
)
|
||||
from utils.args import ContinualModelConfig
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
def get_model(model_args: ContinualModelConfig):
|
||||
def get_model(
|
||||
model_args: ContinualModelConfig, training_args: TrainingArguments = None
|
||||
):
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
@ -26,12 +29,20 @@ def get_model(model_args: ContinualModelConfig):
|
||||
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
if training_args is not None:
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
processor = Qwen2VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
@ -49,11 +60,18 @@ def get_model(model_args: ContinualModelConfig):
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
||||
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
||||
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
if training_args is not None:
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
processor = Qwen2AudioProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
@ -68,4 +86,63 @@ def get_model(model_args: ContinualModelConfig):
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
if training_args is not None:
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
||||
from transformers.models.qwen2_5_omni import (
|
||||
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||
Qwen2_5OmniProcessor,
|
||||
)
|
||||
|
||||
if training_args is not None:
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
||||
|
@ -1,7 +1,8 @@
|
||||
from transformers import Qwen2AudioProcessor
|
||||
from dataset_library.format import Conversation
|
||||
|
||||
|
||||
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
def collate_fn_for_train(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
@ -60,7 +61,7 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
def collate_fn_for_evaluate(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
@ -91,3 +92,4 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
# answers_ids torch.Size([3, 10])
|
||||
# answers_mask torch.Size([3, 10])
|
||||
return batch
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||
from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
# from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
__all__ = [
|
||||
"collate_fn_for_train",
|
||||
|
@ -1,9 +1,22 @@
|
||||
from transformers import Qwen2VLProcessor
|
||||
# from transformers import Qwen2VLProcessor
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "transformers_repo/src/")
|
||||
sys.path.insert(0, "peft_repo/src/")
|
||||
import transformers
|
||||
import peft
|
||||
from dataset_library.format import DatasetOutput
|
||||
|
||||
import torch
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Qwen2VLProcessor
|
||||
from transformers import Qwen2_5_VLProcessor
|
||||
|
||||
|
||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: "Qwen2VLProcessor"):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
@ -32,36 +45,36 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
|
||||
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||
# 151644 151645 8948 872 77091 None 151643
|
||||
|
||||
# for i, label in enumerate(labels):
|
||||
# now_index = 0
|
||||
# while now_index < len(label):
|
||||
# if label[now_index] == im_start_token_id:
|
||||
# label[now_index] = -100
|
||||
# now_index += 1
|
||||
# if (
|
||||
# label[now_index] == system_token_id
|
||||
# or label[now_index] == user_token_id
|
||||
# ):
|
||||
# while label[now_index] != im_end_token_id:
|
||||
# label[now_index] = -100
|
||||
# now_index += 1
|
||||
# label[now_index] = -100
|
||||
# elif label[now_index] == assistant_token_id:
|
||||
# label[now_index] = -100
|
||||
# label[now_index + 1] = -100
|
||||
# now_index += 2
|
||||
# while (
|
||||
# now_index < len(label) and label[now_index] != im_end_token_id
|
||||
# ):
|
||||
# now_index += 1
|
||||
# now_index += 1
|
||||
for i, label in enumerate(labels):
|
||||
now_index = 0
|
||||
while now_index < len(label):
|
||||
if label[now_index] == im_start_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
if (
|
||||
label[now_index] == system_token_id
|
||||
or label[now_index] == user_token_id
|
||||
):
|
||||
while label[now_index] != im_end_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
label[now_index] = -100
|
||||
elif label[now_index] == assistant_token_id:
|
||||
label[now_index] = -100
|
||||
label[now_index + 1] = -100
|
||||
now_index += 2
|
||||
while (
|
||||
now_index < len(label) and label[now_index] != im_end_token_id
|
||||
):
|
||||
now_index += 1
|
||||
now_index += 1
|
||||
batch["labels"] = labels
|
||||
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
def collate_fn_for_evaluate(examples, processor: "Qwen2VLProcessor"):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
@ -89,3 +102,68 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
# answers_ids torch.Size([3, 10])
|
||||
# answers_mask torch.Size([3, 10])
|
||||
return batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
from PIL import Image
|
||||
|
||||
# 随机生成一个图片
|
||||
import numpy as np
|
||||
|
||||
random_image = Image.fromarray(
|
||||
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
)
|
||||
example = {
|
||||
"chat": [
|
||||
# {"role": "user", "content": "What is the capital of France?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "", # Assuming no image for this example
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is the capital of France?",
|
||||
},
|
||||
],
|
||||
}, # Assuming no image for this example
|
||||
{"role": "assistant", "content": "The capital of France is Paris."},
|
||||
],
|
||||
"images": [
|
||||
random_image,
|
||||
random_image,
|
||||
random_image,
|
||||
], # Assuming no images for this example
|
||||
}
|
||||
batch = collate_fn_for_train([example], processor)
|
||||
# print(batch)
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
print(f"{k}: {v.shape}")
|
||||
else:
|
||||
print(f"{k}: {v}")
|
||||
# input_ids: torch.Size([1, 101])
|
||||
# attention_mask: torch.Size([1, 101])
|
||||
# pixel_values: torch.Size([256, 1176])
|
||||
# image_grid_thw: torch.Size([1, 3])
|
||||
# labels: torch.Size([1, 101])
|
||||
# Load model directly
|
||||
from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
|
||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-3B", torch_dtype="auto", device_map="auto")
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")
|
||||
|
||||
|
||||
|
@ -73,7 +73,6 @@ from peft.tuners import (
|
||||
from .tuners import MMOELoraModel, MMOELoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
from peft.utils import _prepare_prompt_learning_config
|
||||
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -46,7 +46,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers.utils import PushToHubMixin
|
||||
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG
|
||||
|
||||
from peft import __version__
|
||||
from peft.config import PeftConfig
|
||||
|
19
src/peft_library/regularizations/__init__.py
Normal file
19
src/peft_library/regularizations/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
class RegularizationMethod:
|
||||
"""RegularizationMethod implement regularization strategies.
|
||||
RegularizationMethod is a callable.
|
||||
The method `update` is called to update the loss, typically at the end
|
||||
of an experience.
|
||||
"""
|
||||
|
||||
def pre_adapt(self, agent, exp):
|
||||
pass # implementation may be empty if adapt is not needed
|
||||
|
||||
def post_adapt(self, agent, exp):
|
||||
pass # implementation may be empty if adapt is not needed
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
from .ewc import EWC
|
||||
from .lwf import LWF
|
58
src/peft_library/regularizations/ewc.py
Normal file
58
src/peft_library/regularizations/ewc.py
Normal file
@ -0,0 +1,58 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
|
||||
class EWC(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
The method applies knowledge distilllation to mitigate forgetting.
|
||||
The teacher is the model checkpoint after the last experience.
|
||||
"""
|
||||
|
||||
def __init__(self, EWC_lambda=1, temperature=2):
|
||||
"""
|
||||
:param alpha: distillation hyperparameter. It can be either a float
|
||||
number or a list containing alpha for each experience.
|
||||
:param temperature: softmax temperature for distillation
|
||||
"""
|
||||
self.EWC_lambda = EWC_lambda
|
||||
self.temperature = temperature
|
||||
self.fisher = {}
|
||||
self.optpar = {}
|
||||
""" In Avalanche, targets of different experiences are not ordered.
|
||||
As a result, some units may be allocated even though their
|
||||
corresponding class has never been seen by the model.
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
|
||||
def adapt(self, output, model, **kwargs):
|
||||
ewc_loss = 0
|
||||
for n, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
dev = p.device
|
||||
l = (
|
||||
self.EWC_lambda
|
||||
* self.fisher[n].to(dev)
|
||||
* (p.data - self.optpar[n].to(dev)).pow(2)
|
||||
)
|
||||
ewc_loss += l.sum()
|
||||
output["loss"] += ewc_loss
|
||||
return output
|
||||
|
||||
def init_epoch(self, model):
|
||||
"""Update the previous logits for the given question id."""
|
||||
optpar = {}
|
||||
fisher = {}
|
||||
for n, p in model.module.base_model.model.named_parameters():
|
||||
if p.requires_grad:
|
||||
fisher[n] = torch.zeros(p.data.shape)
|
||||
optpar[n] = p.clone().cpu().data
|
||||
|
||||
def update_fisher(self, model):
|
||||
"""Update the fisher information for the given question id."""
|
||||
for n, p in model.module.base_model.model.named_parameters():
|
||||
if p.requires_grad:
|
||||
fisher = self.fisher[n]
|
||||
fisher += p.grad.data.pow(2).cpu()
|
||||
self.fisher[n] = fisher
|
67
src/peft_library/regularizations/lwf.py
Normal file
67
src/peft_library/regularizations/lwf.py
Normal file
@ -0,0 +1,67 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
|
||||
class LWF(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
The method applies knowledge distilllation to mitigate forgetting.
|
||||
The teacher is the model checkpoint after the last experience.
|
||||
"""
|
||||
|
||||
def __init__(self, LWF_lambda=1, temperature=2):
|
||||
"""
|
||||
:param alpha: distillation hyperparameter. It can be either a float
|
||||
number or a list containing alpha for each experience.
|
||||
:param temperature: softmax temperature for distillation
|
||||
"""
|
||||
self.LWF_lambda = LWF_lambda
|
||||
self.temperature = temperature
|
||||
self.previous_logits = {}
|
||||
""" In Avalanche, targets of different experiences are not ordered.
|
||||
As a result, some units may be allocated even though their
|
||||
corresponding class has never been seen by the model.
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
|
||||
def adapt(self, output, **kwargs):
|
||||
def modified_kl_div(old, new):
|
||||
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
||||
|
||||
def smooth(logits, temp, dim):
|
||||
log = logits ** (1 / temp)
|
||||
return log / torch.sum(log, dim).unsqueeze(1)
|
||||
|
||||
lwf_loss = []
|
||||
|
||||
soft = torch.nn.Softmax(dim=1)
|
||||
|
||||
previous_keys = self.previous_logits.keys()
|
||||
|
||||
for index, question_id in enumerate(iterable=kwargs["question_ids"]):
|
||||
if question_id in previous_keys:
|
||||
previous_logits = self.previous_logits[question_id]
|
||||
current_logits = output["logits"][index]
|
||||
short_index = min(len(previous_logits), len(current_logits))
|
||||
previous_logits = previous_logits[:short_index]
|
||||
current_logits = current_logits[:short_index]
|
||||
lwf_loss.append(
|
||||
modified_kl_div(
|
||||
old=smooth(
|
||||
logits=soft(previous_logits).to(current_logits.device),
|
||||
temp=2,
|
||||
dim=1,
|
||||
),
|
||||
new=smooth(logits=soft(current_logits), temp=2, dim=1),
|
||||
)
|
||||
)
|
||||
if len(lwf_loss) > 0:
|
||||
output["loss"] += self.LWF_lambda * torch.stack(
|
||||
tensors=lwf_loss, dim=0
|
||||
).sum(dim=0)
|
||||
return output
|
||||
|
||||
def update_previous_logits(self, question_id, logits):
|
||||
"""Update the previous logits for the given question id."""
|
||||
self.previous_logits[question_id] = logits
|
@ -1 +1 @@
|
||||
Subproject commit ebab283576fc8803314314b5e1e4331c424a2198
|
||||
Subproject commit a6e39fafb4e3ccce27671ce14afe62fb754c83c2
|
15
src/scripts/eval_omni.sh
Executable file
15
src/scripts/eval_omni.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml evaluation.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type MOELORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--output_dir ./checkpoint/qwen2_5omni_moelora/ \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16
|
||||
# --eval_strategy epoch \
|
25
src/scripts/train_omni_moelora.sh
Executable file
25
src/scripts/train_omni_moelora.sh
Executable file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type MOELORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir checkpoint/qwen2_5omni_moelora/ \
|
||||
--learning_rate 2e-4 \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type cosine \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 300 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.1 \
|
||||
--eval_strategy steps \
|
||||
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500
|
25
src/scripts/train_omni_olora.sh
Executable file
25
src/scripts/train_omni_olora.sh
Executable file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type OLORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir checkpoint/qwen2_5omni_olora/ \
|
||||
--learning_rate 2e-4 \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type cosine \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 300 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.1 \
|
||||
--eval_strategy steps \
|
||||
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500
|
31
src/test_evalutae.py
Normal file
31
src/test_evalutae.py
Normal file
@ -0,0 +1,31 @@
|
||||
import evaluate
|
||||
|
||||
# Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, and CIDEr
|
||||
example = {
|
||||
"generated": "The cat sat on the mat.",
|
||||
"target": "The cat is sitting on the mat.",
|
||||
"original": "The cat is sitting on the mat.",
|
||||
}
|
||||
evaluate_bleu = evaluate.load("bleu")
|
||||
evaluate_rouge = evaluate.load("rouge")
|
||||
evaluate_meteor = evaluate.load("meteor")
|
||||
|
||||
evaluate_bleu.add_batch(
|
||||
predictions=[example["generated"]],
|
||||
references=[[example["target"]]],
|
||||
)
|
||||
evaluate_rouge.add_batch(
|
||||
predictions=[example["generated"]],
|
||||
references=[[example["target"]]],
|
||||
)
|
||||
evaluate_meteor.add_batch(
|
||||
predictions=[example["generated"]],
|
||||
references=[[example["target"]]],
|
||||
)
|
||||
|
||||
bleu = evaluate_bleu.compute()
|
||||
rouge = evaluate_rouge.compute()
|
||||
meteor = evaluate_meteor.compute()
|
||||
|
||||
comprehensive_results = sum(bleu['precisions']) + rouge['rougeL'] + meteor['meteor']
|
||||
print("Comprehensive Results:", comprehensive_results/6)
|
27
src/todo.md
27
src/todo.md
@ -17,6 +17,31 @@
|
||||
|
||||
[2025.01.19]
|
||||
|
||||
- [ ] 多个数据集引入
|
||||
- [x] 多个数据集引入
|
||||
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
||||
- [ ] 引入EWC和LWF
|
||||
|
||||
[2025.05.15]
|
||||
|
||||
- [x] vizwiz处理
|
||||
|
||||
[2025.05.16]
|
||||
|
||||
- [ ] 处理不同的持续学习框架,使得整体框架能够兼容
|
||||
|
||||
[2025.05.28]
|
||||
|
||||
- [x] MoeLora
|
||||
- [ ] Coin Benchmark
|
||||
- [x] 确定保存什么,便于后期测试
|
||||
- [x] Olora (非实现问题,loss越来越高,感觉很难训练)
|
||||
- [ ] Hide-Llava(复写基类引入clip,不同的adapter做平均,loralinear根据不同的name做插入top layer或正常layer,模型要求接受传入task_id即clip计算的最大相似)
|
||||
- [ ] Hide-llava问题,前些层平均fusion很没有道理,后些层的moe处理,却整整引入了clip的计算量,(任务数确定task数量,使得一些方法没有扩展性)。现实场景要求:没法知道后面还有多少个数据集,然后减少遗忘,最好能够对后续未见数据集产生效果,moelora问题只能适当缓解,利用不同的参数承接不同的任务。 那这个benchmark,每次输入保留数据,baseline是进一个把之前所有的都训练一边,持续学习方法使用update的方式,比较不同数据集按批次输入的收益(找函数定义[How Efficient Are Today’s Continual Learning Algorithms?],[]),也就是准确度的积分。
|
||||
|
||||
[2025.05.30]
|
||||
|
||||
- [x] 评价指标
|
||||
|
||||
[2025.06.03]
|
||||
|
||||
- [ ] 预期算法,低计算成本,
|
||||
|
60
src/train.py
60
src/train.py
@ -13,30 +13,42 @@ from trl import (
|
||||
TrlParser,
|
||||
)
|
||||
from utils.trainer import ContinualTrainer
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
from utils.args import (
|
||||
ContinualScriptArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
)
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser(
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||
(
|
||||
ContinualScriptArguments,
|
||||
TrainingArguments,
|
||||
ContinualModelConfig,
|
||||
ContinualRegularizationArguments,
|
||||
) # type: ignore
|
||||
)
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 0 == 1:
|
||||
if TYPE_CHECKING:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ContinualModelConfig()
|
||||
reg_args = ContinualRegularizationArguments()
|
||||
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||
model_args
|
||||
model_args=model_args
|
||||
)
|
||||
################
|
||||
# Dataset
|
||||
@ -47,12 +59,19 @@ if __name__ == "__main__":
|
||||
accelerator = create_accelerator_and_postprocess(training_args)
|
||||
|
||||
if model_args.peft_type == "MMOELORA":
|
||||
from peft_library.tuners import MMOELoraConfig
|
||||
from peft.tuners import MMOELoraConfig
|
||||
|
||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
elif model_args.peft_type == "MOELORA":
|
||||
from peft.tuners import MOELoraConfig
|
||||
|
||||
peft_config = MOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
elif model_args.peft_type == "LORA":
|
||||
from peft.tuners.lora import LoraConfig
|
||||
|
||||
@ -64,10 +83,25 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
elif model_args.peft_type == "OLORA":
|
||||
from peft.tuners import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(
|
||||
target_modules=model_args.lora_target_modules,
|
||||
r=model_args.lora_r,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
init_lora_weights="olora"
|
||||
)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
else:
|
||||
peft_config = None
|
||||
|
||||
from peft import get_peft_model
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print(model)
|
||||
|
||||
@ -79,19 +113,21 @@ if __name__ == "__main__":
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn_for_train,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
dataset[script_args.dataset_test_split] # type: ignore
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
accelerator=accelerator,
|
||||
reg_args=reg_args,
|
||||
)
|
||||
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print("Saving model")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
# trainer.save_model(training_args.output_dir)
|
||||
model.save_pretrained(training_args.output_dir)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print("Model saved")
|
||||
@ -109,6 +145,6 @@ if __name__ == "__main__":
|
||||
# )
|
||||
# val_dataloader = accelerator.prepare(val_dataloader)
|
||||
|
||||
# from utils.evaluate_tool import evaluate_rouge
|
||||
# from utils.evaluate_tool import evaluate_save
|
||||
|
||||
# evaluate_rouge(model, val_dataloader, processor)
|
||||
# evaluate_save(model, val_dataloader, processor, accelerator)
|
||||
|
15
src/train.sh
15
src/train.sh
@ -1,18 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name chem \
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name textvqa \
|
||||
--use_peft \
|
||||
--peft_type LORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--peft_type MOELORA \
|
||||
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--per_device_train_batch_size 3 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir checkpoint/qwen2_alllinear/ \
|
||||
--learning_rate 1e-4 \
|
||||
--learning_rate 5e-5 \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 30 \
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 8ee1a4eadda1d83cf65c024fe54364b5bd74e55f
|
||||
Subproject commit 42a8639e1e827d6f0ab07d87078ff048b20dab19
|
@ -18,3 +18,16 @@ class ContinualModelConfig(ModelConfig):
|
||||
"""Model configuration for continual learning."""
|
||||
|
||||
peft_type: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContinualRegularizationArguments:
|
||||
"""Regularization arguments for continual learning."""
|
||||
|
||||
# EWC
|
||||
ewc_lambda: float = 0.0
|
||||
ewc_enable: bool = False
|
||||
|
||||
# LWF
|
||||
lwf_lambda: float = 0.0
|
||||
lwf_enable: bool = False
|
||||
|
@ -1,5 +1,6 @@
|
||||
import evaluate
|
||||
from accelerate import Accelerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||
@ -7,10 +8,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
||||
|
||||
for batch in val_dataloader:
|
||||
completion = model.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
pixel_values=batch["pixel_values"],
|
||||
image_grid_thw=batch["image_grid_thw"],
|
||||
**batch,
|
||||
max_length=1000,
|
||||
)
|
||||
target = batch["answers_ids"]
|
||||
@ -27,7 +25,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
||||
print(glue.compute())
|
||||
|
||||
|
||||
def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||
def evaluate_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||
import os
|
||||
|
||||
mtime = 0
|
||||
@ -53,6 +51,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
answers = []
|
||||
completion = model.generate(
|
||||
**batch,
|
||||
# max_new_tokens=30,
|
||||
max_length=1000,
|
||||
)
|
||||
generated_text = [
|
||||
@ -63,20 +62,17 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
generated_text, skip_special_tokens=True
|
||||
)
|
||||
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
||||
for i in range(len(generated_text)):
|
||||
answers.append(
|
||||
{
|
||||
"generated": generated_text[i],
|
||||
"target": target_text[i],
|
||||
"original": str(origianl[i]),
|
||||
}
|
||||
)
|
||||
import json
|
||||
|
||||
world_size = accelerator.process_index
|
||||
|
||||
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
||||
for answer in answers:
|
||||
for i in range(len(generated_text)):
|
||||
answer = {
|
||||
"generated": generated_text[i],
|
||||
"target": target_text[i],
|
||||
"original": str(origianl[i]),
|
||||
}
|
||||
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
||||
f.write(json.dumps(answer) + "\n")
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
@ -97,3 +93,71 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
# delete file
|
||||
for file in files:
|
||||
os.remove(f"results/{mtime}/{file}")
|
||||
|
||||
|
||||
def evaluate_from_jsonl_directory(directory_path):
|
||||
"""
|
||||
从指定目录读取所有jsonl文件并计算综合评估结果
|
||||
|
||||
Args:
|
||||
directory_path: 包含jsonl文件的目录路径
|
||||
|
||||
Returns:
|
||||
dict: 包含各项指标和综合结果的字典
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
||||
# 初始化评估器
|
||||
evaluate_bleu = evaluate.load("bleu")
|
||||
evaluate_rouge = evaluate.load("rouge")
|
||||
evaluate_meteor = evaluate.load("meteor")
|
||||
|
||||
# 读取目录下所有jsonl文件
|
||||
all_data = []
|
||||
for file in os.listdir(directory_path):
|
||||
if file.endswith(".jsonl"):
|
||||
file_path = os.path.join(directory_path, file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
data = json.loads(line)
|
||||
all_data.append(data)
|
||||
|
||||
if not all_data:
|
||||
print(f"未在目录 {directory_path} 中找到有效的jsonl数据")
|
||||
return None
|
||||
|
||||
# 准备数据
|
||||
predictions = [item["generated"] for item in all_data]
|
||||
references = [[item["target"]] for item in all_data]
|
||||
|
||||
# 批量添加数据
|
||||
evaluate_bleu.add_batch(predictions=predictions, references=references)
|
||||
evaluate_rouge.add_batch(predictions=predictions, references=references)
|
||||
evaluate_meteor.add_batch(predictions=predictions, references=references)
|
||||
|
||||
# 计算结果
|
||||
bleu = evaluate_bleu.compute()
|
||||
rouge = evaluate_rouge.compute()
|
||||
meteor = evaluate_meteor.compute()
|
||||
|
||||
# 计算综合结果
|
||||
comprehensive_score = (sum(bleu["precisions"]) + rouge["rougeL"] + meteor["meteor"]) / 6
|
||||
|
||||
results = {
|
||||
"bleu": bleu,
|
||||
"rouge": rouge,
|
||||
"meteor": meteor,
|
||||
"comprehensive_score": comprehensive_score,
|
||||
"total_samples": len(all_data),
|
||||
}
|
||||
|
||||
print(f"评估完成,共处理 {len(all_data)} 条数据")
|
||||
print(f"BLEU分数: {bleu}")
|
||||
print(f"ROUGE分数: {rouge}")
|
||||
print(f"METEOR分数: {meteor}")
|
||||
print(f"综合分数: {comprehensive_score}")
|
||||
|
||||
return results
|
||||
|
@ -1,26 +1,73 @@
|
||||
# _________________________________________________________
|
||||
|
||||
|
||||
from transformers.trainer import (
|
||||
Trainer,
|
||||
_is_peft_model,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
tpu_spmd_dataloader,
|
||||
logger,
|
||||
has_length,
|
||||
sys,
|
||||
)
|
||||
from transformers.trainer import *
|
||||
from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from .args import ContinualRegularizationArguments
|
||||
from peft_library.regularizations import EWC, LWF
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
|
||||
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
|
||||
logits = outputs.logits
|
||||
device = logits.device
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:].to(device)
|
||||
# Save memory
|
||||
masks = shift_labels != -100
|
||||
shift_logits = shift_logits[masks]
|
||||
shift_labels = shift_labels[masks]
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="none")
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if num_items_in_batch is None:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
# compat transformers>=4.46
|
||||
loss = loss.sum() / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
||||
class ContinualTrainer(Trainer):
|
||||
def __init__(
|
||||
self, model, args, data_collator, train_dataset, eval_dataset, accelerator
|
||||
self,
|
||||
model,
|
||||
args: TrainingArguments,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
accelerator,
|
||||
reg_args: ContinualRegularizationArguments = None,
|
||||
):
|
||||
self.accelerator = accelerator
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
# compute_loss_func=ce_loss_func,
|
||||
)
|
||||
|
||||
if reg_args.ewc_enable:
|
||||
self.ewc_lambda = reg_args.ewc_lambda
|
||||
from peft_library.regularizations.ewc import EWC
|
||||
|
||||
self.EWC = EWC()
|
||||
# fisher = t
|
||||
|
||||
if reg_args.lwf_enable:
|
||||
self.lwf_lambda = reg_args.lwf_lambda
|
||||
from peft_library.regularizations.lwf import LWF
|
||||
|
||||
self.LWF = LWF()
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
|
||||
if self.accelerator is not None:
|
||||
self.is_deepspeed_enabled = (
|
||||
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
@ -32,718 +79,79 @@ class ContinualTrainer(Trainer):
|
||||
return
|
||||
else:
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
|
||||
def create_optimizer(self):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
Setup the optimizer.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
||||
"""
|
||||
if (
|
||||
self.label_smoother is not None or self.compute_loss_func is not None
|
||||
) and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if labels is not None:
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
if _is_peft_model(unwrapped_model):
|
||||
model_name = unwrapped_model.base_model.model._get_name()
|
||||
else:
|
||||
model_name = unwrapped_model._get_name()
|
||||
# User-defined compute_loss function
|
||||
if self.compute_loss_func is not None:
|
||||
loss = self.compute_loss_func(
|
||||
outputs, labels, num_items_in_batch=num_items_in_batch
|
||||
)
|
||||
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||
else:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
else:
|
||||
if isinstance(outputs, dict) and "loss" not in outputs:
|
||||
raise ValueError(
|
||||
"The model did not return a loss from the inputs, only the following keys: "
|
||||
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
||||
)
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
def _inner_training_loop(
|
||||
self,
|
||||
batch_size=None,
|
||||
args=None,
|
||||
resume_from_checkpoint=None,
|
||||
trial=None,
|
||||
ignore_keys_for_eval=None,
|
||||
):
|
||||
self.accelerator.free_memory()
|
||||
self._train_batch_size = batch_size
|
||||
if self.args.auto_find_batch_size:
|
||||
if self.state.train_batch_size != self._train_batch_size:
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
(self.model_wrapped,) = release_memory(self.model_wrapped)
|
||||
self.model_wrapped = self.model
|
||||
|
||||
# Check for DeepSpeed *after* the intial pass and modify the config
|
||||
if self.is_deepspeed_enabled:
|
||||
# Temporarily unset `self.args.train_batch_size`
|
||||
original_bs = self.args.per_device_train_batch_size
|
||||
self.args.per_device_train_batch_size = (
|
||||
self._train_batch_size // max(1, self.args.n_gpu)
|
||||
)
|
||||
self.propagate_args_to_deepspeed(True)
|
||||
self.args.per_device_train_batch_size = original_bs
|
||||
self.state.train_batch_size = self._train_batch_size
|
||||
logger.debug(
|
||||
f"Currently training with a batch size of: {self._train_batch_size}"
|
||||
)
|
||||
# Data loader and number of training steps
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
if self.is_fsdp_xla_v2_enabled:
|
||||
train_dataloader = tpu_spmd_dataloader(train_dataloader)
|
||||
|
||||
# Setting up training control variables:
|
||||
# number of training epochs: num_train_epochs
|
||||
# number of training steps per epoch: num_update_steps_per_epoch
|
||||
# total number of training steps to execute: max_steps
|
||||
total_train_batch_size = (
|
||||
self._train_batch_size * args.gradient_accumulation_steps * args.world_size
|
||||
)
|
||||
|
||||
len_dataloader = None
|
||||
num_train_tokens = None
|
||||
if has_length(train_dataloader):
|
||||
len_dataloader = len(train_dataloader)
|
||||
num_update_steps_per_epoch = (
|
||||
len_dataloader // args.gradient_accumulation_steps
|
||||
)
|
||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||
num_examples = self.num_examples(train_dataloader)
|
||||
if args.max_steps > 0:
|
||||
max_steps = args.max_steps
|
||||
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
||||
args.max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
|
||||
# the best we can do.
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = (
|
||||
self.num_tokens(train_dataloader, args.max_steps)
|
||||
* args.gradient_accumulation_steps
|
||||
)
|
||||
else:
|
||||
max_steps = math.ceil(
|
||||
args.num_train_epochs * num_update_steps_per_epoch
|
||||
)
|
||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
||||
num_train_samples = (
|
||||
self.num_examples(train_dataloader) * args.num_train_epochs
|
||||
)
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = (
|
||||
self.num_tokens(train_dataloader) * args.num_train_epochs
|
||||
)
|
||||
elif (
|
||||
args.max_steps > 0
|
||||
): # Rely on max_steps when dataloader does not have a working size
|
||||
max_steps = args.max_steps
|
||||
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
||||
num_train_epochs = sys.maxsize
|
||||
num_update_steps_per_epoch = max_steps
|
||||
num_examples = total_train_batch_size * args.max_steps
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
if args.include_tokens_per_second:
|
||||
num_train_tokens = (
|
||||
self.num_tokens(train_dataloader, args.max_steps)
|
||||
* args.gradient_accumulation_steps
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
|
||||
f" {args.max_steps}"
|
||||
)
|
||||
|
||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
||||
if self.args.n_gpu > 1:
|
||||
# nn.DataParallel(model) replicates the model, creating new variables and module
|
||||
# references registered here no longer work on other gpus, breaking the module
|
||||
raise ValueError(
|
||||
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
|
||||
" (torchrun or torch.distributed.launch (deprecated))."
|
||||
)
|
||||
else:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = (
|
||||
is_sagemaker_mp_enabled()
|
||||
or self.is_fsdp_xla_enabled
|
||||
or self.is_fsdp_enabled
|
||||
)
|
||||
|
||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||
if self._created_lr_scheduler:
|
||||
self.lr_scheduler = None
|
||||
self._created_lr_scheduler = False
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.optimizer, self.lr_scheduler = deepspeed_init(
|
||||
self, num_training_steps=max_steps
|
||||
)
|
||||
|
||||
if not delay_optimizer_creation:
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
self.state = TrainerState(
|
||||
stateful_callbacks=[
|
||||
cb
|
||||
for cb in self.callback_handler.callbacks + [self.control]
|
||||
if isinstance(cb, ExportableState)
|
||||
if self.optimizer is None:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" not in n)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": self.args.learning_rate,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and "merger" in n)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": self.args.learning_rate / 10,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
"lr": self.args.learning_rate,
|
||||
},
|
||||
]
|
||||
)
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
self.state.train_batch_size = self._train_batch_size
|
||||
|
||||
# Compute absolute values for logging, eval, and save if given as ratio
|
||||
if args.logging_steps is not None:
|
||||
if args.logging_steps < 1:
|
||||
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
self.state.logging_steps = args.logging_steps
|
||||
if args.eval_steps is not None:
|
||||
if args.eval_steps < 1:
|
||||
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
|
||||
else:
|
||||
self.state.eval_steps = args.eval_steps
|
||||
if args.save_steps is not None:
|
||||
if args.save_steps < 1:
|
||||
self.state.save_steps = math.ceil(max_steps * args.save_steps)
|
||||
else:
|
||||
self.state.save_steps = args.save_steps
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
|
||||
|
||||
# Activate gradient checkpointing if needed
|
||||
if args.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs
|
||||
)
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
model = self._wrap_model(self.model_wrapped)
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# as the model is wrapped, don't use `accelerator.prepare`
|
||||
# this is for unhandled cases such as
|
||||
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||
use_accelerator_prepare = True if model is self.model else False
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
|
||||
|
||||
if use_accelerator_prepare and self.is_fsdp_enabled:
|
||||
# In case of auto_find_batch_size=True
|
||||
# Remove FSDP wrapping from sub-models.
|
||||
self.model = unwrap_model(self.model, recursive=True)
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
|
||||
if delay_optimizer_creation:
|
||||
if use_accelerator_prepare:
|
||||
# configure fsdp plugin for qlora if any
|
||||
self._fsdp_qlora_plugin_updates()
|
||||
if self.accelerator.mixed_precision != "fp8":
|
||||
self.model = self.accelerator.prepare(self.model)
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
|
||||
import bitsandbytes
|
||||
|
||||
# prepare using `accelerator` prepare
|
||||
if use_accelerator_prepare:
|
||||
self.model.train()
|
||||
if hasattr(self.lr_scheduler, "step"):
|
||||
if self.use_apex:
|
||||
model = self.accelerator.prepare(self.model)
|
||||
else:
|
||||
model, self.optimizer = self.accelerator.prepare(
|
||||
self.model, self.optimizer
|
||||
)
|
||||
else:
|
||||
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
|
||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
# In this case we are in DDP + LOMO, which should be supported
|
||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
self.model = self.model_wrapped = model
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
||||
logger.info(f"skipped {module}: {skipped / 2**20}M params")
|
||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
logger.info(f"skipped: {skipped / 2**20}M params")
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
if model is not self.model:
|
||||
self.model_wrapped = model
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||
|
||||
# backward compatibility
|
||||
if self.is_deepspeed_enabled:
|
||||
self.deepspeed = self.model_wrapped
|
||||
|
||||
# ckpt loading
|
||||
if resume_from_checkpoint is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
deepspeed_load_checkpoint(
|
||||
self.model_wrapped,
|
||||
resume_from_checkpoint,
|
||||
load_module_strict=not _is_peft_model(self.model),
|
||||
)
|
||||
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
|
||||
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
|
||||
# important: at this point:
|
||||
# self.model is the Transformers Model
|
||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
|
||||
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {num_examples:,}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info(
|
||||
f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}"
|
||||
)
|
||||
if self.args.per_device_train_batch_size != self._train_batch_size:
|
||||
logger.info(
|
||||
f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}"
|
||||
)
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_steps:,}")
|
||||
logger.info(
|
||||
f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}"
|
||||
)
|
||||
|
||||
self.state.epoch = 0
|
||||
start_time = time.time()
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
# Check if continuing training from a checkpoint
|
||||
if resume_from_checkpoint is not None and os.path.isfile(
|
||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
||||
):
|
||||
self.state = TrainerState.load_from_json(
|
||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
||||
)
|
||||
self.compare_trainer_and_checkpoint_args(self.args, self.state)
|
||||
self._load_callback_state()
|
||||
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
|
||||
if not args.ignore_data_skip:
|
||||
steps_trained_in_current_epoch = self.state.global_step % (
|
||||
num_update_steps_per_epoch
|
||||
)
|
||||
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
|
||||
else:
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
logger.info(
|
||||
" Continuing training from checkpoint, will skip to saved global_step"
|
||||
)
|
||||
logger.info(f" Continuing training from epoch {epochs_trained}")
|
||||
logger.info(
|
||||
f" Continuing training from global step {self.state.global_step}"
|
||||
)
|
||||
if not args.ignore_data_skip:
|
||||
logger.info(
|
||||
f" Will skip the first {epochs_trained} epochs then the first"
|
||||
f" {steps_trained_in_current_epoch} batches in the first epoch."
|
||||
)
|
||||
|
||||
# Update the references
|
||||
self.callback_handler.model = self.model
|
||||
self.callback_handler.optimizer = self.optimizer
|
||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||
self.callback_handler.train_dataloader = train_dataloader
|
||||
if self.hp_name is not None and self._trial is not None:
|
||||
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
|
||||
# parameter to Train when using DDP.
|
||||
self.state.trial_name = self.hp_name(self._trial)
|
||||
if trial is not None:
|
||||
assignments = (
|
||||
trial.assignments
|
||||
if self.hp_search_backend == HPSearchBackend.SIGOPT
|
||||
else trial
|
||||
)
|
||||
self.state.trial_params = hp_params(assignments)
|
||||
else:
|
||||
self.state.trial_params = None
|
||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||
# to set this after the load.
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||
|
||||
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
||||
tr_loss = torch.tensor(0.0).to(args.device)
|
||||
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
||||
self._total_loss_scalar = 0.0
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
model.zero_grad()
|
||||
grad_norm: Optional[float] = None
|
||||
self.control = self.callback_handler.on_train_begin(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
if args.eval_on_start:
|
||||
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
|
||||
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
epoch_dataloader = train_dataloader
|
||||
if hasattr(epoch_dataloader, "set_epoch"):
|
||||
epoch_dataloader.set_epoch(epoch)
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
steps_in_epoch = (
|
||||
len(epoch_dataloader)
|
||||
if len_dataloader is not None
|
||||
else args.max_steps * args.gradient_accumulation_steps
|
||||
)
|
||||
self.control = self.callback_handler.on_epoch_begin(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
if (
|
||||
epoch == epochs_trained
|
||||
and resume_from_checkpoint is not None
|
||||
and steps_trained_in_current_epoch == 0
|
||||
):
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
|
||||
rng_to_sync = False
|
||||
steps_skipped = 0
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
epoch_dataloader = skip_first_batches(
|
||||
epoch_dataloader, steps_trained_in_current_epoch
|
||||
)
|
||||
steps_skipped = steps_trained_in_current_epoch
|
||||
steps_trained_in_current_epoch = 0
|
||||
rng_to_sync = True
|
||||
|
||||
step = -1
|
||||
epoch_iterator = iter(epoch_dataloader)
|
||||
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
|
||||
remainder = num_examples % args.gradient_accumulation_steps
|
||||
if remainder == 0:
|
||||
remainder = args.gradient_accumulation_steps
|
||||
update_step = -1
|
||||
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
|
||||
for _ in range(total_updates):
|
||||
update_step += 1
|
||||
num_batches = (
|
||||
args.gradient_accumulation_steps
|
||||
if update_step != (total_updates - 1)
|
||||
else remainder
|
||||
)
|
||||
batch_samples, num_items_in_batch = self.get_batch_samples(
|
||||
epoch_iterator, num_batches
|
||||
)
|
||||
for i, inputs in enumerate(batch_samples):
|
||||
step += 1
|
||||
do_sync_step = (
|
||||
step + 1
|
||||
) % args.gradient_accumulation_steps == 0 or (
|
||||
step + 1
|
||||
) == steps_in_epoch
|
||||
# Since we perform prefetching, we need to manually set sync_gradients
|
||||
if not do_sync_step:
|
||||
self.accelerator.gradient_state._set_sync_gradients(False)
|
||||
else:
|
||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||
|
||||
if self.args.include_num_input_tokens_seen:
|
||||
main_input_name = getattr(
|
||||
self.model, "main_input_name", "input_ids"
|
||||
)
|
||||
if main_input_name not in inputs:
|
||||
logger.warning(
|
||||
"Tried to track the number of tokens seen, however the current model is "
|
||||
"not configured properly to know what item is the input. To fix this, add "
|
||||
"a `main_input_name` attribute to the model class you are using."
|
||||
)
|
||||
else:
|
||||
input_tokens = inputs[main_input_name].numel()
|
||||
input_tokens = torch.tensor(
|
||||
input_tokens, device=self.args.device, dtype=torch.int64
|
||||
)
|
||||
self.state.num_input_tokens_seen += (
|
||||
self.accelerator.gather(input_tokens).sum().cpu().item()
|
||||
)
|
||||
if rng_to_sync:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
rng_to_sync = False
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
if steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.update(1)
|
||||
if steps_trained_in_current_epoch == 0:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
continue
|
||||
elif steps_trained_progress_bar is not None:
|
||||
steps_trained_progress_bar.close()
|
||||
steps_trained_progress_bar = None
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1
|
||||
and self.accelerator.distributed_type
|
||||
!= DistributedType.DEEPSPEED
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
tr_loss_step = self.training_step(
|
||||
model, inputs, num_items_in_batch
|
||||
)
|
||||
|
||||
if (
|
||||
args.logging_nan_inf_filter
|
||||
and not is_torch_xla_available()
|
||||
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
|
||||
):
|
||||
# if loss is nan or inf simply add the average of previous logged losses
|
||||
tr_loss = tr_loss + tr_loss / (
|
||||
1 + self.state.global_step - self._globalstep_last_logged
|
||||
)
|
||||
else:
|
||||
if tr_loss.device != tr_loss_step.device:
|
||||
raise ValueError(
|
||||
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
|
||||
)
|
||||
tr_loss = tr_loss + tr_loss_step
|
||||
|
||||
self.current_flos += float(self.floating_point_ops(inputs))
|
||||
|
||||
if do_sync_step:
|
||||
# Since we perform prefetching, we need to manually set sync_gradients to True
|
||||
self.accelerator.gradient_state._set_sync_gradients(True)
|
||||
|
||||
# Gradient clipping
|
||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||
# deepspeed does its own clipping
|
||||
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
_grad_norm = self.optimizer.clip_master_grads(
|
||||
args.max_grad_norm
|
||||
)
|
||||
elif self.use_apex:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
_grad_norm = nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
else:
|
||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
|
||||
if (
|
||||
is_accelerate_available()
|
||||
and self.accelerator.distributed_type
|
||||
== DistributedType.DEEPSPEED
|
||||
):
|
||||
grad_norm = model.get_global_grad_norm()
|
||||
# In some cases the grad norm may not return a float
|
||||
if hasattr(grad_norm, "item"):
|
||||
grad_norm = grad_norm.item()
|
||||
else:
|
||||
grad_norm = _grad_norm
|
||||
|
||||
self.control = self.callback_handler.on_pre_optimizer_step(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
self.control = self.callback_handler.on_optimizer_step(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
optimizer_was_run = (
|
||||
not self.accelerator.optimizer_step_was_skipped
|
||||
)
|
||||
if optimizer_was_run:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(
|
||||
self.lr_scheduler,
|
||||
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
||||
):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = (
|
||||
epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||
)
|
||||
self.control = self.callback_handler.on_step_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
self._maybe_log_save_evaluate(
|
||||
tr_loss,
|
||||
grad_norm,
|
||||
model,
|
||||
trial,
|
||||
epoch,
|
||||
ignore_keys_for_eval,
|
||||
start_time,
|
||||
)
|
||||
else:
|
||||
self.control = self.callback_handler.on_substep_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
# PyTorch/XLA relies on the data loader to insert the mark_step for
|
||||
# each step. Since we are breaking the loop early, we need to manually
|
||||
# insert the mark_step here.
|
||||
if (
|
||||
self.control.should_epoch_stop
|
||||
or self.control.should_training_stop
|
||||
):
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
break
|
||||
# We also need to break out of the nested loop
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
if is_torch_xla_available():
|
||||
xm.mark_step()
|
||||
break
|
||||
if step < 0:
|
||||
logger.warning(
|
||||
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
|
||||
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
|
||||
f" num_steps ({max_steps}) higher than the number of available samples."
|
||||
)
|
||||
self.control.should_training_stop = True
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
self._maybe_log_save_evaluate(
|
||||
tr_loss,
|
||||
grad_norm,
|
||||
model,
|
||||
trial,
|
||||
epoch,
|
||||
ignore_keys_for_eval,
|
||||
start_time,
|
||||
)
|
||||
|
||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||
if is_torch_xla_available():
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
else:
|
||||
logger.warning(
|
||||
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
||||
"configured. Check your training configuration if this is unexpected."
|
||||
)
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of training
|
||||
delattr(self, "_past")
|
||||
|
||||
logger.info(
|
||||
"\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
|
||||
)
|
||||
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
||||
# Wait for everyone to get here so we are sure the model has been saved by process 0.
|
||||
if is_torch_xla_available():
|
||||
xm.rendezvous("load_best_model_at_end")
|
||||
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
dist.barrier()
|
||||
elif is_sagemaker_mp_enabled():
|
||||
smp.barrier()
|
||||
|
||||
self._load_best_model()
|
||||
|
||||
# add remaining tr_loss
|
||||
self._total_loss_scalar += tr_loss.item()
|
||||
effective_global_step = max(
|
||||
self.state.global_step, 0.001
|
||||
) # Avoid ZeroDivisionError
|
||||
train_loss = self._total_loss_scalar / effective_global_step
|
||||
|
||||
metrics = speed_metrics(
|
||||
"train",
|
||||
start_time,
|
||||
num_samples=num_train_samples,
|
||||
num_steps=self.state.max_steps,
|
||||
num_tokens=num_train_tokens,
|
||||
)
|
||||
self.store_flos()
|
||||
metrics["total_flos"] = self.state.total_flos
|
||||
metrics["train_loss"] = train_loss
|
||||
|
||||
self.is_in_train = False
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
|
||||
self.log(metrics)
|
||||
|
||||
run_dir = self._get_output_dir(trial)
|
||||
checkpoints_sorted = self._sorted_checkpoints(
|
||||
use_mtime=False, output_dir=run_dir
|
||||
)
|
||||
|
||||
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
|
||||
if (
|
||||
self.args.should_save
|
||||
and self.state.best_model_checkpoint is not None
|
||||
and self.args.save_total_limit == 1
|
||||
):
|
||||
for checkpoint in checkpoints_sorted:
|
||||
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
|
||||
logger.info(
|
||||
f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit"
|
||||
)
|
||||
shutil.rmtree(checkpoint, ignore_errors=True)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(
|
||||
args, self.state, self.control
|
||||
)
|
||||
|
||||
# Wait for the checkpoint to be uploaded.
|
||||
self._finish_current_push()
|
||||
|
||||
# After training we make sure to retrieve back the original forward pass method
|
||||
# for the embedding layer by removing the forward post hook.
|
||||
if self.neftune_noise_alpha is not None:
|
||||
self._deactivate_neftune(self.model)
|
||||
|
||||
return TrainOutput(self.state.global_step, train_loss, metrics)
|
||||
return self.optimizer
|
||||
|
Loading…
Reference in New Issue
Block a user