Compare commits
No commits in common. "7b9349091e9ad45b9251c9533832bfb7ca15898b" and "a38ccf704231db6f8188416a4362d92daf4ed42e" have entirely different histories.
7b9349091e
...
a38ccf7042
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,3 @@
|
|||||||
**/.venv/*
|
**/.venv/*
|
||||||
**/__pycache__/*
|
**/__pycache__/*
|
||||||
rsync.sh
|
rsync.sh
|
||||||
.pytest_cache/
|
|
||||||
|
|||||||
@ -1,11 +0,0 @@
|
|||||||
repos:
|
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
|
||||||
rev: 24.8.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
# It is recommended to specify the latest version of Python
|
|
||||||
# supported by your project here, or alternatively use
|
|
||||||
# pre-commit's default_language_version, see
|
|
||||||
# https://pre-commit.com/#top_level-default_language_version
|
|
||||||
language_version: python3.11
|
|
||||||
@ -11,8 +11,6 @@ dependencies = [
|
|||||||
"numba>=0.60.0",
|
"numba>=0.60.0",
|
||||||
"peft==0.14.0",
|
"peft==0.14.0",
|
||||||
"pip==24.3.1",
|
"pip==24.3.1",
|
||||||
"pre-commit>=4.0.1",
|
|
||||||
"pytest>=8.3.4",
|
|
||||||
"requests==2.32.3",
|
"requests==2.32.3",
|
||||||
"rouge-score>=0.1.2",
|
"rouge-score>=0.1.2",
|
||||||
"safetensors>=0.5.2",
|
"safetensors>=0.5.2",
|
||||||
@ -32,14 +30,16 @@ requires-python = ">=3.11"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
compile = ["flash-attn>=2.7.2.post1"]
|
compile = [
|
||||||
|
"flash-attn>=2.7.2.post1",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
markupsafe = { index = "pytorch" }
|
markupsafe = {index = "pytorch"}
|
||||||
requests = { index = "pypi" }
|
requests = {index = "pypi"}
|
||||||
torch = { index = "pytorch" }
|
torch = {index = "pytorch"}
|
||||||
torchaudio = { index = "pytorch" }
|
torchaudio = {index = "pytorch"}
|
||||||
torchvision = { index = "pytorch" }
|
torchvision = {index = "pytorch"}
|
||||||
|
|
||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
name = "pypi"
|
name = "pypi"
|
||||||
@ -52,15 +52,3 @@ concurrent-builds = 4
|
|||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
name = "pytorch"
|
name = "pytorch"
|
||||||
url = "https://download.pytorch.org/whl/cu124"
|
url = "https://download.pytorch.org/whl/cu124"
|
||||||
|
|
||||||
[tool.black]
|
|
||||||
line-length = 88
|
|
||||||
exclude = "transformers_repo|peft_repo|.venv"
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
addopts = ["--color=yes", "--durations=0", "-v", "--capture=tee-sys"]
|
|
||||||
norecursedirs = [
|
|
||||||
"src/transformers_repo",
|
|
||||||
"src/peft_repo",
|
|
||||||
".venv",
|
|
||||||
]
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
class ChemDataset(Dataset):
|
class CHEMDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
):
|
):
|
||||||
@ -25,9 +25,9 @@ class ChemDataset(Dataset):
|
|||||||
|
|
||||||
def _vis_processor(self, image: Image.Image):
|
def _vis_processor(self, image: Image.Image):
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if width > 800 or height > 800:
|
if width > 600 or height > 600:
|
||||||
max_size = max(width, height)
|
max_size = max(width, height)
|
||||||
ratio = 800 / max_size
|
ratio = 600 / max_size
|
||||||
new_width = int(width * ratio)
|
new_width = int(width * ratio)
|
||||||
new_height = int(height * ratio)
|
new_height = int(height * ratio)
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
@ -69,7 +69,7 @@ class ChemDataset(Dataset):
|
|||||||
return processed_data
|
return processed_data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data) // 60
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.data[index]
|
sample = self.data[index]
|
||||||
@ -112,7 +112,7 @@ class ChemDataset(Dataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ChemDatasetForGeneration(ChemDataset):
|
class CHEMDatasetForGeneration(CHEMDataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.data[index]
|
sample = self.data[index]
|
||||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
@ -147,11 +147,27 @@ class ChemDatasetForGeneration(ChemDataset):
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
from .format import DatasetOutput
|
return {
|
||||||
|
"image": image,
|
||||||
|
"chat": chat,
|
||||||
|
"answer": answer,
|
||||||
|
"original": sample["original"],
|
||||||
|
}
|
||||||
|
|
||||||
return DatasetOutput(
|
|
||||||
images=[image],
|
if __name__ == "__main__":
|
||||||
chat=chat,
|
dataset = CHEMDataset(
|
||||||
answer=answer,
|
"/home/zyy/research/accelerate/dataset/chem/images",
|
||||||
original=sample["original"],
|
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
|
||||||
|
split="train",
|
||||||
)
|
)
|
||||||
|
print(len(dataset))
|
||||||
|
print(dataset[0])
|
||||||
|
dataset = CHEMDatasetForGeneration(
|
||||||
|
"/home/zyy/research/accelerate/dataset/chem/images",
|
||||||
|
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
print(len(dataset))
|
||||||
|
print(dataset[0])
|
||||||
|
pass
|
||||||
@ -1,11 +1,7 @@
|
|||||||
from .format import (
|
from PIL import Image
|
||||||
Conversation,
|
|
||||||
ConverstationAudio,
|
|
||||||
ConverstationImage,
|
|
||||||
ConverstationText,
|
|
||||||
DatasetOutput,
|
|
||||||
)
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
@ -36,25 +32,22 @@ class GigaspeechDataset(Dataset):
|
|||||||
text = self.text_processor(text)
|
text = self.text_processor(text)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
Conversation(
|
{
|
||||||
role="user",
|
"role": "user",
|
||||||
content=[
|
"content": [
|
||||||
ConverstationAudio(type="audio", audio_url=""),
|
{"type": "audio", "audio_url": ""},
|
||||||
ConverstationText(
|
{
|
||||||
type="text", text="Please convert the audio to text"
|
"type": "text",
|
||||||
),
|
"text": "Please convert the audio to text",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
),
|
},
|
||||||
Conversation(
|
{"role": "assistant", "content": [{"type": "text", "text": text}]},
|
||||||
role="assistant", content=[ConverstationText(type="text", text=text)]
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
return {
|
||||||
return DatasetOutput(
|
"audio": (audio, sampling_rate),
|
||||||
audio=[(audio, sampling_rate)],
|
"chat": chat,
|
||||||
chat=chat,
|
}
|
||||||
original=sample,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||||
@ -71,35 +64,34 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
|||||||
text = self.text_processor(text)
|
text = self.text_processor(text)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
Conversation(
|
{
|
||||||
role="user",
|
"role": "user",
|
||||||
content=[
|
"content": [
|
||||||
ConverstationAudio(type="audio", audio_url=""),
|
{"type": "audio", "audio_url": ""},
|
||||||
ConverstationText(
|
{
|
||||||
type="text", text="Please convert the audio to text"
|
"type": "text",
|
||||||
),
|
"text": "Please convert the audio to text",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
),
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return DatasetOutput(
|
return {
|
||||||
audio=[(audio, sampling_rate)],
|
"audio": (audio, sampling_rate),
|
||||||
chat=chat,
|
"chat": chat,
|
||||||
answer=text,
|
"answer": text,
|
||||||
original=sample,
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gigaspeech():
|
if __name__ == "__main__":
|
||||||
dataset = GigaspeechDataset(
|
dataset = GigaspeechDataset(
|
||||||
split="train",
|
split="train",
|
||||||
)
|
)
|
||||||
|
print(len(dataset))
|
||||||
print(dataset[0])
|
print(dataset[0])
|
||||||
assert len(dataset) > 0
|
|
||||||
assert len(dataset[0]["chat"]) > 0
|
|
||||||
dataset = GigaspeechDatasetForGeneration(
|
dataset = GigaspeechDatasetForGeneration(
|
||||||
split="train",
|
split="train",
|
||||||
)
|
)
|
||||||
|
print(len(dataset))
|
||||||
print(dataset[0])
|
print(dataset[0])
|
||||||
assert len(dataset) > 0
|
pass
|
||||||
assert len(dataset[0]["chat"]) > 0
|
|
||||||
|
|||||||
@ -1,15 +1,7 @@
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from .format import (
|
|
||||||
Conversation,
|
|
||||||
ConverstationAudio,
|
|
||||||
ConverstationImage,
|
|
||||||
ConverstationText,
|
|
||||||
DatasetOutput,
|
|
||||||
)
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADataset(Dataset):
|
class OCRVQADataset(Dataset):
|
||||||
@ -27,9 +19,9 @@ class OCRVQADataset(Dataset):
|
|||||||
)
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=1)
|
self.data = self.create_data(ann_path, split=1)
|
||||||
elif split == "test":
|
elif split == "test":
|
||||||
self.data = self.create_data(Path(ann_path, "dataset.json"), split=3)
|
self.data = self.create_data(ann_path, split=3)
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
# "[vqa] {}",
|
# "[vqa] {}",
|
||||||
@ -56,7 +48,6 @@ class OCRVQADataset(Dataset):
|
|||||||
"image_id": k,
|
"image_id": k,
|
||||||
"title": data[k]["title"],
|
"title": data[k]["title"],
|
||||||
"genre": data[k]["genre"],
|
"genre": data[k]["genre"],
|
||||||
"original": data[k],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return processed_data
|
return processed_data
|
||||||
@ -98,26 +89,22 @@ class OCRVQADataset(Dataset):
|
|||||||
answer = self.text_processor(answer)
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
Conversation(
|
{
|
||||||
role="user",
|
"role": "user",
|
||||||
content=[
|
"content": [
|
||||||
ConverstationImage(type="image", image_url=""),
|
{"type": "image"},
|
||||||
ConverstationText(
|
{
|
||||||
type="text",
|
"type": "text",
|
||||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
),
|
},
|
||||||
],
|
],
|
||||||
),
|
},
|
||||||
Conversation(
|
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
||||||
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
return {
|
||||||
return DatasetOutput(
|
"image": image,
|
||||||
chat=chat,
|
"chat": chat,
|
||||||
original=sample["original"],
|
}
|
||||||
images=[image],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||||
@ -137,20 +124,20 @@ class OCRVQADatasetForGeneration(OCRVQADataset):
|
|||||||
answer = self.text_processor(answer)
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
Conversation(
|
{
|
||||||
role="user",
|
"role": "user",
|
||||||
content=[
|
"content": [
|
||||||
ConverstationImage(type="image", image_url=""),
|
{"type": "image"},
|
||||||
ConverstationText(
|
{
|
||||||
type="text",
|
"type": "text",
|
||||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||||
),
|
},
|
||||||
],
|
],
|
||||||
),
|
}
|
||||||
|
# {"role": "assistant", "content": answer},
|
||||||
]
|
]
|
||||||
return DatasetOutput(
|
return {
|
||||||
images=[image],
|
"image": image,
|
||||||
chat=chat,
|
"chat": chat,
|
||||||
answer=answer,
|
"answer": answer,
|
||||||
original=sample["original"],
|
}
|
||||||
)
|
|
||||||
@ -1,173 +0,0 @@
|
|||||||
from PIL import Image
|
|
||||||
from .format import (
|
|
||||||
Conversation,
|
|
||||||
ConverstationAudio,
|
|
||||||
ConverstationImage,
|
|
||||||
ConverstationText,
|
|
||||||
DatasetOutput,
|
|
||||||
)
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import json
|
|
||||||
import os.path as osp
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class TextVQADataset(Dataset):
|
|
||||||
def __init__(
|
|
||||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
||||||
ann_root (string): directory to store the annotation file
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.vis_processor = (
|
|
||||||
vis_processor if vis_processor is not None else self._vis_processor
|
|
||||||
)
|
|
||||||
self.text_processor = text_processor
|
|
||||||
if split == "train":
|
|
||||||
self.data = self.create_data(
|
|
||||||
Path(ann_path, "TextVQA_0.5.1_train.json"),
|
|
||||||
vis_root=Path(vis_root, "train_images"),
|
|
||||||
)
|
|
||||||
elif split == "test":
|
|
||||||
self.data = self.create_data(
|
|
||||||
Path(ann_path, "TextVQA_0.5.1_val.json"),
|
|
||||||
vis_root=Path(vis_root, "train_images"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.instruction_pool = [
|
|
||||||
# "[vqa] {}",
|
|
||||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
|
||||||
# ]
|
|
||||||
|
|
||||||
def create_data(self, ann_path, vis_root):
|
|
||||||
processed_data = []
|
|
||||||
with open(ann_path, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
data = data["data"]
|
|
||||||
for i in range(len(data)):
|
|
||||||
# print(data[0])
|
|
||||||
# {'question': 'what is the brand of phone?', 'image_id': '0054c91397f2fe05', 'image_classes': ['Belt', 'Headphones', 'Goggles', 'Scale', 'Bottle opener', 'Mobile phone', 'Mirror', 'Digital clock', 'Television', 'Telephone', 'Tool', 'Wheel', 'Camera', 'Watch', 'Glasses', 'Aircraft'], 'flickr_original_url': 'https://farm6.staticflickr.com/2891/9134076951_f65b421097_o.jpg', 'flickr_300k_url': 'https://c4.staticflickr.com/3/2891/9134076951_9db89d3e0f_z.jpg', 'image_width': 1024, 'image_height': 730, 'answers': ['nokia', 'nokia', 'nokia', 'nokia', 'toshiba', 'nokia', 'nokia', 'nokia', 'nokia', 'nokia'], 'question_tokens': ['what', 'is', 'the', 'brand', 'of', 'phone'], 'question_id': 0, 'set_name': 'train'}
|
|
||||||
try:
|
|
||||||
imageFile = data[i]["image_id"] + ".jpg"
|
|
||||||
question = data[i]["question"]
|
|
||||||
answer = data[i]["answers"][0]
|
|
||||||
processed_data.append(
|
|
||||||
{
|
|
||||||
"question": question,
|
|
||||||
"answer": answer,
|
|
||||||
"image_path": Path(vis_root, imageFile),
|
|
||||||
"image_id": data[i]["image_id"],
|
|
||||||
"title": data[i]["image_id"],
|
|
||||||
"genre": data[i]["image_classes"],
|
|
||||||
"original": data[i],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
print(data[i])
|
|
||||||
pass
|
|
||||||
|
|
||||||
return processed_data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def _vis_processor(self, image: Image.Image):
|
|
||||||
width, height = image.size
|
|
||||||
if width > 500 or height > 500:
|
|
||||||
max_size = max(width, height)
|
|
||||||
ratio = 500 / max_size
|
|
||||||
new_width = int(width * ratio)
|
|
||||||
new_height = int(height * ratio)
|
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
|
||||||
|
|
||||||
if width < 28 or height < 28:
|
|
||||||
min_size = min(width, height)
|
|
||||||
ratio = 28 / min_size + 1
|
|
||||||
new_width = int(width * ratio)
|
|
||||||
new_height = int(height * ratio)
|
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
sample = self.data[index]
|
|
||||||
image: Image.Image = Image.open(sample["image_path"]).convert("RGB")
|
|
||||||
# resize image
|
|
||||||
|
|
||||||
question = sample["question"]
|
|
||||||
answer = sample["answer"]
|
|
||||||
if self.vis_processor is not None:
|
|
||||||
image = self.vis_processor(image)
|
|
||||||
if self.text_processor is not None:
|
|
||||||
question = self.text_processor(question)
|
|
||||||
answer = self.text_processor(answer)
|
|
||||||
|
|
||||||
chat = [
|
|
||||||
Conversation(
|
|
||||||
role="user",
|
|
||||||
content=[
|
|
||||||
ConverstationImage(type="image", image_url=""),
|
|
||||||
ConverstationText(
|
|
||||||
type="text",
|
|
||||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
Conversation(
|
|
||||||
role="assistant", content=[ConverstationText(type="text", text=answer)]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
return DatasetOutput(
|
|
||||||
chat=chat,
|
|
||||||
original=sample["original"],
|
|
||||||
images=[image],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TextVQADatasetForGeneration(TextVQADataset):
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
sample = self.data[index]
|
|
||||||
image = Image.open(sample["image_path"]).convert("RGB")
|
|
||||||
# resize image
|
|
||||||
question = sample["question"]
|
|
||||||
answer = sample["answer"]
|
|
||||||
if self.vis_processor is not None:
|
|
||||||
image = self.vis_processor(image)
|
|
||||||
if self.text_processor is not None:
|
|
||||||
question = self.text_processor(question)
|
|
||||||
answer = self.text_processor(answer)
|
|
||||||
|
|
||||||
chat = [
|
|
||||||
Conversation(
|
|
||||||
role="user",
|
|
||||||
content=[
|
|
||||||
ConverstationImage(type="image", image_url=""),
|
|
||||||
ConverstationText(
|
|
||||||
type="text",
|
|
||||||
text=f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return DatasetOutput(
|
|
||||||
images=[image],
|
|
||||||
chat=chat,
|
|
||||||
answer=answer,
|
|
||||||
original=sample["original"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset():
|
|
||||||
vis_root = "/home/zyy/dataset/TextVQA/images"
|
|
||||||
ann_path = "/home/zyy/dataset/TextVQA"
|
|
||||||
dataset = TextVQADataset(vis_root, ann_path)
|
|
||||||
for i in range(10):
|
|
||||||
print(dataset[i])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_dataset()
|
|
||||||
@ -1,49 +1,50 @@
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
dataset_name, base_path="/home/zyy/dataset"
|
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
|
||||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||||
if dataset_name == "ocrvqa200k":
|
if dataset_name == "OCR-VQA-200K":
|
||||||
from .OCRVQA200KDataset import OCRVQADataset, OCRVQADatasetForGeneration
|
import os.path as osp
|
||||||
|
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": OCRVQADataset(
|
"train": OCRVQADataset(
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
split="train",
|
split="train",
|
||||||
),
|
),
|
||||||
"test": OCRVQADataset(
|
"test": OCRVQADataset(
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
"generation": OCRVQADatasetForGeneration(
|
"generation": OCRVQADatasetForGeneration(
|
||||||
vis_root=Path(base_path, "OCR-VQA-200K", "images"),
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
ann_path=Path(base_path, "OCR-VQA-200K"),
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if dataset_name == "chem":
|
if dataset_name == "CHEM":
|
||||||
from .ChemDataset import ChemDataset, ChemDatasetForGeneration
|
import os.path as osp
|
||||||
|
from .CHEM import CHEMDataset, CHEMDatasetForGeneration
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": ChemDataset(
|
"train": CHEMDataset(
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
osp.join(base_path, "chem/images"),
|
||||||
ann_path=Path(base_path, "chem"),
|
osp.join(base_path, "chem"),
|
||||||
split="train",
|
split="train",
|
||||||
),
|
),
|
||||||
"test": ChemDataset(
|
"test": CHEMDataset(
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
osp.join(base_path, "chem/images"),
|
||||||
ann_path=Path(base_path, "chem"),
|
osp.join(base_path, "chem"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
"generation": ChemDatasetForGeneration(
|
"generation": CHEMDatasetForGeneration(
|
||||||
vis_root=Path(base_path, "chem", "images"),
|
osp.join(base_path, "chem/images"),
|
||||||
ann_path=Path(base_path, "chem"),
|
osp.join(base_path, "chem"),
|
||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@ -56,26 +57,4 @@ def get_dataset(
|
|||||||
"test": GigaspeechDataset(split="test"),
|
"test": GigaspeechDataset(split="test"),
|
||||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if dataset_name == "textvqa":
|
|
||||||
from .TextVQADataset import TextVQADataset, TextVQADatasetForGeneration
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": TextVQADataset(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="train",
|
|
||||||
),
|
|
||||||
"test": TextVQADataset(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
"generation": TextVQADatasetForGeneration(
|
|
||||||
vis_root=Path(base_path, "TextVQA", "images"),
|
|
||||||
ann_path=Path(base_path, "TextVQA"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@ -1,32 +0,0 @@
|
|||||||
from typing import Any, Tuple, TypedDict, Literal, Optional
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
class ConverstationText(TypedDict):
|
|
||||||
type: Literal["text"]
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class ConverstationAudio(TypedDict):
|
|
||||||
type: Literal["audio"]
|
|
||||||
audio_url: str
|
|
||||||
|
|
||||||
|
|
||||||
class ConverstationImage(TypedDict):
|
|
||||||
type: Literal["image"]
|
|
||||||
image_url: str
|
|
||||||
|
|
||||||
|
|
||||||
class Conversation(TypedDict):
|
|
||||||
|
|
||||||
role: Literal["user", "assistant", "system"]
|
|
||||||
content: list[ConverstationText | ConverstationAudio | ConverstationImage]
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetOutput(TypedDict):
|
|
||||||
audios: Optional[list[Tuple[np.ndarray, int]]]
|
|
||||||
chat: list[Conversation]
|
|
||||||
answer: Optional[str]
|
|
||||||
original: Any
|
|
||||||
images: Optional[list[Image.Image]]
|
|
||||||
@ -1,37 +0,0 @@
|
|||||||
from .factory import get_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def test_gigaspeech():
|
|
||||||
dataset = get_dataset("gigaspeech")
|
|
||||||
assert len(dataset["train"]) > 0
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_chem():
|
|
||||||
dataset = get_dataset("chem")
|
|
||||||
assert len(dataset["train"]) > 0
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_ocrvqa200k():
|
|
||||||
dataset = get_dataset("ocrvqa200k")
|
|
||||||
assert len(dataset["train"]) > 0
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_textvqa():
|
|
||||||
dataset = get_dataset("textvqa")
|
|
||||||
assert len(dataset["train"]) > 0
|
|
||||||
assert len(dataset["train"][0]["chat"]) > 0
|
|
||||||
|
|
||||||
assert len(dataset["test"]) > 0
|
|
||||||
assert len(dataset["test"][0]["chat"]) > 0
|
|
||||||
@ -1,5 +1,4 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.insert(0, "./transformers_repo/src/")
|
sys.path.insert(0, "./transformers_repo/src/")
|
||||||
sys.path.insert(0, "./peft_repo/src/")
|
sys.path.insert(0, "./peft_repo/src/")
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from utils.args import ContinualModelConfig
|
from utils.args import ContinualModelConfig
|
||||||
|
import transformers
|
||||||
|
print(transformers.__version__)
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_args: ContinualModelConfig):
|
def get_model(model_args: ContinualModelConfig):
|
||||||
@ -24,8 +26,7 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||||
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
@ -37,7 +38,7 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
)
|
)
|
||||||
from .qwen2vl import (
|
from model_library.qwen2vl import (
|
||||||
collate_fn_for_train,
|
collate_fn_for_train,
|
||||||
collate_fn_for_evaluate,
|
collate_fn_for_evaluate,
|
||||||
)
|
)
|
||||||
@ -59,7 +60,7 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
)
|
)
|
||||||
from .qwen2audio import (
|
from model_library.qwen2audio import (
|
||||||
collate_fn_for_train,
|
collate_fn_for_train,
|
||||||
collate_fn_for_evaluate,
|
collate_fn_for_evaluate,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,16 +1,15 @@
|
|||||||
from transformers import Qwen2VLProcessor
|
from transformers import Qwen2VLProcessor
|
||||||
from dataset_library.format import DatasetOutput
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
for example in examples
|
for example in examples
|
||||||
]
|
]
|
||||||
# print(texts)
|
# print(texts)
|
||||||
images = [example["images"] for example in examples]
|
images = [example["image"] for example in examples]
|
||||||
# Tokenize the texts and process the images
|
# Tokenize the texts and process the images
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
@ -66,7 +65,7 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
|||||||
for example in examples
|
for example in examples
|
||||||
]
|
]
|
||||||
# print(texts)
|
# print(texts)
|
||||||
images = [example["images"] for example in examples]
|
images = [example["image"] for example in examples]
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
# Tokenize the texts and process the images
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|||||||
@ -40,9 +40,7 @@ class AqlmLoraLinear(torch.nn.Module, LoraLayer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_dora:
|
if use_dora:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
|
||||||
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
LoraLayer.__init__(self, base_layer)
|
LoraLayer.__init__(self, base_layer)
|
||||||
|
|||||||
@ -37,9 +37,7 @@ class AwqLoraLinear(torch.nn.Module, LoraLayer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_dora:
|
if use_dora:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
|
||||||
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
LoraLayer.__init__(self, base_layer)
|
LoraLayer.__init__(self, base_layer)
|
||||||
@ -109,9 +107,7 @@ def dispatch_awq(
|
|||||||
if isinstance(target_base_layer, WQLinear_GEMM):
|
if isinstance(target_base_layer, WQLinear_GEMM):
|
||||||
# Raise the error only at the dispatch level
|
# Raise the error only at the dispatch level
|
||||||
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
|
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
|
||||||
version_autoawq = packaging.version.parse(
|
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))
|
||||||
importlib_metadata.version("autoawq")
|
|
||||||
)
|
|
||||||
|
|
||||||
if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
|
if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|||||||
@ -60,9 +60,7 @@ if is_bnb_available():
|
|||||||
lora_bias=lora_bias,
|
lora_bias=lora_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Merge the active adapter weights into the base weights
|
Merge the active adapter weights into the base weights
|
||||||
|
|
||||||
@ -111,9 +109,7 @@ if is_bnb_available():
|
|||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
|
||||||
)
|
|
||||||
w_data = dora_factor.view(-1, 1) * (output + lora_data)
|
w_data = dora_factor.view(-1, 1) * (output + lora_data)
|
||||||
|
|
||||||
if safe_merge and not torch.isfinite(w_data).all():
|
if safe_merge and not torch.isfinite(w_data).all():
|
||||||
@ -122,15 +118,10 @@ if is_bnb_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.get_base_layer().weight = bnb.nn.Int8Params(
|
self.get_base_layer().weight = bnb.nn.Int8Params(
|
||||||
w_data.to("cpu"),
|
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
|
||||||
requires_grad=False,
|
|
||||||
has_fp16_weights=weight.has_fp16_weights,
|
|
||||||
).to(weight.device)
|
).to(weight.device)
|
||||||
if self.lora_bias[active_adapter]:
|
if self.lora_bias[active_adapter]:
|
||||||
bias_data = (
|
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
|
||||||
self.get_base_layer().bias.data
|
|
||||||
+ self.lora_B[active_adapter].bias
|
|
||||||
)
|
|
||||||
if safe_merge and not torch.isfinite(bias_data):
|
if safe_merge and not torch.isfinite(bias_data):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||||
@ -167,15 +158,11 @@ if is_bnb_available():
|
|||||||
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
|
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
|
||||||
else:
|
else:
|
||||||
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
|
||||||
)
|
|
||||||
w_data = output.data / dora_factor.view(-1, 1) - lora_data
|
w_data = output.data / dora_factor.view(-1, 1) - lora_data
|
||||||
|
|
||||||
self.get_base_layer().weight = bnb.nn.Int8Params(
|
self.get_base_layer().weight = bnb.nn.Int8Params(
|
||||||
w_data.to("cpu"),
|
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
|
||||||
requires_grad=False,
|
|
||||||
has_fp16_weights=weight.has_fp16_weights,
|
|
||||||
).to(weight.device)
|
).to(weight.device)
|
||||||
|
|
||||||
if self.lora_bias[active_adapter]:
|
if self.lora_bias[active_adapter]:
|
||||||
@ -201,13 +188,7 @@ if is_bnb_available():
|
|||||||
unique_adapters = set(adapter_names)
|
unique_adapters = set(adapter_names)
|
||||||
sub_batch_indices_list = []
|
sub_batch_indices_list = []
|
||||||
for adapter in unique_adapters:
|
for adapter in unique_adapters:
|
||||||
sub_batch_indices_list.append(
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
||||||
[
|
|
||||||
index
|
|
||||||
for index, item in enumerate(adapter_names)
|
|
||||||
if item == adapter
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, active_adapter in enumerate(unique_adapters):
|
for i, active_adapter in enumerate(unique_adapters):
|
||||||
if active_adapter == "__base__":
|
if active_adapter == "__base__":
|
||||||
@ -246,9 +227,7 @@ if is_bnb_available():
|
|||||||
self.unmerge()
|
self.unmerge()
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
elif adapter_names is not None:
|
elif adapter_names is not None:
|
||||||
result = self._mixed_batch_forward(
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||||
x, *args, adapter_names=adapter_names, **kwargs
|
|
||||||
)
|
|
||||||
elif self.merged:
|
elif self.merged:
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -351,9 +330,7 @@ if is_bnb_4bit_available():
|
|||||||
lora_bias=lora_bias,
|
lora_bias=lora_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Merge the active adapter weights into the base weights
|
Merge the active adapter weights into the base weights
|
||||||
|
|
||||||
@ -398,9 +375,7 @@ if is_bnb_4bit_available():
|
|||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
|
||||||
)
|
|
||||||
w_data = dora_factor.view(-1, 1) * (output + lora_data)
|
w_data = dora_factor.view(-1, 1) * (output + lora_data)
|
||||||
|
|
||||||
if safe_merge and not torch.isfinite(w_data).all():
|
if safe_merge and not torch.isfinite(w_data).all():
|
||||||
@ -411,14 +386,9 @@ if is_bnb_4bit_available():
|
|||||||
kwargs["bnb_quantized"] = False
|
kwargs["bnb_quantized"] = False
|
||||||
kwargs["requires_grad"] = False
|
kwargs["requires_grad"] = False
|
||||||
kwargs.pop("data", None)
|
kwargs.pop("data", None)
|
||||||
self.get_base_layer().weight = bnb.nn.Params4bit(
|
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
|
||||||
w_data.to("cpu"), **kwargs
|
|
||||||
).to(weight.device)
|
|
||||||
if self.lora_bias[active_adapter]:
|
if self.lora_bias[active_adapter]:
|
||||||
bias_data = (
|
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
|
||||||
self.get_base_layer().bias.data
|
|
||||||
+ self.lora_B[active_adapter].bias
|
|
||||||
)
|
|
||||||
if safe_merge and not torch.isfinite(bias_data):
|
if safe_merge and not torch.isfinite(bias_data):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||||
@ -452,18 +422,14 @@ if is_bnb_4bit_available():
|
|||||||
w_data = output - lora_data
|
w_data = output - lora_data
|
||||||
else:
|
else:
|
||||||
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
|
||||||
)
|
|
||||||
w_data = output.data / dora_factor.view(-1, 1) - lora_data
|
w_data = output.data / dora_factor.view(-1, 1) - lora_data
|
||||||
|
|
||||||
if "bnb_quantized" in kwargs:
|
if "bnb_quantized" in kwargs:
|
||||||
kwargs["bnb_quantized"] = False
|
kwargs["bnb_quantized"] = False
|
||||||
kwargs["requires_grad"] = False
|
kwargs["requires_grad"] = False
|
||||||
kwargs.pop("data", None)
|
kwargs.pop("data", None)
|
||||||
self.get_base_layer().weight = bnb.nn.Params4bit(
|
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
|
||||||
w_data.to("cpu"), **kwargs
|
|
||||||
).to(weight.device)
|
|
||||||
if self.lora_bias[active_adapter]:
|
if self.lora_bias[active_adapter]:
|
||||||
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias
|
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias
|
||||||
|
|
||||||
@ -486,13 +452,7 @@ if is_bnb_4bit_available():
|
|||||||
unique_adapters = set(adapter_names)
|
unique_adapters = set(adapter_names)
|
||||||
sub_batch_indices_list = []
|
sub_batch_indices_list = []
|
||||||
for adapter in unique_adapters:
|
for adapter in unique_adapters:
|
||||||
sub_batch_indices_list.append(
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
||||||
[
|
|
||||||
index
|
|
||||||
for index, item in enumerate(adapter_names)
|
|
||||||
if item == adapter
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, active_adapter in enumerate(unique_adapters):
|
for i, active_adapter in enumerate(unique_adapters):
|
||||||
if active_adapter == "__base__":
|
if active_adapter == "__base__":
|
||||||
@ -529,9 +489,7 @@ if is_bnb_4bit_available():
|
|||||||
self.unmerge()
|
self.unmerge()
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
elif adapter_names is not None:
|
elif adapter_names is not None:
|
||||||
result = self._mixed_batch_forward(
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||||
x, *args, adapter_names=adapter_names, **kwargs
|
|
||||||
)
|
|
||||||
elif self.merged:
|
elif self.merged:
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -592,11 +550,7 @@ if is_bnb_4bit_available():
|
|||||||
target_base_layer = target
|
target_base_layer = target
|
||||||
|
|
||||||
loaded_in_4bit = kwargs.get("loaded_in_4bit", False)
|
loaded_in_4bit = kwargs.get("loaded_in_4bit", False)
|
||||||
if (
|
if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
|
||||||
loaded_in_4bit
|
|
||||||
and is_bnb_4bit_available()
|
|
||||||
and isinstance(target_base_layer, bnb.nn.Linear4bit)
|
|
||||||
):
|
|
||||||
fourbit_kwargs = kwargs.copy()
|
fourbit_kwargs = kwargs.copy()
|
||||||
fourbit_kwargs.update(
|
fourbit_kwargs.update(
|
||||||
{
|
{
|
||||||
|
|||||||
@ -66,9 +66,7 @@ class LoftQConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"})
|
loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"})
|
||||||
loftq_iter: int = field(
|
loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
|
||||||
default=1, metadata={"help": "Alternating iterations for LoftQ"}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -103,25 +101,13 @@ class EvaConfig:
|
|||||||
are adjusted so that all LoRA gradients have the same scale regardless of their rank. Default is True.
|
are adjusted so that all LoRA gradients have the same scale regardless of their rank. Default is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rho: float = field(
|
rho: float = field(default=2.0, metadata={"help": "Rho value for EVA redistribution"})
|
||||||
default=2.0, metadata={"help": "Rho value for EVA redistribution"}
|
tau: float = field(default=0.99, metadata={"help": "Cosine similarity threshold for early stopping"})
|
||||||
)
|
use_label_mask: bool = field(default=True, metadata={"help": "Use label mask for EVA initialization"})
|
||||||
tau: float = field(
|
|
||||||
default=0.99,
|
|
||||||
metadata={"help": "Cosine similarity threshold for early stopping"},
|
|
||||||
)
|
|
||||||
use_label_mask: bool = field(
|
|
||||||
default=True, metadata={"help": "Use label mask for EVA initialization"}
|
|
||||||
)
|
|
||||||
label_mask_value: int = field(
|
label_mask_value: int = field(
|
||||||
default=-100,
|
default=-100, metadata={"help": "if use_label_mask=True the value to look for to mask out ignored tokens"}
|
||||||
metadata={
|
|
||||||
"help": "if use_label_mask=True the value to look for to mask out ignored tokens"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
whiten: bool = field(
|
|
||||||
default=False, metadata={"help": "Apply whitening to singular vectors"}
|
|
||||||
)
|
)
|
||||||
|
whiten: bool = field(default=False, metadata={"help": "Apply whitening to singular vectors"})
|
||||||
adjust_scaling_factors: bool = field(
|
adjust_scaling_factors: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Adjust LoRA scaling factors after the rank redistribution"},
|
metadata={"help": "Adjust LoRA scaling factors after the rank redistribution"},
|
||||||
@ -247,21 +233,16 @@ class LoraConfig(PeftConfig):
|
|||||||
)
|
)
|
||||||
exclude_modules: Optional[Union[list[str], str]] = field(
|
exclude_modules: Optional[Union[list[str], str]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "List of module names or regex expression of the module names to exclude from Lora."},
|
||||||
"help": "List of module names or regex expression of the module names to exclude from Lora."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
|
lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
|
||||||
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
|
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
|
||||||
fan_in_fan_out: bool = field(
|
fan_in_fan_out: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
|
||||||
"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
bias: Literal["none", "all", "lora_only"] = field(
|
bias: Literal["none", "all", "lora_only"] = field(
|
||||||
default="none",
|
default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}
|
||||||
metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"},
|
|
||||||
)
|
)
|
||||||
use_rslora: bool = field(
|
use_rslora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@ -283,15 +264,7 @@ class LoraConfig(PeftConfig):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
init_lora_weights: (
|
init_lora_weights: (
|
||||||
bool
|
bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]
|
||||||
| Literal[
|
|
||||||
"gaussian",
|
|
||||||
"eva",
|
|
||||||
"olora",
|
|
||||||
"pissa",
|
|
||||||
"pissa_niter_[number of iters]",
|
|
||||||
"loftq",
|
|
||||||
]
|
|
||||||
) = field(
|
) = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
@ -445,72 +418,47 @@ class LoraConfig(PeftConfig):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
self.peft_type = PeftType.LORA
|
self.peft_type = PeftType.LORA
|
||||||
self.target_modules = (
|
self.target_modules = (
|
||||||
set(self.target_modules)
|
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
|
||||||
if isinstance(self.target_modules, list)
|
|
||||||
else self.target_modules
|
|
||||||
)
|
)
|
||||||
self.exclude_modules = (
|
self.exclude_modules = (
|
||||||
set(self.exclude_modules)
|
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
|
||||||
if isinstance(self.exclude_modules, list)
|
|
||||||
else self.exclude_modules
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# if target_modules is a regex expression, then layers_to_transform should be None
|
# if target_modules is a regex expression, then layers_to_transform should be None
|
||||||
if (
|
if isinstance(self.target_modules, str) and self.layers_to_transform is not None:
|
||||||
isinstance(self.target_modules, str)
|
raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.")
|
||||||
and self.layers_to_transform is not None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"`layers_to_transform` cannot be used when `target_modules` is a str."
|
|
||||||
)
|
|
||||||
|
|
||||||
# if target_modules is a regex expression, then layers_pattern should be None
|
# if target_modules is a regex expression, then layers_pattern should be None
|
||||||
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
|
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
|
||||||
raise ValueError(
|
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")
|
||||||
"`layers_pattern` cannot be used when `target_modules` is a str."
|
|
||||||
)
|
|
||||||
|
|
||||||
# check for layers_to_transform and layers_pattern
|
# check for layers_to_transform and layers_pattern
|
||||||
if self.layers_pattern and not self.layers_to_transform:
|
if self.layers_pattern and not self.layers_to_transform:
|
||||||
raise ValueError(
|
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ")
|
||||||
"When `layers_pattern` is specified, `layers_to_transform` must also be specified. "
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_dora and self.megatron_config:
|
if self.use_dora and self.megatron_config:
|
||||||
raise ValueError(
|
raise ValueError("DoRA does not support megatron_core, please set `use_dora=False`.")
|
||||||
"DoRA does not support megatron_core, please set `use_dora=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle init_lora_weights and loftq_config
|
# handle init_lora_weights and loftq_config
|
||||||
if self.init_lora_weights == "loftq":
|
if self.init_lora_weights == "loftq":
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
if not importlib.util.find_spec("scipy"):
|
if not importlib.util.find_spec("scipy"):
|
||||||
raise ImportError(
|
raise ImportError("The required package 'scipy' is not installed. Please install it to continue.")
|
||||||
"The required package 'scipy' is not installed. Please install it to continue."
|
|
||||||
)
|
|
||||||
if not self.loftq_config:
|
if not self.loftq_config:
|
||||||
raise ValueError(
|
raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")
|
||||||
"`loftq_config` must be specified when `init_lora_weights` is 'loftq'."
|
|
||||||
)
|
|
||||||
if not isinstance(self.loftq_config, dict):
|
if not isinstance(self.loftq_config, dict):
|
||||||
# convert loftq_config to dict
|
# convert loftq_config to dict
|
||||||
self.loftq_config = vars(self.loftq_config)
|
self.loftq_config = vars(self.loftq_config)
|
||||||
elif self.loftq_config:
|
elif self.loftq_config:
|
||||||
self.loftq_config = {}
|
self.loftq_config = {}
|
||||||
warnings.warn(
|
warnings.warn("`loftq_config` specified but will be ignored when `init_lora_weights` is not 'loftq'.")
|
||||||
"`loftq_config` specified but will be ignored when `init_lora_weights` is not 'loftq'."
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.init_lora_weights == "eva" and self.eva_config is None:
|
elif self.init_lora_weights == "eva" and self.eva_config is None:
|
||||||
warnings.warn(
|
warnings.warn("`init_lora_weights` is 'eva' but `eva_config` is not specified. Using default EVA config.")
|
||||||
"`init_lora_weights` is 'eva' but `eva_config` is not specified. Using default EVA config."
|
|
||||||
)
|
|
||||||
self.eva_config = EvaConfig()
|
self.eva_config = EvaConfig()
|
||||||
elif self.init_lora_weights != "eva" and self.eva_config is not None:
|
elif self.init_lora_weights != "eva" and self.eva_config is not None:
|
||||||
warnings.warn(
|
warnings.warn("`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'.")
|
||||||
"`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.lora_bias:
|
if self.lora_bias:
|
||||||
if self.init_lora_weights not in (True, False):
|
if self.init_lora_weights not in (True, False):
|
||||||
@ -519,9 +467,7 @@ class LoraConfig(PeftConfig):
|
|||||||
f"init_lora_weights={self.init_lora_weights} instead."
|
f"init_lora_weights={self.init_lora_weights} instead."
|
||||||
)
|
)
|
||||||
if self.use_dora:
|
if self.use_dora:
|
||||||
raise ValueError(
|
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")
|
||||||
"The argument lora_bias=True is not supported for DoRA, please pass use_dora=False"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
|
# Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
|
||||||
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
|
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
|
||||||
@ -531,10 +477,7 @@ class LoraConfig(PeftConfig):
|
|||||||
self.use_rslora
|
self.use_rslora
|
||||||
and (self.rank_pattern or self.alpha_pattern)
|
and (self.rank_pattern or self.alpha_pattern)
|
||||||
and (
|
and (
|
||||||
(
|
(isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa")))
|
||||||
isinstance(self.init_lora_weights, str)
|
|
||||||
and (self.init_lora_weights.startswith("pissa"))
|
|
||||||
)
|
|
||||||
or (self.init_lora_weights == "olora")
|
or (self.init_lora_weights == "olora")
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@ -548,9 +491,7 @@ class LoraConfig(PeftConfig):
|
|||||||
|
|
||||||
self._custom_modules: Optional[dict[type[nn.Mmodule], type[nn.Module]]] = None
|
self._custom_modules: Optional[dict[type[nn.Mmodule], type[nn.Module]]] = None
|
||||||
|
|
||||||
def _register_custom_module(
|
def _register_custom_module(self, mapping: dict[type[nn.Mmodule], type[nn.Module]]) -> None:
|
||||||
self, mapping: dict[type[nn.Mmodule], type[nn.Module]]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Experimental API to support providing custom LoRA layers.
|
Experimental API to support providing custom LoRA layers.
|
||||||
|
|
||||||
|
|||||||
@ -34,9 +34,7 @@ class DoraLinearLayer(nn.Module):
|
|||||||
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
||||||
return weight_norm
|
return weight_norm
|
||||||
|
|
||||||
def update_layer(
|
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None:
|
||||||
self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False
|
|
||||||
) -> None:
|
|
||||||
# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
|
# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
|
||||||
dtype_is_fp16 = lora_A.dtype == torch.float16
|
dtype_is_fp16 = lora_A.dtype == torch.float16
|
||||||
if dtype_is_fp16:
|
if dtype_is_fp16:
|
||||||
@ -51,18 +49,14 @@ class DoraLinearLayer(nn.Module):
|
|||||||
|
|
||||||
weight = dequantize_module_weight(base_layer)
|
weight = dequantize_module_weight(base_layer)
|
||||||
if weight.data.ndim >= 4: # For handling LoRAs applied to Conv layers.
|
if weight.data.ndim >= 4: # For handling LoRAs applied to Conv layers.
|
||||||
lora_weight = torch.mm(
|
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
|
||||||
lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1)
|
|
||||||
)
|
|
||||||
lora_weight = lora_weight.reshape(weight.shape)
|
lora_weight = lora_weight.reshape(weight.shape)
|
||||||
else:
|
else:
|
||||||
lora_weight = lora_B @ lora_A
|
lora_weight = lora_B @ lora_A
|
||||||
|
|
||||||
if dtype_is_fp16:
|
if dtype_is_fp16:
|
||||||
lora_weight = lora_weight.half()
|
lora_weight = lora_weight.half()
|
||||||
weight_norm = self.get_weight_norm(
|
weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling)
|
||||||
weight.to(lora_A.device), lora_weight, scaling
|
|
||||||
)
|
|
||||||
|
|
||||||
if place_on_cpu:
|
if place_on_cpu:
|
||||||
weight_norm = weight_norm.to("cpu")
|
weight_norm = weight_norm.to("cpu")
|
||||||
@ -75,9 +69,7 @@ class DoraLinearLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
|
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
|
||||||
# calculate the same but using forward.
|
# calculate the same but using forward.
|
||||||
x_eye = torch.eye(
|
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype)
|
||||||
lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
lora_weight = lora_B(lora_A(x_eye)).T
|
lora_weight = lora_B(lora_A(x_eye)).T
|
||||||
|
|
||||||
magnitude = self.weight
|
magnitude = self.weight
|
||||||
@ -103,9 +95,7 @@ class DoraLinearLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))
|
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))
|
||||||
|
|
||||||
result_dora = (
|
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling
|
||||||
mag_norm_scale - 1
|
|
||||||
) * base_result + mag_norm_scale * lora_result * scaling
|
|
||||||
|
|
||||||
return result_dora
|
return result_dora
|
||||||
|
|
||||||
@ -155,9 +145,7 @@ class _DoraConvNdLayer(DoraLinearLayer):
|
|||||||
output.
|
output.
|
||||||
"""
|
"""
|
||||||
weight = base_layer.weight
|
weight = base_layer.weight
|
||||||
lora_weight = torch.mm(
|
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
|
||||||
lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)
|
|
||||||
)
|
|
||||||
lora_weight = lora_weight.reshape(weight.shape)
|
lora_weight = lora_weight.reshape(weight.shape)
|
||||||
magnitude = self.weight
|
magnitude = self.weight
|
||||||
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling)
|
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling)
|
||||||
|
|||||||
@ -38,9 +38,7 @@ if is_eetq_available():
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_dora:
|
if use_dora:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
|
||||||
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
LoraLayer.__init__(self, base_layer)
|
LoraLayer.__init__(self, base_layer)
|
||||||
@ -87,17 +85,11 @@ if is_eetq_available():
|
|||||||
result = result + output
|
result = result + output
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
|
raise AttributeError("Merging LoRA layers is not supported for Eetq layers.")
|
||||||
) -> None:
|
|
||||||
raise AttributeError(
|
|
||||||
"Merging LoRA layers is not supported for Eetq layers."
|
|
||||||
)
|
|
||||||
|
|
||||||
def unmerge(self) -> None:
|
def unmerge(self) -> None:
|
||||||
raise AttributeError(
|
raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.")
|
||||||
"Unmerging LoRA layers is not supported for Eetq layers."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
rep = super().__repr__()
|
rep = super().__repr__()
|
||||||
|
|||||||
@ -26,10 +26,7 @@ import torch.distributed as dist
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from peft.tuners.tuners_utils import (
|
from peft.tuners.tuners_utils import _find_minimal_target_modules, check_target_module_exists
|
||||||
_find_minimal_target_modules,
|
|
||||||
check_target_module_exists,
|
|
||||||
)
|
|
||||||
from peft.utils.constants import MIN_TARGET_MODULES_FOR_OPTIMIZATION
|
from peft.utils.constants import MIN_TARGET_MODULES_FOR_OPTIMIZATION
|
||||||
from peft.utils.incremental_pca import IncrementalPCA
|
from peft.utils.incremental_pca import IncrementalPCA
|
||||||
from peft.utils.other import _get_submodules, get_pattern_key
|
from peft.utils.other import _get_submodules, get_pattern_key
|
||||||
@ -61,9 +58,7 @@ class _Hook:
|
|||||||
self.model_input = None
|
self.model_input = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prepare_layer_inputs_fn_default(
|
def _prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> torch.Tensor:
|
||||||
layer_input, model_input, layer_name
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if isinstance(layer_input, torch.Tensor):
|
if isinstance(layer_input, torch.Tensor):
|
||||||
pass
|
pass
|
||||||
elif isinstance(layer_input, (tuple, list)):
|
elif isinstance(layer_input, (tuple, list)):
|
||||||
@ -88,28 +83,20 @@ class _Hook:
|
|||||||
|
|
||||||
# First gather sizes from all processes more efficiently
|
# First gather sizes from all processes more efficiently
|
||||||
local_size = torch.tensor([layer_input.shape[0]], device=layer_input.device)
|
local_size = torch.tensor([layer_input.shape[0]], device=layer_input.device)
|
||||||
all_sizes = torch.empty(
|
all_sizes = torch.empty(world_size, dtype=local_size.dtype, device=layer_input.device)
|
||||||
world_size, dtype=local_size.dtype, device=layer_input.device
|
|
||||||
)
|
|
||||||
dist.all_gather_into_tensor(all_sizes, local_size)
|
dist.all_gather_into_tensor(all_sizes, local_size)
|
||||||
all_sizes = all_sizes.tolist()
|
all_sizes = all_sizes.tolist()
|
||||||
|
|
||||||
# Find maximum size and pad tensors
|
# Find maximum size and pad tensors
|
||||||
padded_input = layer_input.new_zeros(
|
padded_input = layer_input.new_zeros((max(all_sizes), *layer_input.shape[1:]))
|
||||||
(max(all_sizes), *layer_input.shape[1:])
|
|
||||||
)
|
|
||||||
padded_input[: layer_input.shape[0]] = layer_input
|
padded_input[: layer_input.shape[0]] = layer_input
|
||||||
|
|
||||||
# Gather padded tensors
|
# Gather padded tensors
|
||||||
gathered_inputs = [
|
gathered_inputs = [torch.zeros_like(padded_input) for _ in range(world_size)]
|
||||||
torch.zeros_like(padded_input) for _ in range(world_size)
|
|
||||||
]
|
|
||||||
dist.all_gather(gathered_inputs, padded_input.contiguous())
|
dist.all_gather(gathered_inputs, padded_input.contiguous())
|
||||||
|
|
||||||
# Remove padding for each gathered tensor
|
# Remove padding for each gathered tensor
|
||||||
gathered_inputs = [
|
gathered_inputs = [tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes)]
|
||||||
tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Concatenate along batch dimension
|
# Concatenate along batch dimension
|
||||||
return torch.cat(gathered_inputs, dim=0)
|
return torch.cat(gathered_inputs, dim=0)
|
||||||
@ -166,9 +153,7 @@ class SVDHook(_Hook):
|
|||||||
states = self.gather_layer_inputs(states)
|
states = self.gather_layer_inputs(states)
|
||||||
# check if batch sizes is more than the number of components
|
# check if batch sizes is more than the number of components
|
||||||
if states.size(0) < self.n_components:
|
if states.size(0) < self.n_components:
|
||||||
print(
|
print(f"skipping SVD for {self.name} because there are less than {self.n_components} examples")
|
||||||
f"skipping SVD for {self.name} because there are less than {self.n_components} examples"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
self.svd.partial_fit(states.to(torch.float32))
|
self.svd.partial_fit(states.to(torch.float32))
|
||||||
# add if statement to check if we are in the first step where previous_components is None
|
# add if statement to check if we are in the first step where previous_components is None
|
||||||
@ -233,9 +218,7 @@ def get_device_with_meta_params(model: torch.nn.Module) -> torch.device:
|
|||||||
"""
|
"""
|
||||||
devices = list({p.device for p in model.parameters() if p.device.type != "meta"})
|
devices = list({p.device for p in model.parameters() if p.device.type != "meta"})
|
||||||
if len(devices) > 1:
|
if len(devices) > 1:
|
||||||
warnings.warn(
|
warnings.warn(f"Could not determine device, model has multiple devices: {devices}")
|
||||||
f"Could not determine device, model has multiple devices: {devices}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
return devices[0]
|
return devices[0]
|
||||||
|
|
||||||
@ -247,15 +230,11 @@ def move_inputs_to_device(inputs, device: Union[str, torch.device]):
|
|||||||
if hasattr(inputs, "to"):
|
if hasattr(inputs, "to"):
|
||||||
return inputs.to(device)
|
return inputs.to(device)
|
||||||
if isinstance(inputs, Mapping):
|
if isinstance(inputs, Mapping):
|
||||||
return type(inputs)(
|
return type(inputs)({k: move_inputs_to_device(v, device) for k, v in inputs.items()})
|
||||||
{k: move_inputs_to_device(v, device) for k, v in inputs.items()}
|
|
||||||
)
|
|
||||||
elif isinstance(inputs, (tuple, list)):
|
elif isinstance(inputs, (tuple, list)):
|
||||||
return type(inputs)(move_inputs_to_device(v, device) for v in inputs)
|
return type(inputs)(move_inputs_to_device(v, device) for v in inputs)
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(f"input of type {type(inputs)} could not be moved to the correct device")
|
||||||
f"input of type {type(inputs)} could not be moved to the correct device"
|
|
||||||
)
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
@ -268,22 +247,14 @@ def prepare_model_inputs_fn_language_modeling(model_input, peft_config: LoraConf
|
|||||||
peft_config (LoraConfig): The configuration for the LoRA layers.
|
peft_config (LoraConfig): The configuration for the LoRA layers.
|
||||||
"""
|
"""
|
||||||
if not isinstance(model_input, dict):
|
if not isinstance(model_input, dict):
|
||||||
raise ValueError(
|
raise ValueError("When using `prepare_model_inputs_fn_language_modeling` inputs must be a dictionary")
|
||||||
"When using `prepare_model_inputs_fn_language_modeling` inputs must be a dictionary"
|
mask = model_input.get("attention_mask", torch.ones_like(model_input["input_ids"])).bool()
|
||||||
)
|
|
||||||
mask = model_input.get(
|
|
||||||
"attention_mask", torch.ones_like(model_input["input_ids"])
|
|
||||||
).bool()
|
|
||||||
if peft_config.eva_config.use_label_mask and hasattr(model_input, "labels"):
|
if peft_config.eva_config.use_label_mask and hasattr(model_input, "labels"):
|
||||||
mask = torch.logical_and(
|
mask = torch.logical_and(mask, model_input["labels"] != peft_config.eva_config.label_mask_value)
|
||||||
mask, model_input["labels"] != peft_config.eva_config.label_mask_value
|
|
||||||
)
|
|
||||||
return mask.nonzero()
|
return mask.nonzero()
|
||||||
|
|
||||||
|
|
||||||
def prepare_layer_inputs_fn_language_modeling(
|
def prepare_layer_inputs_fn_language_modeling(layer_input, model_input, layer_name) -> torch.Tensor:
|
||||||
layer_input, model_input, layer_name
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
if not all items in the input should be used for SVD, this function can be used to get the indices of the items
|
if not all items in the input should be used for SVD, this function can be used to get the indices of the items
|
||||||
that should be used.
|
that should be used.
|
||||||
@ -328,21 +299,12 @@ def _get_eva_state_dict(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
# Computes the rank distribution for each layer based on the explained variance ratio.
|
# Computes the rank distribution for each layer based on the explained variance ratio.
|
||||||
# when rank_pattern flag is False, all values in max_components are the same
|
# when rank_pattern flag is False, all values in max_components are the same
|
||||||
def _get_rank_distribution(
|
def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components):
|
||||||
hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components
|
exp_vars = {k: h[0].svd.explained_variance_ratio_[: max_components[k]] for k, h in hooks.items()}
|
||||||
):
|
keys, values = zip(*[(k, c) for k, name in layer_hook_map.items() for c in exp_vars[name]])
|
||||||
exp_vars = {
|
|
||||||
k: h[0].svd.explained_variance_ratio_[: max_components[k]]
|
|
||||||
for k, h in hooks.items()
|
|
||||||
}
|
|
||||||
keys, values = zip(
|
|
||||||
*[(k, c) for k, name in layer_hook_map.items() for c in exp_vars[name]]
|
|
||||||
)
|
|
||||||
idx = torch.stack(values).argsort(descending=True)
|
idx = torch.stack(values).argsort(descending=True)
|
||||||
counts = Counter([keys[i] for i in idx[:rank_budget]])
|
counts = Counter([keys[i] for i in idx[:rank_budget]])
|
||||||
counts = {
|
counts = {k: counts.get(k, 0) for k in layer_hook_map.keys()} # add layers with 0 rank
|
||||||
k: counts.get(k, 0) for k in layer_hook_map.keys()
|
|
||||||
} # add layers with 0 rank
|
|
||||||
for k, k_hook in equal_inputs_map.items():
|
for k, k_hook in equal_inputs_map.items():
|
||||||
# ensure hook layers have the highest rank if they are equal to another layer
|
# ensure hook layers have the highest rank if they are equal to another layer
|
||||||
rank, rank_hook = counts[k], counts[k_hook]
|
rank, rank_hook = counts[k], counts[k_hook]
|
||||||
@ -394,11 +356,7 @@ def _get_eva_state_dict(
|
|||||||
fn = prepare_layer_inputs_fn.pop(name, None)
|
fn = prepare_layer_inputs_fn.pop(name, None)
|
||||||
else:
|
else:
|
||||||
fn = prepare_layer_inputs_fn
|
fn = prepare_layer_inputs_fn
|
||||||
hook = HashHook(
|
hook = HashHook(name=name, prepare_layer_inputs_fn=fn, gather_distributed_inputs=gather_distributed_inputs)
|
||||||
name=name,
|
|
||||||
prepare_layer_inputs_fn=fn,
|
|
||||||
gather_distributed_inputs=gather_distributed_inputs,
|
|
||||||
)
|
|
||||||
hook.model_input = model_inputs_for_hooks
|
hook.model_input = model_inputs_for_hooks
|
||||||
handle = module.register_forward_hook(hook)
|
handle = module.register_forward_hook(hook)
|
||||||
hooks[name] = (hook, handle)
|
hooks[name] = (hook, handle)
|
||||||
@ -407,10 +365,7 @@ def _get_eva_state_dict(
|
|||||||
)
|
)
|
||||||
max_components[name] = round(layer_rank * rho)
|
max_components[name] = round(layer_rank * rho)
|
||||||
rank_budget += layer_rank
|
rank_budget += layer_rank
|
||||||
if (
|
if isinstance(prepare_layer_inputs_fn, Mapping) and len(prepare_layer_inputs_fn) > 0:
|
||||||
isinstance(prepare_layer_inputs_fn, Mapping)
|
|
||||||
and len(prepare_layer_inputs_fn) > 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"prepare_layer_inputs_fn is a mapping but the following module names were not found in the model: "
|
"prepare_layer_inputs_fn is a mapping but the following module names were not found in the model: "
|
||||||
f"{prepare_layer_inputs_fn.keys()}"
|
f"{prepare_layer_inputs_fn.keys()}"
|
||||||
@ -444,10 +399,7 @@ def _get_eva_state_dict(
|
|||||||
)
|
)
|
||||||
module = model.get_submodule(name)
|
module = model.get_submodule(name)
|
||||||
handle = module.register_forward_hook(hook)
|
handle = module.register_forward_hook(hook)
|
||||||
hooks[name] = (
|
hooks[name] = (hook, handle) # adding the old handle here so we dont get errors in the first forward pass
|
||||||
hook,
|
|
||||||
handle,
|
|
||||||
) # adding the old handle here so we dont get errors in the first forward pass
|
|
||||||
layer_hook_map = {**dict(zip(hooks.keys(), hooks.keys())), **equal_inputs_map}
|
layer_hook_map = {**dict(zip(hooks.keys(), hooks.keys())), **equal_inputs_map}
|
||||||
|
|
||||||
# start svd calculation
|
# start svd calculation
|
||||||
@ -489,9 +441,7 @@ def _get_eva_state_dict(
|
|||||||
layer_converged = list(convergence_dict.values()) + [
|
layer_converged = list(convergence_dict.values()) + [
|
||||||
convergence_dict[v] for v in equal_inputs_map.values()
|
convergence_dict[v] for v in equal_inputs_map.values()
|
||||||
]
|
]
|
||||||
pbar.set_description(
|
pbar.set_description(f"{sum(layer_converged)}/{len(layer_converged)} layers have converged")
|
||||||
f"{sum(layer_converged)}/{len(layer_converged)} layers have converged"
|
|
||||||
)
|
|
||||||
|
|
||||||
if all(convergence_dict.values()):
|
if all(convergence_dict.values()):
|
||||||
break
|
break
|
||||||
@ -503,17 +453,10 @@ def _get_eva_state_dict(
|
|||||||
if not all(hasattr(h[0].svd, "components_") for h in hooks.values()):
|
if not all(hasattr(h[0].svd, "components_") for h in hooks.values()):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
rank_dist = _get_rank_distribution(
|
rank_dist = _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components)
|
||||||
hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components
|
|
||||||
)
|
|
||||||
|
|
||||||
# check all custom hooks have been removed
|
# check all custom hooks have been removed
|
||||||
remaining_hooks = {
|
remaining_hooks = {n for n, m in model.named_modules() for v in m._forward_hooks.values() if isinstance(v, _Hook)}
|
||||||
n
|
|
||||||
for n, m in model.named_modules()
|
|
||||||
for v in m._forward_hooks.values()
|
|
||||||
if isinstance(v, _Hook)
|
|
||||||
}
|
|
||||||
if len(remaining_hooks) > 0:
|
if len(remaining_hooks) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Found active hooks added by EVA that weren't properly removed: {remaining_hooks}. "
|
f"Found active hooks added by EVA that weren't properly removed: {remaining_hooks}. "
|
||||||
@ -567,12 +510,9 @@ def _load_eva_state_dict(
|
|||||||
other_module_names.append(name_in_base_model)
|
other_module_names.append(name_in_base_model)
|
||||||
continue
|
continue
|
||||||
# Regexp matching - Find key which matches current target_name in patterns provided
|
# Regexp matching - Find key which matches current target_name in patterns provided
|
||||||
r = peft_config.rank_pattern.get(
|
r = peft_config.rank_pattern.get(get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r)
|
||||||
get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r
|
|
||||||
)
|
|
||||||
alpha = peft_config.alpha_pattern.get(
|
alpha = peft_config.alpha_pattern.get(
|
||||||
get_pattern_key(peft_config.alpha_pattern.keys(), name),
|
get_pattern_key(peft_config.alpha_pattern.keys(), name), peft_config.lora_alpha
|
||||||
peft_config.lora_alpha,
|
|
||||||
)
|
)
|
||||||
if name in eva_state_dict:
|
if name in eva_state_dict:
|
||||||
w = eva_state_dict.pop(name)
|
w = eva_state_dict.pop(name)
|
||||||
@ -584,22 +524,12 @@ def _load_eva_state_dict(
|
|||||||
elif new_rank != r:
|
elif new_rank != r:
|
||||||
if peft_config.eva_config.adjust_scaling_factors:
|
if peft_config.eva_config.adjust_scaling_factors:
|
||||||
alpha *= new_rank / r
|
alpha *= new_rank / r
|
||||||
if (
|
if new_rank != r or module.lora_A[adapter_name].weight.device.type == "meta":
|
||||||
new_rank != r
|
module.update_layer(r=new_rank, lora_alpha=alpha, init_lora_weights="eva", **update_layer_kwargs)
|
||||||
or module.lora_A[adapter_name].weight.device.type == "meta"
|
|
||||||
):
|
|
||||||
module.update_layer(
|
|
||||||
r=new_rank,
|
|
||||||
lora_alpha=alpha,
|
|
||||||
init_lora_weights="eva",
|
|
||||||
**update_layer_kwargs,
|
|
||||||
)
|
|
||||||
module.lora_A[adapter_name].weight.copy_(w)
|
module.lora_A[adapter_name].weight.copy_(w)
|
||||||
new_target_modules.append(name_in_base_model)
|
new_target_modules.append(name_in_base_model)
|
||||||
else:
|
else:
|
||||||
module.update_layer(
|
module.update_layer(r=r, lora_alpha=alpha, init_lora_weights=True, **update_layer_kwargs)
|
||||||
r=r, lora_alpha=alpha, init_lora_weights=True, **update_layer_kwargs
|
|
||||||
)
|
|
||||||
missing_eva_inits.append(name_in_base_model)
|
missing_eva_inits.append(name_in_base_model)
|
||||||
new_rank = r
|
new_rank = r
|
||||||
# update rank pattern and alpha pattern
|
# update rank pattern and alpha pattern
|
||||||
@ -611,9 +541,7 @@ def _load_eva_state_dict(
|
|||||||
# update target modules if some lora layers have been removed due to their EVA rank being 0
|
# update target modules if some lora layers have been removed due to their EVA rank being 0
|
||||||
new_target_modules = new_target_modules + missing_eva_inits
|
new_target_modules = new_target_modules + missing_eva_inits
|
||||||
if len(new_target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION:
|
if len(new_target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION:
|
||||||
new_target_modules = _find_minimal_target_modules(
|
new_target_modules = _find_minimal_target_modules(new_target_modules, other_module_names)
|
||||||
new_target_modules, other_module_names
|
|
||||||
)
|
|
||||||
model.peft_config[adapter_name].target_modules = new_target_modules
|
model.peft_config[adapter_name].target_modules = new_target_modules
|
||||||
|
|
||||||
# set rank pattern obtained from EVA
|
# set rank pattern obtained from EVA
|
||||||
@ -636,12 +564,8 @@ def get_eva_state_dict(
|
|||||||
dataloader: Iterable,
|
dataloader: Iterable,
|
||||||
peft_config: Optional[LoraConfig] = None,
|
peft_config: Optional[LoraConfig] = None,
|
||||||
forward_fn: Optional[callable] = forward_fn_dict,
|
forward_fn: Optional[callable] = forward_fn_dict,
|
||||||
prepare_model_inputs_fn: Optional[
|
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
|
||||||
callable
|
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
|
||||||
] = prepare_model_inputs_fn_language_modeling,
|
|
||||||
prepare_layer_inputs_fn: Union[
|
|
||||||
callable, Dict[str, callable], None
|
|
||||||
] = prepare_layer_inputs_fn_language_modeling,
|
|
||||||
adapter_name: str = "default",
|
adapter_name: str = "default",
|
||||||
gather_distributed_inputs: bool = True,
|
gather_distributed_inputs: bool = True,
|
||||||
show_progress_bar: bool = True,
|
show_progress_bar: bool = True,
|
||||||
@ -689,9 +613,7 @@ def get_eva_state_dict(
|
|||||||
|
|
||||||
def target_module_check_fn_peft_model(name, module, unsupported_lora_modules):
|
def target_module_check_fn_peft_model(name, module, unsupported_lora_modules):
|
||||||
"check if a module is an adapter module via base_layer attribute"
|
"check if a module is an adapter module via base_layer attribute"
|
||||||
return hasattr(module, "base_layer") and not isinstance(
|
return hasattr(module, "base_layer") and not isinstance(module, unsupported_lora_modules)
|
||||||
module, unsupported_lora_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
def target_module_check_fn_default(name, module, peft_config):
|
def target_module_check_fn_default(name, module, peft_config):
|
||||||
"check if a module is an adapter module via target_modules"
|
"check if a module is an adapter module via target_modules"
|
||||||
@ -713,14 +635,11 @@ def get_eva_state_dict(
|
|||||||
if is_peft_model:
|
if is_peft_model:
|
||||||
ctx = model.disable_adapter()
|
ctx = model.disable_adapter()
|
||||||
target_module_check_fn = partial(
|
target_module_check_fn = partial(
|
||||||
target_module_check_fn_peft_model,
|
target_module_check_fn_peft_model, unsupported_lora_modules=UNSUPPORTED_LORA_MODULES
|
||||||
unsupported_lora_modules=UNSUPPORTED_LORA_MODULES,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ctx = nullcontext()
|
ctx = nullcontext()
|
||||||
target_module_check_fn = partial(
|
target_module_check_fn = partial(target_module_check_fn_default, peft_config=peft_config)
|
||||||
target_module_check_fn_default, peft_config=peft_config
|
|
||||||
)
|
|
||||||
|
|
||||||
with ctx:
|
with ctx:
|
||||||
eva_state_dict = _get_eva_state_dict(
|
eva_state_dict = _get_eva_state_dict(
|
||||||
@ -743,12 +662,8 @@ def initialize_lora_eva_weights(
|
|||||||
dataloader: Optional[Iterable] = None,
|
dataloader: Optional[Iterable] = None,
|
||||||
eva_state_dict: Optional[dict] = None,
|
eva_state_dict: Optional[dict] = None,
|
||||||
forward_fn: Optional[callable] = forward_fn_dict,
|
forward_fn: Optional[callable] = forward_fn_dict,
|
||||||
prepare_model_inputs_fn: Optional[
|
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
|
||||||
callable
|
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
|
||||||
] = prepare_model_inputs_fn_language_modeling,
|
|
||||||
prepare_layer_inputs_fn: Union[
|
|
||||||
callable, Dict[str, callable], None
|
|
||||||
] = prepare_layer_inputs_fn_language_modeling,
|
|
||||||
adapter_name: str = "default",
|
adapter_name: str = "default",
|
||||||
gather_distributed_inputs: bool = True,
|
gather_distributed_inputs: bool = True,
|
||||||
show_progress_bar: bool = True,
|
show_progress_bar: bool = True,
|
||||||
@ -800,15 +715,11 @@ def initialize_lora_eva_weights(
|
|||||||
# eva currently only works with a single active adapter
|
# eva currently only works with a single active adapter
|
||||||
# Important: when removing this requirement, make sure eva init works correctly if the new rank is 0.
|
# Important: when removing this requirement, make sure eva init works correctly if the new rank is 0.
|
||||||
if len(model.active_adapters) > 1:
|
if len(model.active_adapters) > 1:
|
||||||
raise ValueError(
|
raise ValueError("`initialize_lora_eva_weights` currently only works with a single active adapter")
|
||||||
"`initialize_lora_eva_weights` currently only works with a single active adapter"
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize_lora_eva_weights only works with `init_lora_weights='eva'`
|
# initialize_lora_eva_weights only works with `init_lora_weights='eva'`
|
||||||
if model.peft_config[adapter_name].init_lora_weights != "eva":
|
if model.peft_config[adapter_name].init_lora_weights != "eva":
|
||||||
raise ValueError(
|
raise ValueError("`initialize_lora_eva_weights` can only be used with `init_lora_weights='eva'`")
|
||||||
"`initialize_lora_eva_weights` can only be used with `init_lora_weights='eva'`"
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute svd
|
# compute svd
|
||||||
if eva_state_dict is None:
|
if eva_state_dict is None:
|
||||||
|
|||||||
@ -39,9 +39,7 @@ class QuantLinear(torch.nn.Module, LoraLayer):
|
|||||||
LoraLayer.__init__(self, base_layer)
|
LoraLayer.__init__(self, base_layer)
|
||||||
|
|
||||||
if use_dora:
|
if use_dora:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
|
||||||
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
||||||
# for backwards compatibility
|
# for backwards compatibility
|
||||||
@ -111,9 +109,7 @@ def dispatch_gptq(
|
|||||||
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
|
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
|
||||||
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
|
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
|
||||||
|
|
||||||
if AutoGPTQQuantLinear is not None and isinstance(
|
if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
|
||||||
target_base_layer, AutoGPTQQuantLinear
|
|
||||||
):
|
|
||||||
new_module = QuantLinear(target, adapter_name, **kwargs)
|
new_module = QuantLinear(target, adapter_name, **kwargs)
|
||||||
target.qweight = target_base_layer.qweight
|
target.qweight = target_base_layer.qweight
|
||||||
|
|
||||||
|
|||||||
@ -45,9 +45,7 @@ if is_hqq_available():
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if lora_bias:
|
if lora_bias:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
|
||||||
f"{self.__class__.__name__} does not support lora_bias yet, set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
LoraLayer.__init__(self, base_layer)
|
LoraLayer.__init__(self, base_layer)
|
||||||
@ -65,9 +63,7 @@ if is_hqq_available():
|
|||||||
lora_bias=lora_bias,
|
lora_bias=lora_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Merge the active adapter weights into the base weights
|
Merge the active adapter weights into the base weights
|
||||||
|
|
||||||
@ -90,10 +86,7 @@ if is_hqq_available():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
layer = self.get_base_layer()
|
layer = self.get_base_layer()
|
||||||
quant_config = {
|
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
|
||||||
**copy.deepcopy(layer.quant_config),
|
|
||||||
"offload_meta": layer.offload_meta,
|
|
||||||
}
|
|
||||||
lora_data = self.get_delta_weight(active_adapter)
|
lora_data = self.get_delta_weight(active_adapter)
|
||||||
|
|
||||||
output = layer.dequantize()
|
output = layer.dequantize()
|
||||||
@ -102,28 +95,19 @@ if is_hqq_available():
|
|||||||
else:
|
else:
|
||||||
# handle dora
|
# handle dora
|
||||||
# since output already includes scaling, set it to 1 here
|
# since output already includes scaling, set it to 1 here
|
||||||
weight_norm = self._get_weight_norm(
|
weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach()
|
||||||
output, lora_data, scaling=1
|
|
||||||
).detach()
|
|
||||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter] / weight_norm
|
|
||||||
)
|
|
||||||
w_data = dora_factor.view(-1, 1) * (output + lora_data)
|
w_data = dora_factor.view(-1, 1) * (output + lora_data)
|
||||||
|
|
||||||
if safe_merge and not torch.isfinite(w_data).all():
|
if safe_merge and not torch.isfinite(w_data).all():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||||
)
|
)
|
||||||
new_hqq_layer = HQQLinear(
|
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
|
||||||
None,
|
|
||||||
quant_config,
|
|
||||||
compute_dtype=layer.compute_dtype,
|
|
||||||
device=layer.device,
|
|
||||||
)
|
|
||||||
quant_config.pop("offload_meta", None)
|
quant_config.pop("offload_meta", None)
|
||||||
new_hqq_layer.quantize(w_data, **quant_config)
|
new_hqq_layer.quantize(w_data, **quant_config)
|
||||||
self.base_layer = new_hqq_layer
|
self.base_layer = new_hqq_layer
|
||||||
@ -144,27 +128,17 @@ if is_hqq_available():
|
|||||||
|
|
||||||
lora_data = self.get_delta_weight(active_adapter)
|
lora_data = self.get_delta_weight(active_adapter)
|
||||||
layer = self.get_base_layer()
|
layer = self.get_base_layer()
|
||||||
quant_config = {
|
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
|
||||||
**copy.deepcopy(layer.quant_config),
|
|
||||||
"offload_meta": layer.offload_meta,
|
|
||||||
}
|
|
||||||
output = layer.dequantize()
|
output = layer.dequantize()
|
||||||
|
|
||||||
if not self.use_dora[active_adapter]:
|
if not self.use_dora[active_adapter]:
|
||||||
w_data = output - lora_data
|
w_data = output - lora_data
|
||||||
else:
|
else:
|
||||||
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter] / weight_norm
|
|
||||||
)
|
|
||||||
w_data = output.data / dora_factor.view(-1, 1) - lora_data
|
w_data = output.data / dora_factor.view(-1, 1) - lora_data
|
||||||
|
|
||||||
new_hqq_layer = HQQLinear(
|
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
|
||||||
None,
|
|
||||||
quant_config,
|
|
||||||
compute_dtype=layer.compute_dtype,
|
|
||||||
device=layer.device,
|
|
||||||
)
|
|
||||||
quant_config.pop("offload_meta", None)
|
quant_config.pop("offload_meta", None)
|
||||||
new_hqq_layer.quantize(w_data, **quant_config)
|
new_hqq_layer.quantize(w_data, **quant_config)
|
||||||
self.base_layer = new_hqq_layer
|
self.base_layer = new_hqq_layer
|
||||||
@ -188,13 +162,7 @@ if is_hqq_available():
|
|||||||
unique_adapters = set(adapter_names)
|
unique_adapters = set(adapter_names)
|
||||||
sub_batch_indices_list = []
|
sub_batch_indices_list = []
|
||||||
for adapter in unique_adapters:
|
for adapter in unique_adapters:
|
||||||
sub_batch_indices_list.append(
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
||||||
[
|
|
||||||
index
|
|
||||||
for index, item in enumerate(adapter_names)
|
|
||||||
if item == adapter
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, active_adapter in enumerate(unique_adapters):
|
for i, active_adapter in enumerate(unique_adapters):
|
||||||
if active_adapter == "__base__":
|
if active_adapter == "__base__":
|
||||||
@ -233,9 +201,7 @@ if is_hqq_available():
|
|||||||
self.unmerge()
|
self.unmerge()
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
elif adapter_names is not None:
|
elif adapter_names is not None:
|
||||||
result = self._mixed_batch_forward(
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||||
x, *args, adapter_names=adapter_names, **kwargs
|
|
||||||
)
|
|
||||||
elif self.merged:
|
elif self.merged:
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -25,21 +25,11 @@ from torch import svd_lowrank
|
|||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||||
from peft.utils.integrations import (
|
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type
|
||||||
dequantize_module_weight,
|
|
||||||
gather_params_ctx,
|
|
||||||
get_bnb_param_type,
|
|
||||||
)
|
|
||||||
from peft.utils.other import transpose
|
from peft.utils.other import transpose
|
||||||
|
|
||||||
from .config import LoraConfig
|
from .config import LoraConfig
|
||||||
from .dora import (
|
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer
|
||||||
DoraConv2dLayer,
|
|
||||||
DoraConv3dLayer,
|
|
||||||
DoraEmbeddingLayer,
|
|
||||||
DoraLinearLayer,
|
|
||||||
_DoraConvNdLayer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoraLayer(BaseTunerLayer):
|
class LoraLayer(BaseTunerLayer):
|
||||||
@ -48,9 +38,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
# All names of other parameters that may contain adapter-related parameters
|
# All names of other parameters that may contain adapter-related parameters
|
||||||
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")
|
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None:
|
||||||
self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs
|
|
||||||
) -> None:
|
|
||||||
self.base_layer = base_layer
|
self.base_layer = base_layer
|
||||||
self.r = {}
|
self.r = {}
|
||||||
self.lora_alpha = {}
|
self.lora_alpha = {}
|
||||||
@ -79,15 +67,10 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
elif isinstance(base_layer, nn.Conv3d):
|
elif isinstance(base_layer, nn.Conv3d):
|
||||||
in_features, out_features = base_layer.in_channels, base_layer.out_channels
|
in_features, out_features = base_layer.in_channels, base_layer.out_channels
|
||||||
elif isinstance(base_layer, nn.Embedding):
|
elif isinstance(base_layer, nn.Embedding):
|
||||||
in_features, out_features = (
|
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
|
||||||
base_layer.num_embeddings,
|
|
||||||
base_layer.embedding_dim,
|
|
||||||
)
|
|
||||||
elif isinstance(base_layer, Conv1D):
|
elif isinstance(base_layer, Conv1D):
|
||||||
in_features, out_features = (
|
in_features, out_features = (
|
||||||
base_layer.weight.ds_shape
|
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
|
||||||
if hasattr(base_layer.weight, "ds_shape")
|
|
||||||
else base_layer.weight.shape
|
|
||||||
)
|
)
|
||||||
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
|
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
|
||||||
# QuantLinear
|
# QuantLinear
|
||||||
@ -95,40 +78,26 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
|
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
|
||||||
# Megatron ColumnParallelLinear,RowParallelLinear
|
# Megatron ColumnParallelLinear,RowParallelLinear
|
||||||
in_features, out_features = base_layer.input_size, base_layer.output_size
|
in_features, out_features = base_layer.input_size, base_layer.output_size
|
||||||
elif (
|
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear":
|
||||||
hasattr(base_layer, "codebooks")
|
|
||||||
and base_layer.__class__.__name__ == "QuantizedLinear"
|
|
||||||
):
|
|
||||||
# AQLM QuantLinear
|
# AQLM QuantLinear
|
||||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||||
elif (
|
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
|
||||||
hasattr(base_layer, "w_bit")
|
|
||||||
and base_layer.__class__.__name__ == "WQLinear_GEMM"
|
|
||||||
):
|
|
||||||
# Awq layers
|
# Awq layers
|
||||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||||
elif base_layer.__class__.__name__ == "EetqLinear":
|
elif base_layer.__class__.__name__ == "EetqLinear":
|
||||||
# Eetq layers
|
# Eetq layers
|
||||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||||
elif (
|
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
|
||||||
hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear"
|
|
||||||
):
|
|
||||||
# HQQ layers
|
# HQQ layers
|
||||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||||
else:
|
else:
|
||||||
# possibly support user provided custom layer types using dynamic dispatch
|
# possibly support user provided custom layer types using dynamic dispatch
|
||||||
if hasattr(base_layer, "in_features") and hasattr(
|
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
|
||||||
base_layer, "out_features"
|
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||||
):
|
|
||||||
in_features, out_features = (
|
|
||||||
base_layer.in_features,
|
|
||||||
base_layer.out_features,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
in_features, out_features = None, None
|
in_features, out_features = None, None
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.",
|
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning
|
||||||
UserWarning,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
@ -147,9 +116,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
):
|
):
|
||||||
# This code works for linear layers, override for other layer types
|
# This code works for linear layers, override for other layer types
|
||||||
if r <= 0:
|
if r <= 0:
|
||||||
raise ValueError(
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||||
f"`r` should be a positive integer value but the value passed is {r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.r[adapter_name] = r
|
self.r[adapter_name] = r
|
||||||
self.lora_alpha[adapter_name] = lora_alpha
|
self.lora_alpha[adapter_name] = lora_alpha
|
||||||
@ -173,9 +140,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
|
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
|
||||||
with gather_params_ctx(self.get_base_layer().weight):
|
with gather_params_ctx(self.get_base_layer().weight):
|
||||||
self.pissa_init(adapter_name, init_lora_weights)
|
self.pissa_init(adapter_name, init_lora_weights)
|
||||||
elif (
|
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
|
||||||
isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora"
|
|
||||||
):
|
|
||||||
with gather_params_ctx(self.get_base_layer().weight):
|
with gather_params_ctx(self.get_base_layer().weight):
|
||||||
self.olora_init(adapter_name)
|
self.olora_init(adapter_name)
|
||||||
elif init_lora_weights == "loftq":
|
elif init_lora_weights == "loftq":
|
||||||
@ -204,13 +169,9 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
if init_lora_weights is True:
|
if init_lora_weights is True:
|
||||||
# initialize A the same way as the default for nn.Linear and B to zero
|
# initialize A the same way as the default for nn.Linear and B to zero
|
||||||
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
|
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
|
||||||
nn.init.kaiming_uniform_(
|
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
|
||||||
self.lora_A[adapter_name].weight, a=math.sqrt(5)
|
|
||||||
)
|
|
||||||
elif init_lora_weights.lower() == "gaussian":
|
elif init_lora_weights.lower() == "gaussian":
|
||||||
nn.init.normal_(
|
nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name])
|
||||||
self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown initialization {init_lora_weights=}")
|
raise ValueError(f"Unknown initialization {init_lora_weights=}")
|
||||||
nn.init.zeros_(self.lora_B[adapter_name].weight)
|
nn.init.zeros_(self.lora_B[adapter_name].weight)
|
||||||
@ -249,11 +210,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
self.lora_A[adapter_name].weight.data = Rr.contiguous()
|
self.lora_A[adapter_name].weight.data = Rr.contiguous()
|
||||||
self.lora_B[adapter_name].weight.data = Qr.contiguous()
|
self.lora_B[adapter_name].weight.data = Qr.contiguous()
|
||||||
|
|
||||||
weight_tensor.data -= (
|
weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
|
||||||
scale_factor
|
|
||||||
* self.lora_B[adapter_name].weight
|
|
||||||
@ self.lora_A[adapter_name].weight
|
|
||||||
)
|
|
||||||
if bnb_param_type == "4bit":
|
if bnb_param_type == "4bit":
|
||||||
weight_tensor = orig_weight.__class__(
|
weight_tensor = orig_weight.__class__(
|
||||||
weight_tensor,
|
weight_tensor,
|
||||||
@ -292,9 +249,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
Uhr = Uh[: self.r[adapter_name]]
|
Uhr = Uh[: self.r[adapter_name]]
|
||||||
elif len(init_lora_weights.split("_niter_")) == 2:
|
elif len(init_lora_weights.split("_niter_")) == 2:
|
||||||
Vr, Sr, Ur = svd_lowrank(
|
Vr, Sr, Ur = svd_lowrank(
|
||||||
weight.data,
|
weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
|
||||||
self.r[adapter_name],
|
|
||||||
niter=int(init_lora_weights.split("_niter_")[-1]),
|
|
||||||
)
|
)
|
||||||
Sr /= self.scaling[adapter_name]
|
Sr /= self.scaling[adapter_name]
|
||||||
Uhr = Ur.t()
|
Uhr = Ur.t()
|
||||||
@ -335,18 +290,12 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
def dora_init(self, adapter_name: str) -> None:
|
def dora_init(self, adapter_name: str) -> None:
|
||||||
if not self.lora_magnitude_vector:
|
if not self.lora_magnitude_vector:
|
||||||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||||||
self.adapter_layer_names = self.adapter_layer_names[:] + (
|
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||||||
"lora_magnitude_vector",
|
|
||||||
)
|
|
||||||
|
|
||||||
dora_layer = DoraLinearLayer(
|
dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(self, "fan_in_fan_out", False))
|
||||||
fan_in_fan_out=getattr(self, "fan_in_fan_out", False)
|
|
||||||
)
|
|
||||||
lora_A = self.lora_A[adapter_name].weight
|
lora_A = self.lora_A[adapter_name].weight
|
||||||
lora_B = self.lora_B[adapter_name].weight
|
lora_B = self.lora_B[adapter_name].weight
|
||||||
place_on_cpu = self.ephemeral_gpu_offload and (
|
place_on_cpu = self.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu")
|
||||||
lora_A.device.type == "cpu" or lora_B.device.type == "cpu"
|
|
||||||
)
|
|
||||||
if self.ephemeral_gpu_offload:
|
if self.ephemeral_gpu_offload:
|
||||||
if lora_A.device.type in ["cuda", "xpu"]:
|
if lora_A.device.type in ["cuda", "xpu"]:
|
||||||
lora_B = lora_B.to(lora_A.device)
|
lora_B = lora_B.to(lora_A.device)
|
||||||
@ -359,11 +308,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
lora_A = lora_A.to(lora_B.device)
|
lora_A = lora_A.to(lora_B.device)
|
||||||
scaling = self.scaling[adapter_name]
|
scaling = self.scaling[adapter_name]
|
||||||
dora_layer.update_layer(
|
dora_layer.update_layer(
|
||||||
base_layer=self.get_base_layer(),
|
base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling, place_on_cpu=place_on_cpu
|
||||||
lora_A=lora_A,
|
|
||||||
lora_B=lora_B,
|
|
||||||
scaling=scaling,
|
|
||||||
place_on_cpu=place_on_cpu,
|
|
||||||
)
|
)
|
||||||
self.lora_magnitude_vector[adapter_name] = dora_layer
|
self.lora_magnitude_vector[adapter_name] = dora_layer
|
||||||
|
|
||||||
@ -396,9 +341,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
self.scaling[active_adapter] = (
|
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
|
||||||
self.lora_alpha[active_adapter] / self.r[active_adapter]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.scaling[active_adapter] /= scale
|
self.scaling[active_adapter] /= scale
|
||||||
|
|
||||||
@ -440,9 +383,7 @@ class LoraLayer(BaseTunerLayer):
|
|||||||
unique_adapters = set(adapter_names)
|
unique_adapters = set(adapter_names)
|
||||||
sub_batch_indices_list = []
|
sub_batch_indices_list = []
|
||||||
for adapter in unique_adapters:
|
for adapter in unique_adapters:
|
||||||
sub_batch_indices_list.append(
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
||||||
[index for index, item in enumerate(adapter_names) if item == adapter]
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, active_adapter in enumerate(unique_adapters):
|
for i, active_adapter in enumerate(unique_adapters):
|
||||||
if active_adapter == "__base__":
|
if active_adapter == "__base__":
|
||||||
@ -508,9 +449,7 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
)
|
)
|
||||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Merge the active adapter weights into the base weights
|
Merge the active adapter weights into the base weights
|
||||||
|
|
||||||
@ -543,24 +482,15 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
# since delta_weight already includes scaling, set it to 1 here
|
# since delta_weight already includes scaling, set it to 1 here
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
self.lora_magnitude_vector[active_adapter]
|
self.lora_magnitude_vector[active_adapter]
|
||||||
.get_weight_norm(
|
.get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1)
|
||||||
orig_weights,
|
|
||||||
transpose(delta_weight, self.fan_in_fan_out),
|
|
||||||
scaling=1,
|
|
||||||
)
|
|
||||||
.detach()
|
.detach()
|
||||||
)
|
)
|
||||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight
|
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
|
||||||
/ weight_norm
|
|
||||||
)
|
|
||||||
dora_factor = transpose(
|
|
||||||
dora_factor.view(-1, 1), self.fan_in_fan_out
|
|
||||||
)
|
|
||||||
orig_weights = dora_factor * (orig_weights + delta_weight)
|
orig_weights = dora_factor * (orig_weights + delta_weight)
|
||||||
|
|
||||||
if not torch.isfinite(orig_weights).all():
|
if not torch.isfinite(orig_weights).all():
|
||||||
@ -588,9 +518,7 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
weight_norm = (
|
weight_norm = (
|
||||||
self.lora_magnitude_vector[active_adapter]
|
self.lora_magnitude_vector[active_adapter]
|
||||||
.get_weight_norm(
|
.get_weight_norm(
|
||||||
base_layer.weight,
|
base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1
|
||||||
transpose(delta_weight, self.fan_in_fan_out),
|
|
||||||
scaling=1,
|
|
||||||
)
|
)
|
||||||
.detach()
|
.detach()
|
||||||
)
|
)
|
||||||
@ -598,16 +526,9 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight
|
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
|
||||||
/ weight_norm
|
new_weight = dora_factor * (base_layer.weight.data + delta_weight)
|
||||||
)
|
|
||||||
dora_factor = transpose(
|
|
||||||
dora_factor.view(-1, 1), self.fan_in_fan_out
|
|
||||||
)
|
|
||||||
new_weight = dora_factor * (
|
|
||||||
base_layer.weight.data + delta_weight
|
|
||||||
)
|
|
||||||
base_layer.weight.data = new_weight
|
base_layer.weight.data = new_weight
|
||||||
|
|
||||||
if self.lora_bias[active_adapter]:
|
if self.lora_bias[active_adapter]:
|
||||||
@ -631,9 +552,7 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
weight.data -= delta_weight
|
weight.data -= delta_weight
|
||||||
else:
|
else:
|
||||||
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
|
||||||
)
|
|
||||||
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
|
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
|
||||||
weight.data = weight_orig
|
weight.data = weight_orig
|
||||||
|
|
||||||
@ -654,9 +573,7 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
# In case users wants to merge the adapter weights that are in
|
# In case users wants to merge the adapter weights that are in
|
||||||
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
||||||
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
||||||
cast_to_fp32 = device.type == "cpu" and (
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||||
dtype == torch.float16 or dtype == torch.bfloat16
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_A = self.lora_A[adapter].weight
|
weight_A = self.lora_A[adapter].weight
|
||||||
weight_B = self.lora_B[adapter].weight
|
weight_B = self.lora_B[adapter].weight
|
||||||
@ -665,9 +582,7 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
weight_A = weight_A.float()
|
weight_A = weight_A.float()
|
||||||
weight_B = weight_B.float()
|
weight_B = weight_B.float()
|
||||||
|
|
||||||
output_tensor = (
|
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
|
||||||
transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
|
|
||||||
)
|
|
||||||
|
|
||||||
if cast_to_fp32:
|
if cast_to_fp32:
|
||||||
output_tensor = output_tensor.to(dtype=dtype)
|
output_tensor = output_tensor.to(dtype=dtype)
|
||||||
@ -687,9 +602,7 @@ class Linear(nn.Module, LoraLayer):
|
|||||||
self.unmerge()
|
self.unmerge()
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
elif adapter_names is not None:
|
elif adapter_names is not None:
|
||||||
result = self._mixed_batch_forward(
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||||
x, *args, adapter_names=adapter_names, **kwargs
|
|
||||||
)
|
|
||||||
elif self.merged:
|
elif self.merged:
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -748,9 +661,7 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if lora_bias:
|
if lora_bias:
|
||||||
# lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter
|
# lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter
|
||||||
raise ValueError(
|
raise ValueError(f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}.")
|
||||||
f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}."
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
LoraLayer.__init__(self, base_layer)
|
LoraLayer.__init__(self, base_layer)
|
||||||
@ -768,20 +679,10 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_layer(
|
def update_layer(
|
||||||
self,
|
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
|
||||||
adapter_name,
|
|
||||||
r,
|
|
||||||
lora_alpha,
|
|
||||||
lora_dropout,
|
|
||||||
init_lora_weights,
|
|
||||||
use_rslora,
|
|
||||||
use_dora,
|
|
||||||
lora_bias,
|
|
||||||
):
|
):
|
||||||
if r <= 0:
|
if r <= 0:
|
||||||
raise ValueError(
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||||
f"`r` should be a positive integer value but the value passed is {r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.r[adapter_name] = r
|
self.r[adapter_name] = r
|
||||||
self.lora_alpha[adapter_name] = lora_alpha
|
self.lora_alpha[adapter_name] = lora_alpha
|
||||||
@ -822,25 +723,18 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
def dora_init(self, adapter_name: str) -> None:
|
def dora_init(self, adapter_name: str) -> None:
|
||||||
if self.lora_magnitude_vector is None:
|
if self.lora_magnitude_vector is None:
|
||||||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||||||
self.adapter_layer_names = self.adapter_layer_names[:] + (
|
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||||||
"lora_magnitude_vector",
|
|
||||||
)
|
|
||||||
|
|
||||||
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
|
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
|
||||||
lora_embedding_A = self.lora_embedding_A[adapter_name]
|
lora_embedding_A = self.lora_embedding_A[adapter_name]
|
||||||
lora_embedding_B = self.lora_embedding_B[adapter_name]
|
lora_embedding_B = self.lora_embedding_B[adapter_name]
|
||||||
scaling = self.scaling[adapter_name]
|
scaling = self.scaling[adapter_name]
|
||||||
dora_layer.update_layer(
|
dora_layer.update_layer(
|
||||||
base_layer=self.get_base_layer(),
|
base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling
|
||||||
lora_A=lora_embedding_A,
|
|
||||||
lora_B=lora_embedding_B,
|
|
||||||
scaling=scaling,
|
|
||||||
)
|
)
|
||||||
self.lora_magnitude_vector[adapter_name] = dora_layer
|
self.lora_magnitude_vector[adapter_name] = dora_layer
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Merge the active adapter weights into the base weights
|
Merge the active adapter weights into the base weights
|
||||||
|
|
||||||
@ -887,9 +781,7 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
while len(self.merged_adapters) > 0:
|
while len(self.merged_adapters) > 0:
|
||||||
active_adapter = self.merged_adapters.pop()
|
active_adapter = self.merged_adapters.pop()
|
||||||
if active_adapter in self.lora_embedding_A.keys():
|
if active_adapter in self.lora_embedding_A.keys():
|
||||||
self.get_base_layer().weight.data -= self.get_delta_weight(
|
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
||||||
active_adapter
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_delta_weight(self, adapter) -> torch.Tensor:
|
def get_delta_weight(self, adapter) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -905,9 +797,7 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
# In case users wants to merge the adapter weights that are in
|
# In case users wants to merge the adapter weights that are in
|
||||||
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
||||||
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
||||||
cast_to_fp32 = device.type == "cpu" and (
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||||
dtype == torch.float16 or dtype == torch.bfloat16
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_A = self.lora_embedding_A[adapter]
|
weight_A = self.lora_embedding_A[adapter]
|
||||||
weight_B = self.lora_embedding_B[adapter]
|
weight_B = self.lora_embedding_B[adapter]
|
||||||
@ -937,9 +827,7 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
unique_adapters = set(adapter_names)
|
unique_adapters = set(adapter_names)
|
||||||
sub_batch_indices_list = []
|
sub_batch_indices_list = []
|
||||||
for adapter in unique_adapters:
|
for adapter in unique_adapters:
|
||||||
sub_batch_indices_list.append(
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
||||||
[index for index, item in enumerate(adapter_names) if item == adapter]
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, active_adapter in enumerate(unique_adapters):
|
for i, active_adapter in enumerate(unique_adapters):
|
||||||
if active_adapter == "__base__":
|
if active_adapter == "__base__":
|
||||||
@ -981,9 +869,7 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
self.unmerge()
|
self.unmerge()
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
elif adapter_names is not None:
|
elif adapter_names is not None:
|
||||||
result = self._mixed_batch_forward(
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||||
x, *args, adapter_names=adapter_names, **kwargs
|
|
||||||
)
|
|
||||||
elif self.merged:
|
elif self.merged:
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -1000,9 +886,7 @@ class Embedding(nn.Module, LoraLayer):
|
|||||||
after_A = self._embed(x, embedding_A)
|
after_A = self._embed(x, embedding_A)
|
||||||
result = result + (after_A @ embedding_B) * scaling
|
result = result + (after_A @ embedding_B) * scaling
|
||||||
else:
|
else:
|
||||||
mag_norm_scale, dora_result = self.lora_magnitude_vector[
|
mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter](
|
||||||
active_adapter
|
|
||||||
](
|
|
||||||
x,
|
x,
|
||||||
lora_A=embedding_A,
|
lora_A=embedding_A,
|
||||||
lora_B=embedding_B,
|
lora_B=embedding_B,
|
||||||
@ -1053,20 +937,10 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_layer(
|
def update_layer(
|
||||||
self,
|
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
|
||||||
adapter_name,
|
|
||||||
r,
|
|
||||||
lora_alpha,
|
|
||||||
lora_dropout,
|
|
||||||
init_lora_weights,
|
|
||||||
use_rslora,
|
|
||||||
use_dora,
|
|
||||||
lora_bias,
|
|
||||||
):
|
):
|
||||||
if r <= 0:
|
if r <= 0:
|
||||||
raise ValueError(
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||||
f"`r` should be a positive integer value but the value passed is {r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.r[adapter_name] = r
|
self.r[adapter_name] = r
|
||||||
self.lora_alpha[adapter_name] = lora_alpha
|
self.lora_alpha[adapter_name] = lora_alpha
|
||||||
@ -1083,12 +957,8 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
padding = base_layer.padding
|
padding = base_layer.padding
|
||||||
conv_layer = type(base_layer)
|
conv_layer = type(base_layer)
|
||||||
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
|
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
|
||||||
self.lora_A[adapter_name] = conv_layer(
|
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
|
||||||
self.in_features, r, kernel_size, stride, padding, bias=False
|
self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=lora_bias)
|
||||||
)
|
|
||||||
self.lora_B[adapter_name] = conv_layer(
|
|
||||||
r, self.out_features, out_kernel, out_stride, bias=lora_bias
|
|
||||||
)
|
|
||||||
self.lora_bias[adapter_name] = lora_bias
|
self.lora_bias[adapter_name] = lora_bias
|
||||||
|
|
||||||
if use_rslora:
|
if use_rslora:
|
||||||
@ -1118,30 +988,21 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
def dora_init(self, adapter_name: str) -> None:
|
def dora_init(self, adapter_name: str) -> None:
|
||||||
if self.lora_magnitude_vector is None:
|
if self.lora_magnitude_vector is None:
|
||||||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||||||
self.adapter_layer_names = self.adapter_layer_names[:] + (
|
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||||||
"lora_magnitude_vector",
|
|
||||||
)
|
|
||||||
|
|
||||||
dora_layer_class = self._get_dora_layer_class()
|
dora_layer_class = self._get_dora_layer_class()
|
||||||
dora_layer = dora_layer_class(fan_in_fan_out=False)
|
dora_layer = dora_layer_class(fan_in_fan_out=False)
|
||||||
lora_A = self.lora_A[adapter_name].weight
|
lora_A = self.lora_A[adapter_name].weight
|
||||||
lora_B = self.lora_B[adapter_name].weight
|
lora_B = self.lora_B[adapter_name].weight
|
||||||
scaling = self.scaling[adapter_name]
|
scaling = self.scaling[adapter_name]
|
||||||
dora_layer.update_layer(
|
dora_layer.update_layer(base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling)
|
||||||
base_layer=self.get_base_layer(),
|
|
||||||
lora_A=lora_A,
|
|
||||||
lora_B=lora_B,
|
|
||||||
scaling=scaling,
|
|
||||||
)
|
|
||||||
self.lora_magnitude_vector[adapter_name] = dora_layer
|
self.lora_magnitude_vector[adapter_name] = dora_layer
|
||||||
|
|
||||||
def _get_dora_layer_class(self) -> type[_DoraConvNdLayer]:
|
def _get_dora_layer_class(self) -> type[_DoraConvNdLayer]:
|
||||||
# Subclasses should override this method to return the appropriate DoraLayer class
|
# Subclasses should override this method to return the appropriate DoraLayer class
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Merge the active adapter weights inside the base weights
|
Merge the active adapter weights inside the base weights
|
||||||
|
|
||||||
@ -1182,13 +1043,8 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight
|
orig_weights = dora_factor.view(*self._get_dora_factor_view()) * (orig_weights + delta_weight)
|
||||||
/ weight_norm
|
|
||||||
)
|
|
||||||
orig_weights = dora_factor.view(
|
|
||||||
*self._get_dora_factor_view()
|
|
||||||
) * (orig_weights + delta_weight)
|
|
||||||
|
|
||||||
if not torch.isfinite(orig_weights).all():
|
if not torch.isfinite(orig_weights).all():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1220,10 +1076,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight
|
|
||||||
/ weight_norm
|
|
||||||
)
|
|
||||||
new_weight = dora_factor.view(*self._get_dora_factor_view()) * (
|
new_weight = dora_factor.view(*self._get_dora_factor_view()) * (
|
||||||
base_layer.weight.data + delta_weight
|
base_layer.weight.data + delta_weight
|
||||||
)
|
)
|
||||||
@ -1250,13 +1103,8 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
weight.data -= delta_weight
|
weight.data -= delta_weight
|
||||||
else:
|
else:
|
||||||
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
weight_orig = weight.data / dora_factor.view(*self._get_dora_factor_view()) - delta_weight
|
||||||
)
|
|
||||||
weight_orig = (
|
|
||||||
weight.data / dora_factor.view(*self._get_dora_factor_view())
|
|
||||||
- delta_weight
|
|
||||||
)
|
|
||||||
weight.data = weight_orig
|
weight.data = weight_orig
|
||||||
|
|
||||||
if self.lora_bias[active_adapter]:
|
if self.lora_bias[active_adapter]:
|
||||||
@ -1276,9 +1124,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
# In case users wants to merge the adapter weights that are in
|
# In case users wants to merge the adapter weights that are in
|
||||||
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
||||||
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
||||||
cast_to_fp32 = device.type == "cpu" and (
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||||
dtype == torch.float16 or dtype == torch.bfloat16
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_A = self.lora_A[adapter].weight
|
weight_A = self.lora_A[adapter].weight
|
||||||
weight_B = self.lora_B[adapter].weight
|
weight_B = self.lora_B[adapter].weight
|
||||||
@ -1290,9 +1136,9 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
|
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
|
||||||
if self.get_base_layer().weight.size()[2:4] == (1, 1):
|
if self.get_base_layer().weight.size()[2:4] == (1, 1):
|
||||||
# conv2d 1x1
|
# conv2d 1x1
|
||||||
output_tensor = (
|
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
|
||||||
weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)
|
3
|
||||||
).unsqueeze(2).unsqueeze(3) * self.scaling[adapter]
|
) * self.scaling[adapter]
|
||||||
else:
|
else:
|
||||||
output_tensor = (
|
output_tensor = (
|
||||||
self.conv_fn(
|
self.conv_fn(
|
||||||
@ -1320,9 +1166,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
|||||||
self.unmerge()
|
self.unmerge()
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
elif adapter_names is not None:
|
elif adapter_names is not None:
|
||||||
result = self._mixed_batch_forward(
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||||
x, *args, adapter_names=adapter_names, **kwargs
|
|
||||||
)
|
|
||||||
elif self.merged:
|
elif self.merged:
|
||||||
result = self.base_layer(x, *args, **kwargs)
|
result = self.base_layer(x, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -1363,9 +1207,7 @@ class Conv2d(_ConvNd):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if not self._kernel_dim == 4:
|
if not self._kernel_dim == 4:
|
||||||
raise ValueError(
|
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
|
||||||
f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}"
|
|
||||||
)
|
|
||||||
self.conv_fn = F.conv2d
|
self.conv_fn = F.conv2d
|
||||||
|
|
||||||
def _get_dora_layer_class(self):
|
def _get_dora_layer_class(self):
|
||||||
@ -1377,9 +1219,7 @@ class Conv3d(_ConvNd):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if not self._kernel_dim == 5:
|
if not self._kernel_dim == 5:
|
||||||
raise ValueError(
|
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
|
||||||
f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}"
|
|
||||||
)
|
|
||||||
self.conv_fn = F.conv3d
|
self.conv_fn = F.conv3d
|
||||||
|
|
||||||
def _get_dora_layer_class(self):
|
def _get_dora_layer_class(self):
|
||||||
@ -1422,13 +1262,10 @@ def dispatch_default(
|
|||||||
elif isinstance(target_base_layer, Conv1D):
|
elif isinstance(target_base_layer, Conv1D):
|
||||||
if not kwargs["fan_in_fan_out"]:
|
if not kwargs["fan_in_fan_out"]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. " "Setting fan_in_fan_out to True."
|
||||||
"Setting fan_in_fan_out to True."
|
|
||||||
)
|
)
|
||||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
||||||
kwargs.update(lora_config.loftq_config)
|
kwargs.update(lora_config.loftq_config)
|
||||||
new_module = Linear(
|
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
|
||||||
target, adapter_name, is_target_conv_1d_layer=True, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_module
|
return new_module
|
||||||
|
|||||||
@ -42,13 +42,7 @@ from peft.utils import (
|
|||||||
get_peft_model_state_dict,
|
get_peft_model_state_dict,
|
||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from peft.utils.merge_utils import (
|
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties
|
||||||
dare_linear,
|
|
||||||
dare_ties,
|
|
||||||
magnitude_prune,
|
|
||||||
task_arithmetic,
|
|
||||||
ties,
|
|
||||||
)
|
|
||||||
from peft.utils.other import get_pattern_key
|
from peft.utils.other import get_pattern_key
|
||||||
|
|
||||||
from .aqlm import dispatch_aqlm
|
from .aqlm import dispatch_aqlm
|
||||||
@ -143,12 +137,8 @@ class LoraModel(BaseTuner):
|
|||||||
|
|
||||||
prefix: str = "lora_"
|
prefix: str = "lora_"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
|
||||||
self, model, config, adapter_name, low_cpu_mem_usage: bool = False
|
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_new_adapter_config(self, config: LoraConfig) -> None:
|
def _check_new_adapter_config(self, config: LoraConfig) -> None:
|
||||||
"""
|
"""
|
||||||
@ -223,9 +213,7 @@ class LoraModel(BaseTuner):
|
|||||||
|
|
||||||
quant_methods = ["gptq", "aqlm", "awq"]
|
quant_methods = ["gptq", "aqlm", "awq"]
|
||||||
for quant_method in quant_methods:
|
for quant_method in quant_methods:
|
||||||
quantization_config = get_quantization_config(
|
quantization_config = get_quantization_config(self.model, method=quant_method)
|
||||||
self.model, method=quant_method
|
|
||||||
)
|
|
||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
kwargs[f"{quant_method}_quantization_config"] = quantization_config
|
kwargs[f"{quant_method}_quantization_config"] = quantization_config
|
||||||
|
|
||||||
@ -244,9 +232,7 @@ class LoraModel(BaseTuner):
|
|||||||
lora_bias=lora_config.lora_bias,
|
lora_bias=lora_config.lora_bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_module = self._create_new_module(
|
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
|
||||||
lora_config, adapter_name, target, **kwargs
|
|
||||||
)
|
|
||||||
if adapter_name not in self.active_adapters:
|
if adapter_name not in self.active_adapters:
|
||||||
# adding an additional adapter: it is not automatically trainable
|
# adding an additional adapter: it is not automatically trainable
|
||||||
new_module.requires_grad_(False)
|
new_module.requires_grad_(False)
|
||||||
@ -283,16 +269,12 @@ class LoraModel(BaseTuner):
|
|||||||
weight = (
|
weight = (
|
||||||
child.qweight
|
child.qweight
|
||||||
if hasattr(child, "qweight")
|
if hasattr(child, "qweight")
|
||||||
else (
|
else child.W_q
|
||||||
child.W_q
|
|
||||||
if hasattr(child, "W_q")
|
if hasattr(child, "W_q")
|
||||||
else (
|
else child.weight
|
||||||
child.weight
|
|
||||||
if hasattr(child, "weight")
|
if hasattr(child, "weight")
|
||||||
else next(child.parameters())
|
else next(child.parameters())
|
||||||
)
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
if not any(p.device == meta for p in module.parameters()):
|
if not any(p.device == meta for p in module.parameters()):
|
||||||
module.to(weight.device)
|
module.to(weight.device)
|
||||||
|
|
||||||
@ -312,16 +294,10 @@ class LoraModel(BaseTuner):
|
|||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
elif bias == "lora_only":
|
elif bias == "lora_only":
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if (
|
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
|
||||||
isinstance(m, LoraLayer)
|
|
||||||
and hasattr(m, "bias")
|
|
||||||
and m.bias is not None
|
|
||||||
):
|
|
||||||
m.bias.requires_grad = True
|
m.bias.requires_grad = True
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")
|
||||||
f"Requested bias: {bias}, is not implemented."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_new_module(lora_config, adapter_name, target, **kwargs):
|
def _create_new_module(lora_config, adapter_name, target, **kwargs):
|
||||||
@ -375,9 +351,7 @@ class LoraModel(BaseTuner):
|
|||||||
|
|
||||||
new_module = None
|
new_module = None
|
||||||
for dispatcher in dispatchers:
|
for dispatcher in dispatchers:
|
||||||
new_module = dispatcher(
|
new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs)
|
||||||
target, adapter_name, lora_config=lora_config, **kwargs
|
|
||||||
)
|
|
||||||
if new_module is not None: # first match wins
|
if new_module is not None: # first match wins
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -396,19 +370,14 @@ class LoraModel(BaseTuner):
|
|||||||
try:
|
try:
|
||||||
return super().__getattr__(name) # defer to nn.Module's logic
|
return super().__getattr__(name) # defer to nn.Module's logic
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if (
|
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
|
||||||
name == "model"
|
|
||||||
): # see #1892: prevent infinite recursion if class is not initialized
|
|
||||||
raise
|
raise
|
||||||
return getattr(self.model, name)
|
return getattr(self.model, name)
|
||||||
|
|
||||||
def get_peft_config_as_dict(self, inference: bool = False):
|
def get_peft_config_as_dict(self, inference: bool = False):
|
||||||
config_dict = {}
|
config_dict = {}
|
||||||
for key, value in self.peft_config.items():
|
for key, value in self.peft_config.items():
|
||||||
config = {
|
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
|
||||||
k: v.value if isinstance(v, Enum) else v
|
|
||||||
for k, v in asdict(value).items()
|
|
||||||
}
|
|
||||||
if inference:
|
if inference:
|
||||||
config["inference_mode"] = True
|
config["inference_mode"] = True
|
||||||
config_dict[key] = config
|
config_dict[key] = config
|
||||||
@ -459,9 +428,7 @@ class LoraModel(BaseTuner):
|
|||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
if isinstance(module, LoraLayer):
|
if isinstance(module, LoraLayer):
|
||||||
if module.merged:
|
if module.merged:
|
||||||
warnings.warn(
|
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
|
||||||
"Adapter cannot be set when the model is merged. Unmerging the model first."
|
|
||||||
)
|
|
||||||
module.unmerge()
|
module.unmerge()
|
||||||
module.set_adapter(adapter_name)
|
module.set_adapter(adapter_name)
|
||||||
self.active_adapter = adapter_name
|
self.active_adapter = adapter_name
|
||||||
@ -476,9 +443,7 @@ class LoraModel(BaseTuner):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
raise ValueError(
|
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
|
||||||
"Cannot pass `adapter_names` when the model is in training mode."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that users only passed actually existing adapters.
|
# Check that users only passed actually existing adapters.
|
||||||
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
|
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
|
||||||
@ -491,18 +456,12 @@ class LoraModel(BaseTuner):
|
|||||||
unique_adapters = {name for name in adapter_names if name != "__base__"}
|
unique_adapters = {name for name in adapter_names if name != "__base__"}
|
||||||
unexpected_adapters = unique_adapters - expected_adapters
|
unexpected_adapters = unique_adapters - expected_adapters
|
||||||
if unexpected_adapters:
|
if unexpected_adapters:
|
||||||
raise ValueError(
|
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")
|
||||||
f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
hook_handles = []
|
hook_handles = []
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if isinstance(module, LoraLayer) or isinstance(
|
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
|
||||||
module, ModulesToSaveWrapper
|
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
|
||||||
):
|
|
||||||
pre_forward = partial(
|
|
||||||
_adapter_names_pre_forward_hook, adapter_names=adapter_names
|
|
||||||
)
|
|
||||||
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
||||||
hook_handles.append(handle)
|
hook_handles.append(handle)
|
||||||
|
|
||||||
@ -518,26 +477,17 @@ class LoraModel(BaseTuner):
|
|||||||
"""
|
"""
|
||||||
super()._check_merge_allowed()
|
super()._check_merge_allowed()
|
||||||
if getattr(self.model, "quantization_method", None) == "gptq":
|
if getattr(self.model, "quantization_method", None) == "gptq":
|
||||||
raise ValueError(
|
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
|
||||||
"Cannot merge LORA layers when the model is gptq quantized"
|
|
||||||
)
|
|
||||||
if self.peft_config.get("layer_replication"):
|
if self.peft_config.get("layer_replication"):
|
||||||
raise ValueError(
|
raise ValueError("Cannot merge LORA layers when base model layers are replicated")
|
||||||
"Cannot merge LORA layers when base model layers are replicated"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prepare_adapter_config(peft_config, model_config):
|
def _prepare_adapter_config(peft_config, model_config):
|
||||||
if peft_config.target_modules is None:
|
if peft_config.target_modules is None:
|
||||||
if (
|
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
|
||||||
model_config["model_type"]
|
|
||||||
not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
|
||||||
):
|
|
||||||
raise ValueError("Please specify `target_modules` in `peft_config`")
|
raise ValueError("Please specify `target_modules` in `peft_config`")
|
||||||
peft_config.target_modules = set(
|
peft_config.target_modules = set(
|
||||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
|
||||||
model_config["model_type"]
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return peft_config
|
return peft_config
|
||||||
|
|
||||||
@ -551,9 +501,7 @@ class LoraModel(BaseTuner):
|
|||||||
if merge:
|
if merge:
|
||||||
self._check_merge_allowed()
|
self._check_merge_allowed()
|
||||||
|
|
||||||
key_list = [
|
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||||
key for key, _ in self.model.named_modules() if self.prefix not in key
|
|
||||||
]
|
|
||||||
desc = "Unloading " + ("and merging " if merge else "") + "model"
|
desc = "Unloading " + ("and merging " if merge else "") + "model"
|
||||||
for key in tqdm(key_list, disable=not progressbar, desc=desc):
|
for key in tqdm(key_list, disable=not progressbar, desc=desc):
|
||||||
try:
|
try:
|
||||||
@ -564,18 +512,14 @@ class LoraModel(BaseTuner):
|
|||||||
if hasattr(target, "base_layer"):
|
if hasattr(target, "base_layer"):
|
||||||
if merge:
|
if merge:
|
||||||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||||
self._replace_module(
|
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
||||||
parent, target_name, target.get_base_layer(), target
|
|
||||||
)
|
|
||||||
elif isinstance(target, ModulesToSaveWrapper):
|
elif isinstance(target, ModulesToSaveWrapper):
|
||||||
# save any additional trainable modules part of `modules_to_save`
|
# save any additional trainable modules part of `modules_to_save`
|
||||||
new_module = target.modules_to_save[target.active_adapter]
|
new_module = target.modules_to_save[target.active_adapter]
|
||||||
if hasattr(new_module, "base_layer"):
|
if hasattr(new_module, "base_layer"):
|
||||||
# check if the module is itself a tuner layer
|
# check if the module is itself a tuner layer
|
||||||
if merge:
|
if merge:
|
||||||
new_module.merge(
|
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||||
safe_merge=safe_merge, adapter_names=adapter_names
|
|
||||||
)
|
|
||||||
new_module = new_module.get_base_layer()
|
new_module = new_module.get_base_layer()
|
||||||
setattr(parent, target_name, new_module)
|
setattr(parent, target_name, new_module)
|
||||||
|
|
||||||
@ -595,11 +539,7 @@ class LoraModel(BaseTuner):
|
|||||||
# If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
|
# If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
|
||||||
# modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
|
# modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
|
||||||
# have modules for the adapters to be merged.
|
# have modules for the adapters to be merged.
|
||||||
modules_to_save_wrappers = [
|
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
|
||||||
module
|
|
||||||
for module in self.modules()
|
|
||||||
if isinstance(module, ModulesToSaveWrapper)
|
|
||||||
]
|
|
||||||
problematic_wrappers = [
|
problematic_wrappers = [
|
||||||
wrapper
|
wrapper
|
||||||
for wrapper in modules_to_save_wrappers
|
for wrapper in modules_to_save_wrappers
|
||||||
@ -615,13 +555,7 @@ class LoraModel(BaseTuner):
|
|||||||
combination_type = "linear" if len(adapters) == 1 else combination_type
|
combination_type = "linear" if len(adapters) == 1 else combination_type
|
||||||
|
|
||||||
adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
|
adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
|
||||||
if combination_type in (
|
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
|
||||||
"linear",
|
|
||||||
"ties",
|
|
||||||
"dare_ties",
|
|
||||||
"dare_linear",
|
|
||||||
"magnitude_prune",
|
|
||||||
):
|
|
||||||
# all adapters ranks should be same, new rank is just this value
|
# all adapters ranks should be same, new rank is just this value
|
||||||
if len(set(adapters_ranks)) != 1:
|
if len(set(adapters_ranks)) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -639,9 +573,7 @@ class LoraModel(BaseTuner):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid combination_type: {combination_type}")
|
raise ValueError(f"Invalid combination_type: {combination_type}")
|
||||||
|
|
||||||
target_module_types = [
|
target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
|
||||||
type(self.peft_config[adapter].target_modules) for adapter in adapters
|
|
||||||
]
|
|
||||||
if not target_module_types:
|
if not target_module_types:
|
||||||
raise ValueError(f"Found no adapter matching the names in {adapters}")
|
raise ValueError(f"Found no adapter matching the names in {adapters}")
|
||||||
if len(set(target_module_types)) > 1:
|
if len(set(target_module_types)) > 1:
|
||||||
@ -651,18 +583,13 @@ class LoraModel(BaseTuner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if target_module_types[0] is str:
|
if target_module_types[0] is str:
|
||||||
new_target_modules = "|".join(
|
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
|
||||||
f"({self.peft_config[adapter].target_modules})" for adapter in adapters
|
|
||||||
)
|
|
||||||
elif target_module_types[0] is set:
|
elif target_module_types[0] is set:
|
||||||
new_target_modules = reduce(
|
new_target_modules = reduce(
|
||||||
operator.or_,
|
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
|
||||||
(self.peft_config[adapter].target_modules for adapter in adapters),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")
|
||||||
f"Invalid type {target_module_types[0]} found in target_modules"
|
|
||||||
)
|
|
||||||
|
|
||||||
return combination_type, new_rank, new_target_modules
|
return combination_type, new_rank, new_target_modules
|
||||||
|
|
||||||
@ -722,13 +649,11 @@ class LoraModel(BaseTuner):
|
|||||||
if adapter_name in list(self.peft_config.keys()):
|
if adapter_name in list(self.peft_config.keys()):
|
||||||
return
|
return
|
||||||
|
|
||||||
combination_type, new_rank, new_target_modules = (
|
combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
|
||||||
self._check_add_weighted_adapter(
|
|
||||||
adapters=adapters,
|
adapters=adapters,
|
||||||
combination_type=combination_type,
|
combination_type=combination_type,
|
||||||
svd_rank=svd_rank,
|
svd_rank=svd_rank,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self.peft_config[adapter_name] = replace(
|
self.peft_config[adapter_name] = replace(
|
||||||
self.peft_config[adapters[0]],
|
self.peft_config[adapters[0]],
|
||||||
@ -741,9 +666,7 @@ class LoraModel(BaseTuner):
|
|||||||
# Do we really need that?
|
# Do we really need that?
|
||||||
_freeze_adapter(self.model, adapter_name)
|
_freeze_adapter(self.model, adapter_name)
|
||||||
|
|
||||||
key_list = [
|
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||||
key for key, _ in self.model.named_modules() if self.prefix not in key
|
|
||||||
]
|
|
||||||
for key in key_list:
|
for key in key_list:
|
||||||
_, target, _ = _get_submodules(self.model, key)
|
_, target, _ = _get_submodules(self.model, key)
|
||||||
if isinstance(target, LoraLayer):
|
if isinstance(target, LoraLayer):
|
||||||
@ -769,17 +692,11 @@ class LoraModel(BaseTuner):
|
|||||||
current_adapter_lora_B = target.lora_embedding_B[adapter]
|
current_adapter_lora_B = target.lora_embedding_B[adapter]
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
loras_A.append(
|
loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter])
|
||||||
current_adapter_lora_A.data
|
|
||||||
* weight
|
|
||||||
* target.scaling[adapter]
|
|
||||||
)
|
|
||||||
loras_B.append(current_adapter_lora_B.data)
|
loras_B.append(current_adapter_lora_B.data)
|
||||||
|
|
||||||
if len(loras_A) == 0:
|
if len(loras_A) == 0:
|
||||||
raise ValueError(
|
raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.")
|
||||||
"No matching LoRAs found. Please raise an issue on GitHub."
|
|
||||||
)
|
|
||||||
loras_A = torch.cat(loras_A, dim=0)
|
loras_A = torch.cat(loras_A, dim=0)
|
||||||
loras_B = torch.cat(loras_B, dim=1)
|
loras_B = torch.cat(loras_B, dim=1)
|
||||||
target_lora_A.data[: loras_A.shape[0], :] = loras_A
|
target_lora_A.data[: loras_A.shape[0], :] = loras_A
|
||||||
@ -791,8 +708,7 @@ class LoraModel(BaseTuner):
|
|||||||
"dare_ties_svd",
|
"dare_ties_svd",
|
||||||
"magnitude_prune_svd",
|
"magnitude_prune_svd",
|
||||||
]:
|
]:
|
||||||
target_lora_A.data, target_lora_B.data = (
|
target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter(
|
||||||
self._svd_generalized_task_arithmetic_weighted_adapter(
|
|
||||||
combination_type,
|
combination_type,
|
||||||
adapters,
|
adapters,
|
||||||
weights,
|
weights,
|
||||||
@ -806,23 +722,9 @@ class LoraModel(BaseTuner):
|
|||||||
full_matrices=svd_full_matrices,
|
full_matrices=svd_full_matrices,
|
||||||
driver=svd_driver,
|
driver=svd_driver,
|
||||||
)
|
)
|
||||||
)
|
elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]:
|
||||||
elif combination_type in [
|
target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter(
|
||||||
"linear",
|
combination_type, adapters, weights, target, density, majority_sign_method
|
||||||
"ties",
|
|
||||||
"dare_linear",
|
|
||||||
"dare_ties",
|
|
||||||
"magnitude_prune",
|
|
||||||
]:
|
|
||||||
target_lora_A.data, target_lora_B.data = (
|
|
||||||
self._generalized_task_arithmetic_weighted_adapter(
|
|
||||||
combination_type,
|
|
||||||
adapters,
|
|
||||||
weights,
|
|
||||||
target,
|
|
||||||
density,
|
|
||||||
majority_sign_method,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _svd_generalized_task_arithmetic_weighted_adapter(
|
def _svd_generalized_task_arithmetic_weighted_adapter(
|
||||||
@ -850,29 +752,21 @@ class LoraModel(BaseTuner):
|
|||||||
|
|
||||||
# if no valid adapter, nothing to do
|
# if no valid adapter, nothing to do
|
||||||
if len(valid_adapters) == 0:
|
if len(valid_adapters) == 0:
|
||||||
raise ValueError(
|
raise ValueError("No matching LoRAs found. Please raise an issue on Github.")
|
||||||
"No matching LoRAs found. Please raise an issue on Github."
|
|
||||||
)
|
|
||||||
delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters]
|
delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters]
|
||||||
valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device)
|
valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device)
|
||||||
if combination_type == "svd":
|
if combination_type == "svd":
|
||||||
delta_weight = task_arithmetic(delta_weight, valid_weights)
|
delta_weight = task_arithmetic(delta_weight, valid_weights)
|
||||||
elif combination_type == "ties_svd":
|
elif combination_type == "ties_svd":
|
||||||
delta_weight = ties(
|
delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method)
|
||||||
delta_weight, valid_weights, density, majority_sign_method
|
|
||||||
)
|
|
||||||
elif combination_type == "dare_linear_svd":
|
elif combination_type == "dare_linear_svd":
|
||||||
delta_weight = dare_linear(delta_weight, valid_weights, density)
|
delta_weight = dare_linear(delta_weight, valid_weights, density)
|
||||||
elif combination_type == "dare_ties_svd":
|
elif combination_type == "dare_ties_svd":
|
||||||
delta_weight = dare_ties(
|
delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method)
|
||||||
delta_weight, valid_weights, density, majority_sign_method
|
|
||||||
)
|
|
||||||
elif combination_type == "magnitude_prune_svd":
|
elif combination_type == "magnitude_prune_svd":
|
||||||
delta_weight = magnitude_prune(delta_weight, valid_weights, density)
|
delta_weight = magnitude_prune(delta_weight, valid_weights, density)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid value passed to combination type: {combination_type}")
|
||||||
f"Invalid value passed to combination type: {combination_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
conv2d = isinstance(target, Conv2d)
|
conv2d = isinstance(target, Conv2d)
|
||||||
if conv2d:
|
if conv2d:
|
||||||
@ -881,15 +775,11 @@ class LoraModel(BaseTuner):
|
|||||||
delta_weight = delta_weight.flatten(start_dim=1)
|
delta_weight = delta_weight.flatten(start_dim=1)
|
||||||
else:
|
else:
|
||||||
delta_weight = delta_weight.squeeze()
|
delta_weight = delta_weight.squeeze()
|
||||||
if (
|
if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding:
|
||||||
hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out
|
|
||||||
) or is_embedding:
|
|
||||||
delta_weight = delta_weight.T
|
delta_weight = delta_weight.T
|
||||||
|
|
||||||
# based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131
|
# based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131
|
||||||
U, S, Vh = torch.linalg.svd(
|
U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver)
|
||||||
delta_weight, full_matrices=full_matrices, driver=driver
|
|
||||||
)
|
|
||||||
U = U[:, :new_rank]
|
U = U[:, :new_rank]
|
||||||
S = S[:new_rank]
|
S = S[:new_rank]
|
||||||
U = U @ torch.diag(S)
|
U = U @ torch.diag(S)
|
||||||
@ -937,15 +827,11 @@ class LoraModel(BaseTuner):
|
|||||||
if combination_type == "linear":
|
if combination_type == "linear":
|
||||||
lora_deltas[i] = task_arithmetic(task_tensors, valid_weights)
|
lora_deltas[i] = task_arithmetic(task_tensors, valid_weights)
|
||||||
elif combination_type == "ties":
|
elif combination_type == "ties":
|
||||||
lora_deltas[i] = ties(
|
lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method)
|
||||||
task_tensors, valid_weights, density, majority_sign_method
|
|
||||||
)
|
|
||||||
elif combination_type == "dare_linear":
|
elif combination_type == "dare_linear":
|
||||||
lora_deltas[i] = dare_linear(task_tensors, valid_weights, density)
|
lora_deltas[i] = dare_linear(task_tensors, valid_weights, density)
|
||||||
elif combination_type == "dare_ties":
|
elif combination_type == "dare_ties":
|
||||||
lora_deltas[i] = dare_ties(
|
lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method)
|
||||||
task_tensors, valid_weights, density, majority_sign_method
|
|
||||||
)
|
|
||||||
elif combination_type == "magnitude_prune":
|
elif combination_type == "magnitude_prune":
|
||||||
lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density)
|
lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density)
|
||||||
else:
|
else:
|
||||||
@ -964,9 +850,7 @@ class LoraModel(BaseTuner):
|
|||||||
raise ValueError(f"Adapter {adapter_name} does not exist")
|
raise ValueError(f"Adapter {adapter_name} does not exist")
|
||||||
del self.peft_config[adapter_name]
|
del self.peft_config[adapter_name]
|
||||||
|
|
||||||
key_list = [
|
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
||||||
key for key, _ in self.model.named_modules() if self.prefix not in key
|
|
||||||
]
|
|
||||||
new_adapter = None
|
new_adapter = None
|
||||||
for key in key_list:
|
for key in key_list:
|
||||||
_, target, _ = _get_submodules(self.model, key)
|
_, target, _ = _get_submodules(self.model, key)
|
||||||
@ -978,10 +862,7 @@ class LoraModel(BaseTuner):
|
|||||||
self.active_adapter = new_adapter or []
|
self.active_adapter = new_adapter or []
|
||||||
|
|
||||||
def merge_and_unload(
|
def merge_and_unload(
|
||||||
self,
|
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
||||||
progressbar: bool = False,
|
|
||||||
safe_merge: bool = False,
|
|
||||||
adapter_names: Optional[list[str]] = None,
|
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
r"""
|
r"""
|
||||||
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
|
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
|
||||||
@ -1019,9 +900,7 @@ class LoraModel(BaseTuner):
|
|||||||
"""
|
"""
|
||||||
return self._unload_and_optionally_merge(merge=False)
|
return self._unload_and_optionally_merge(merge=False)
|
||||||
|
|
||||||
def subtract_mutated_init(
|
def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None):
|
||||||
self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA |
|
This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA |
|
||||||
OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus
|
OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus
|
||||||
@ -1050,19 +929,11 @@ class LoraModel(BaseTuner):
|
|||||||
## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'.
|
## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'.
|
||||||
if "lora_A" in name:
|
if "lora_A" in name:
|
||||||
tensors_lora[name] = torch.cat(
|
tensors_lora[name] = torch.cat(
|
||||||
[
|
[output_state_dict[name], mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=0
|
||||||
output_state_dict[name],
|
|
||||||
mutated_init_state_dict[".".join(name.split(".")[1:])],
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
)
|
||||||
elif "lora_B" in name:
|
elif "lora_B" in name:
|
||||||
tensors_lora[name] = torch.cat(
|
tensors_lora[name] = torch.cat(
|
||||||
[
|
[output_state_dict[name], -mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=1
|
||||||
output_state_dict[name],
|
|
||||||
-mutated_init_state_dict[".".join(name.split(".")[1:])],
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return tensors_lora
|
return tensors_lora
|
||||||
|
|||||||
@ -33,9 +33,7 @@ class TorchaoLoraLinear(Linear):
|
|||||||
# this is not strictly necessary, as kwargs are stored either way, but we want to error early if
|
# this is not strictly necessary, as kwargs are stored either way, but we want to error early if
|
||||||
# get_apply_tensor_subclass is missing.
|
# get_apply_tensor_subclass is missing.
|
||||||
if kwargs.get("lora_bias", False):
|
if kwargs.get("lora_bias", False):
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
|
||||||
f"{self.__class__.__name__} does not support lora_bias yet, set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.get_apply_tensor_subclass = get_apply_tensor_subclass
|
self.get_apply_tensor_subclass = get_apply_tensor_subclass
|
||||||
@ -45,16 +43,10 @@ class TorchaoLoraLinear(Linear):
|
|||||||
# TODO: Not required once int4_weight_only is properly supported by torchao
|
# TODO: Not required once int4_weight_only is properly supported by torchao
|
||||||
base_layer = self.get_base_layer()
|
base_layer = self.get_base_layer()
|
||||||
weight = base_layer.weight
|
weight = base_layer.weight
|
||||||
if hasattr(weight, "layout_tensor") and (
|
if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8):
|
||||||
weight.layout_tensor.data.dtype != torch.int8
|
raise ValueError(f"{type(self).__name__} only supports int8 weights for now.")
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"{type(self).__name__} only supports int8 weights for now."
|
|
||||||
)
|
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
from torchao import quantize_
|
from torchao import quantize_
|
||||||
|
|
||||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||||
@ -151,10 +143,7 @@ def dispatch_torchao(
|
|||||||
from torchao.dtypes import AffineQuantizedTensor
|
from torchao.dtypes import AffineQuantizedTensor
|
||||||
from torchao.quantization import LinearActivationQuantizedTensor
|
from torchao.quantization import LinearActivationQuantizedTensor
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)):
|
||||||
target_base_layer.weight,
|
|
||||||
(AffineQuantizedTensor, LinearActivationQuantizedTensor),
|
|
||||||
):
|
|
||||||
new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)
|
new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)
|
||||||
|
|
||||||
return new_module
|
return new_module
|
||||||
|
|||||||
@ -54,17 +54,13 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if lora_bias:
|
if lora_bias:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
|
||||||
f"{self.__class__.__name__} does not support lora_bias yet, set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
LoraLayer.__init__(self, base_layer=base_layer, **kwargs)
|
LoraLayer.__init__(self, base_layer=base_layer, **kwargs)
|
||||||
|
|
||||||
if use_dora:
|
if use_dora:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
|
||||||
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear)
|
self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear)
|
||||||
@ -117,9 +113,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
**parallel_linear_kwargs,
|
**parallel_linear_kwargs,
|
||||||
):
|
):
|
||||||
if r <= 0:
|
if r <= 0:
|
||||||
raise ValueError(
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||||
f"`r` should be a positive integer value but the value passed is {r}"
|
|
||||||
)
|
|
||||||
self.r[adapter_name] = r
|
self.r[adapter_name] = r
|
||||||
self.lora_alpha[adapter_name] = lora_alpha
|
self.lora_alpha[adapter_name] = lora_alpha
|
||||||
if lora_dropout > 0.0:
|
if lora_dropout > 0.0:
|
||||||
@ -142,19 +136,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
init_method=init_method,
|
init_method=init_method,
|
||||||
config=megatron_config,
|
config=megatron_config,
|
||||||
)
|
)
|
||||||
lora_b = nn.Linear(
|
lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32)
|
||||||
in_features=r,
|
|
||||||
out_features=self.out_features,
|
|
||||||
bias=False,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
lora_a = nn.Linear(
|
lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32)
|
||||||
in_features=self.in_features,
|
|
||||||
out_features=r,
|
|
||||||
bias=False,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
lora_b = self.backend.ColumnParallelLinear(
|
lora_b = self.backend.ColumnParallelLinear(
|
||||||
input_size=r,
|
input_size=r,
|
||||||
output_size=self.out_features,
|
output_size=self.out_features,
|
||||||
@ -174,9 +158,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
|
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
|
||||||
with gather_params_ctx(self.get_base_layer().weight):
|
with gather_params_ctx(self.get_base_layer().weight):
|
||||||
self.pissa_init(adapter_name, init_lora_weights)
|
self.pissa_init(adapter_name, init_lora_weights)
|
||||||
elif (
|
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
|
||||||
isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora"
|
|
||||||
):
|
|
||||||
with gather_params_ctx(self.get_base_layer().weight):
|
with gather_params_ctx(self.get_base_layer().weight):
|
||||||
self.olora_init(adapter_name)
|
self.olora_init(adapter_name)
|
||||||
elif init_lora_weights == "loftq":
|
elif init_lora_weights == "loftq":
|
||||||
@ -207,9 +189,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
self.unmerge()
|
self.unmerge()
|
||||||
result, bias = self.base_layer(x, *args, **kwargs)
|
result, bias = self.base_layer(x, *args, **kwargs)
|
||||||
elif adapter_names is not None:
|
elif adapter_names is not None:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.")
|
||||||
f"{self.__class__.__name__} does not support mixed_batch_forward yet."
|
|
||||||
)
|
|
||||||
elif self.merged:
|
elif self.merged:
|
||||||
result, bias = self.base_layer(x, *args, **kwargs)
|
result, bias = self.base_layer(x, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -245,9 +225,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
result = result.to(torch_result_dtype)
|
result = result.to(torch_result_dtype)
|
||||||
return result, bias
|
return result, bias
|
||||||
|
|
||||||
def merge(
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||||
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Merge the active adapter weights into the base weights
|
Merge the active adapter weights into the base weights
|
||||||
|
|
||||||
@ -280,24 +258,15 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
# since delta_weight already includes scaling, set it to 1 here
|
# since delta_weight already includes scaling, set it to 1 here
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
self.lora_magnitude_vector[active_adapter]
|
self.lora_magnitude_vector[active_adapter]
|
||||||
.get_weight_norm(
|
.get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1)
|
||||||
orig_weights,
|
|
||||||
transpose(delta_weight, self.fan_in_fan_out),
|
|
||||||
scaling=1,
|
|
||||||
)
|
|
||||||
.detach()
|
.detach()
|
||||||
)
|
)
|
||||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight
|
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
|
||||||
/ weight_norm
|
|
||||||
)
|
|
||||||
dora_factor = transpose(
|
|
||||||
dora_factor.view(-1, 1), self.fan_in_fan_out
|
|
||||||
)
|
|
||||||
orig_weights = dora_factor * (orig_weights + delta_weight)
|
orig_weights = dora_factor * (orig_weights + delta_weight)
|
||||||
|
|
||||||
if not torch.isfinite(orig_weights).all():
|
if not torch.isfinite(orig_weights).all():
|
||||||
@ -316,9 +285,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
weight_norm = (
|
weight_norm = (
|
||||||
self.lora_magnitude_vector[active_adapter]
|
self.lora_magnitude_vector[active_adapter]
|
||||||
.get_weight_norm(
|
.get_weight_norm(
|
||||||
base_layer.weight,
|
base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1
|
||||||
transpose(delta_weight, self.fan_in_fan_out),
|
|
||||||
scaling=1,
|
|
||||||
)
|
)
|
||||||
.detach()
|
.detach()
|
||||||
)
|
)
|
||||||
@ -326,16 +293,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||||
# different value
|
# different value
|
||||||
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight
|
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
|
||||||
/ weight_norm
|
new_weight = dora_factor * (base_layer.weight.data + delta_weight)
|
||||||
)
|
|
||||||
dora_factor = transpose(
|
|
||||||
dora_factor.view(-1, 1), self.fan_in_fan_out
|
|
||||||
)
|
|
||||||
new_weight = dora_factor * (
|
|
||||||
base_layer.weight.data + delta_weight
|
|
||||||
)
|
|
||||||
base_layer.weight.data = new_weight
|
base_layer.weight.data = new_weight
|
||||||
|
|
||||||
self.merged_adapters.append(active_adapter)
|
self.merged_adapters.append(active_adapter)
|
||||||
@ -356,9 +316,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
weight.data -= delta_weight
|
weight.data -= delta_weight
|
||||||
else:
|
else:
|
||||||
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
|
||||||
dora_factor = (
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||||
self.lora_magnitude_vector[active_adapter].weight / weight_norm
|
|
||||||
)
|
|
||||||
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
|
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
|
||||||
weight.data = weight_orig
|
weight.data = weight_orig
|
||||||
|
|
||||||
@ -376,9 +334,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
# In case users wants to merge the adapter weights that are in
|
# In case users wants to merge the adapter weights that are in
|
||||||
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
||||||
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
||||||
cast_to_fp32 = device.type == "cpu" and (
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||||
dtype == torch.float16 or dtype == torch.bfloat16
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_A = self.lora_A[adapter].weight
|
weight_A = self.lora_A[adapter].weight
|
||||||
weight_B = self.lora_B[adapter].weight
|
weight_B = self.lora_B[adapter].weight
|
||||||
@ -387,9 +343,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
|||||||
weight_A = weight_A.float()
|
weight_A = weight_A.float()
|
||||||
weight_B = weight_B.float()
|
weight_B = weight_B.float()
|
||||||
|
|
||||||
output_tensor = (
|
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
|
||||||
transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
|
|
||||||
)
|
|
||||||
|
|
||||||
if cast_to_fp32:
|
if cast_to_fp32:
|
||||||
output_tensor = output_tensor.to(dtype=dtype)
|
output_tensor = output_tensor.to(dtype=dtype)
|
||||||
@ -425,17 +379,12 @@ def dispatch_megatron(
|
|||||||
|
|
||||||
if megatron_core and isinstance(
|
if megatron_core and isinstance(
|
||||||
target_base_layer,
|
target_base_layer,
|
||||||
(
|
(megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear),
|
||||||
megatron_core.tensor_parallel.ColumnParallelLinear,
|
|
||||||
megatron_core.tensor_parallel.RowParallelLinear,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
megatron_kwargs = kwargs.copy()
|
megatron_kwargs = kwargs.copy()
|
||||||
megatron_config = lora_config.megatron_config
|
megatron_config = lora_config.megatron_config
|
||||||
if isinstance(megatron_config, dict):
|
if isinstance(megatron_config, dict):
|
||||||
transformer_config_class = (
|
transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig
|
||||||
megatron_core.transformer.transformer_config.TransformerConfig
|
|
||||||
)
|
|
||||||
megatron_config = transformer_config_class(**lora_config.megatron_config)
|
megatron_config = transformer_config_class(**lora_config.megatron_config)
|
||||||
megatron_kwargs["megatron_config"] = megatron_config
|
megatron_kwargs["megatron_config"] = megatron_config
|
||||||
if megatron_kwargs["fan_in_fan_out"]:
|
if megatron_kwargs["fan_in_fan_out"]:
|
||||||
@ -446,10 +395,7 @@ def dispatch_megatron(
|
|||||||
)
|
)
|
||||||
megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||||
new_module = LoraParallelLinear(
|
new_module = LoraParallelLinear(
|
||||||
base_layer=target,
|
base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs
|
||||||
adapter_name=adapter_name,
|
|
||||||
backend=megatron_core.tensor_parallel,
|
|
||||||
**megatron_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_module
|
return new_module
|
||||||
|
|||||||
10
src/train.py
10
src/train.py
@ -1,7 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
|
sys.path.insert(0, "./transformers_repo/src/")
|
||||||
sys.path.insert(0, "transformers_repo/src/")
|
sys.path.insert(0, "./peft_repo/src/")
|
||||||
sys.path.insert(0, "peft_repo/src/")
|
|
||||||
|
|
||||||
from dataset_library.factory import get_dataset
|
from dataset_library.factory import get_dataset
|
||||||
|
|
||||||
@ -51,6 +50,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
|
# model = get_peft_model(model, peft_config)
|
||||||
model.add_adapter(peft_config)
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
elif model_args.peft_type == "LORA":
|
elif model_args.peft_type == "LORA":
|
||||||
@ -58,8 +58,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
|
# model = get_peft_model(model, peft_config)
|
||||||
model.add_adapter(peft_config)
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
|
# if accelerator.is_local_main_process:
|
||||||
|
# model.print_trainable_parameters()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
peft_config = None
|
peft_config = None
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||||
--dataset_name gigaspeech \
|
--dataset_name CHEM \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type LORA \
|
--peft_type MMOELORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
--lora_target_modules q_proj v_proj \
|
--lora_target_modules q_proj v_proj \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
|
|||||||
103
uv.lock
generated
103
uv.lock
generated
@ -196,15 +196,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 },
|
{ url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cfgv"
|
|
||||||
version = "3.4.0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "charset-normalizer"
|
name = "charset-normalizer"
|
||||||
version = "3.4.1"
|
version = "3.4.1"
|
||||||
@ -269,8 +260,6 @@ dependencies = [
|
|||||||
{ name = "numba" },
|
{ name = "numba" },
|
||||||
{ name = "peft" },
|
{ name = "peft" },
|
||||||
{ name = "pip" },
|
{ name = "pip" },
|
||||||
{ name = "pre-commit" },
|
|
||||||
{ name = "pytest" },
|
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
{ name = "rouge-score" },
|
{ name = "rouge-score" },
|
||||||
{ name = "safetensors" },
|
{ name = "safetensors" },
|
||||||
@ -303,8 +292,6 @@ requires-dist = [
|
|||||||
{ name = "numba", specifier = ">=0.60.0" },
|
{ name = "numba", specifier = ">=0.60.0" },
|
||||||
{ name = "peft", specifier = "==0.14.0" },
|
{ name = "peft", specifier = "==0.14.0" },
|
||||||
{ name = "pip", specifier = "==24.3.1" },
|
{ name = "pip", specifier = "==24.3.1" },
|
||||||
{ name = "pre-commit", specifier = ">=4.0.1" },
|
|
||||||
{ name = "pytest", specifier = ">=8.3.4" },
|
|
||||||
{ name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" },
|
{ name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" },
|
||||||
{ name = "rouge-score", specifier = ">=0.1.2" },
|
{ name = "rouge-score", specifier = ">=0.1.2" },
|
||||||
{ name = "safetensors", specifier = ">=0.5.2" },
|
{ name = "safetensors", specifier = ">=0.5.2" },
|
||||||
@ -401,15 +388,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252 },
|
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "distlib"
|
|
||||||
version = "0.3.9"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "einops"
|
name = "einops"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
@ -555,15 +533,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/6c/3f/50f6b25fafdcfb1c089187a328c95081abf882309afd86f4053951507cd1/huggingface_hub-0.27.1-py3-none-any.whl", hash = "sha256:1c5155ca7d60b60c2e2fc38cbb3ffb7f7c3adf48f824015b219af9061771daec", size = 450658 },
|
{ url = "https://files.pythonhosted.org/packages/6c/3f/50f6b25fafdcfb1c089187a328c95081abf882309afd86f4053951507cd1/huggingface_hub-0.27.1-py3-none-any.whl", hash = "sha256:1c5155ca7d60b60c2e2fc38cbb3ffb7f7c3adf48f824015b219af9061771daec", size = 450658 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "identify"
|
|
||||||
version = "2.6.5"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/cf/92/69934b9ef3c31ca2470980423fda3d00f0460ddefdf30a67adf7f17e2e00/identify-2.6.5.tar.gz", hash = "sha256:c10b33f250e5bba374fae86fb57f3adcebf1161bce7cdf92031915fd480c13bc", size = 99213 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/ec/fa/dce098f4cdf7621aa8f7b4f919ce545891f489482f0bfa5102f3eca8608b/identify-2.6.5-py2.py3-none-any.whl", hash = "sha256:14181a47091eb75b337af4c23078c9d09225cd4c48929f521f3bf16b09d02566", size = 99078 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "3.10"
|
version = "3.10"
|
||||||
@ -573,15 +542,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "iniconfig"
|
|
||||||
version = "2.0.0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jinja2"
|
name = "jinja2"
|
||||||
version = "3.1.5"
|
version = "3.1.5"
|
||||||
@ -865,15 +825,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 },
|
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "nodeenv"
|
|
||||||
version = "1.9.1"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numba"
|
name = "numba"
|
||||||
version = "0.60.0"
|
version = "0.60.0"
|
||||||
@ -1182,15 +1133,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 },
|
{ url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pluggy"
|
|
||||||
version = "1.5.0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pooch"
|
name = "pooch"
|
||||||
version = "1.8.2"
|
version = "1.8.2"
|
||||||
@ -1205,22 +1147,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574 },
|
{ url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pre-commit"
|
|
||||||
version = "4.0.1"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "cfgv" },
|
|
||||||
{ name = "identify" },
|
|
||||||
{ name = "nodeenv" },
|
|
||||||
{ name = "pyyaml" },
|
|
||||||
{ name = "virtualenv" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/2e/c8/e22c292035f1bac8b9f5237a2622305bc0304e776080b246f3df57c4ff9f/pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2", size = 191678 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/16/8f/496e10d51edd6671ebe0432e33ff800aa86775d2d147ce7d43389324a525/pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878", size = 218713 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "propcache"
|
name = "propcache"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@ -1422,21 +1348,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
|
{ url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pytest"
|
|
||||||
version = "8.3.4"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
|
||||||
{ name = "iniconfig" },
|
|
||||||
{ name = "packaging" },
|
|
||||||
{ name = "pluggy" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
@ -1938,20 +1849,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
|
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "virtualenv"
|
|
||||||
version = "20.29.1"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "distlib" },
|
|
||||||
{ name = "filelock" },
|
|
||||||
{ name = "platformdirs" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/a7/ca/f23dcb02e161a9bba141b1c08aa50e8da6ea25e6d780528f1d385a3efe25/virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35", size = 7658028 }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/89/9b/599bcfc7064fbe5740919e78c5df18e5dceb0887e676256a1061bb5ae232/virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779", size = 4282379 },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wheel"
|
name = "wheel"
|
||||||
version = "0.45.1"
|
version = "0.45.1"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user