feat: 更新数据集处理逻辑,优化图像大小调整,添加生成数据的功能

This commit is contained in:
YunyaoZhou 2025-01-18 03:48:13 +08:00
parent a38ccf7042
commit da644a081d
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
23 changed files with 1001 additions and 303 deletions

11
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,11 @@
repos:
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.11

View File

@ -11,6 +11,7 @@ dependencies = [
"numba>=0.60.0",
"peft==0.14.0",
"pip==24.3.1",
"pre-commit>=4.0.1",
"requests==2.32.3",
"rouge-score>=0.1.2",
"safetensors>=0.5.2",
@ -30,16 +31,14 @@ requires-python = ">=3.11"
version = "0.1.0"
[project.optional-dependencies]
compile = [
"flash-attn>=2.7.2.post1",
]
compile = ["flash-attn>=2.7.2.post1"]
[tool.uv.sources]
markupsafe = {index = "pytorch"}
requests = {index = "pypi"}
torch = {index = "pytorch"}
torchaudio = {index = "pytorch"}
torchvision = {index = "pytorch"}
markupsafe = { index = "pytorch" }
requests = { index = "pypi" }
torch = { index = "pytorch" }
torchaudio = { index = "pytorch" }
torchvision = { index = "pytorch" }
[[tool.uv.index]]
name = "pypi"
@ -52,3 +51,7 @@ concurrent-builds = 4
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu124"
[tool.black]
line-length = 88
exclude = "transformers_repo|peft_repo|.venv"

View File

@ -25,9 +25,9 @@ class CHEMDataset(Dataset):
def _vis_processor(self, image: Image.Image):
width, height = image.size
if width > 600 or height > 600:
if width > 800 or height > 800:
max_size = max(width, height)
ratio = 600 / max_size
ratio = 800 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
@ -69,7 +69,7 @@ class CHEMDataset(Dataset):
return processed_data
def __len__(self):
return len(self.data) // 60
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
@ -147,12 +147,14 @@ class CHEMDatasetForGeneration(CHEMDataset):
],
},
]
return {
"image": image,
"chat": chat,
"answer": answer,
"original": sample["original"],
}
from .format import create_generate
return create_generate(
images=[image],
chat=chat,
answer=answer,
original=sample["original"],
)
if __name__ == "__main__":

View File

@ -75,12 +75,14 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset):
],
},
]
from .format import create_generate
return {
"audio": (audio, sampling_rate),
"chat": chat,
"answer": text,
}
return create_generate(
audio=[(audio, sampling_rate)],
chat=chat,
answer=text,
original=sample,
)
if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
import sys
sys.path.insert(0, "./transformers_repo/src/")
sys.path.insert(0, "./peft_repo/src/")

View File

@ -1 +1 @@
from .mapping import get_peft_config, get_peft_model, inject_adapter_in_model
from .mapping import get_peft_config, get_peft_model, inject_adapter_in_model

View File

@ -40,7 +40,9 @@ class AqlmLoraLinear(torch.nn.Module, LoraLayer):
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
)
super().__init__()
LoraLayer.__init__(self, base_layer)

View File

@ -37,7 +37,9 @@ class AwqLoraLinear(torch.nn.Module, LoraLayer):
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
)
super().__init__()
LoraLayer.__init__(self, base_layer)
@ -107,7 +109,9 @@ def dispatch_awq(
if isinstance(target_base_layer, WQLinear_GEMM):
# Raise the error only at the dispatch level
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))
version_autoawq = packaging.version.parse(
importlib_metadata.version("autoawq")
)
if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
raise ImportError(

View File

@ -60,7 +60,9 @@ if is_bnb_available():
lora_bias=lora_bias,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights into the base weights
@ -109,7 +111,9 @@ if is_bnb_available():
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight / weight_norm
)
w_data = dora_factor.view(-1, 1) * (output + lora_data)
if safe_merge and not torch.isfinite(w_data).all():
@ -118,10 +122,15 @@ if is_bnb_available():
)
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
w_data.to("cpu"),
requires_grad=False,
has_fp16_weights=weight.has_fp16_weights,
).to(weight.device)
if self.lora_bias[active_adapter]:
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
bias_data = (
self.get_base_layer().bias.data
+ self.lora_B[active_adapter].bias
)
if safe_merge and not torch.isfinite(bias_data):
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
@ -158,11 +167,15 @@ if is_bnb_available():
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight / weight_norm
)
w_data = output.data / dora_factor.view(-1, 1) - lora_data
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
w_data.to("cpu"),
requires_grad=False,
has_fp16_weights=weight.has_fp16_weights,
).to(weight.device)
if self.lora_bias[active_adapter]:
@ -188,7 +201,13 @@ if is_bnb_available():
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
sub_batch_indices_list.append(
[
index
for index, item in enumerate(adapter_names)
if item == adapter
]
)
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
@ -227,7 +246,9 @@ if is_bnb_available():
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
result = self._mixed_batch_forward(
x, *args, adapter_names=adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
@ -330,7 +351,9 @@ if is_bnb_4bit_available():
lora_bias=lora_bias,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights into the base weights
@ -375,7 +398,9 @@ if is_bnb_4bit_available():
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight / weight_norm
)
w_data = dora_factor.view(-1, 1) * (output + lora_data)
if safe_merge and not torch.isfinite(w_data).all():
@ -386,9 +411,14 @@ if is_bnb_4bit_available():
kwargs["bnb_quantized"] = False
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
self.get_base_layer().weight = bnb.nn.Params4bit(
w_data.to("cpu"), **kwargs
).to(weight.device)
if self.lora_bias[active_adapter]:
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
bias_data = (
self.get_base_layer().bias.data
+ self.lora_B[active_adapter].bias
)
if safe_merge and not torch.isfinite(bias_data):
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
@ -422,14 +452,18 @@ if is_bnb_4bit_available():
w_data = output - lora_data
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight / weight_norm
)
w_data = output.data / dora_factor.view(-1, 1) - lora_data
if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
self.get_base_layer().weight = bnb.nn.Params4bit(
w_data.to("cpu"), **kwargs
).to(weight.device)
if self.lora_bias[active_adapter]:
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias
@ -452,7 +486,13 @@ if is_bnb_4bit_available():
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
sub_batch_indices_list.append(
[
index
for index, item in enumerate(adapter_names)
if item == adapter
]
)
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
@ -489,7 +529,9 @@ if is_bnb_4bit_available():
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
result = self._mixed_batch_forward(
x, *args, adapter_names=adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
@ -550,7 +592,11 @@ if is_bnb_4bit_available():
target_base_layer = target
loaded_in_4bit = kwargs.get("loaded_in_4bit", False)
if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
if (
loaded_in_4bit
and is_bnb_4bit_available()
and isinstance(target_base_layer, bnb.nn.Linear4bit)
):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{

View File

@ -66,7 +66,9 @@ class LoftQConfig:
"""
loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"})
loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
loftq_iter: int = field(
default=1, metadata={"help": "Alternating iterations for LoftQ"}
)
@dataclass
@ -101,13 +103,25 @@ class EvaConfig:
are adjusted so that all LoRA gradients have the same scale regardless of their rank. Default is True.
"""
rho: float = field(default=2.0, metadata={"help": "Rho value for EVA redistribution"})
tau: float = field(default=0.99, metadata={"help": "Cosine similarity threshold for early stopping"})
use_label_mask: bool = field(default=True, metadata={"help": "Use label mask for EVA initialization"})
label_mask_value: int = field(
default=-100, metadata={"help": "if use_label_mask=True the value to look for to mask out ignored tokens"}
rho: float = field(
default=2.0, metadata={"help": "Rho value for EVA redistribution"}
)
tau: float = field(
default=0.99,
metadata={"help": "Cosine similarity threshold for early stopping"},
)
use_label_mask: bool = field(
default=True, metadata={"help": "Use label mask for EVA initialization"}
)
label_mask_value: int = field(
default=-100,
metadata={
"help": "if use_label_mask=True the value to look for to mask out ignored tokens"
},
)
whiten: bool = field(
default=False, metadata={"help": "Apply whitening to singular vectors"}
)
whiten: bool = field(default=False, metadata={"help": "Apply whitening to singular vectors"})
adjust_scaling_factors: bool = field(
default=True,
metadata={"help": "Adjust LoRA scaling factors after the rank redistribution"},
@ -233,16 +247,21 @@ class LoraConfig(PeftConfig):
)
exclude_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={"help": "List of module names or regex expression of the module names to exclude from Lora."},
metadata={
"help": "List of module names or regex expression of the module names to exclude from Lora."
},
)
lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
fan_in_fan_out: bool = field(
default=False,
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
metadata={
"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"
},
)
bias: Literal["none", "all", "lora_only"] = field(
default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}
default="none",
metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"},
)
use_rslora: bool = field(
default=False,
@ -264,7 +283,15 @@ class LoraConfig(PeftConfig):
},
)
init_lora_weights: (
bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]
bool
| Literal[
"gaussian",
"eva",
"olora",
"pissa",
"pissa_niter_[number of iters]",
"loftq",
]
) = field(
default=True,
metadata={
@ -418,47 +445,72 @@ class LoraConfig(PeftConfig):
super().__post_init__()
self.peft_type = PeftType.LORA
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
set(self.target_modules)
if isinstance(self.target_modules, list)
else self.target_modules
)
self.exclude_modules = (
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
set(self.exclude_modules)
if isinstance(self.exclude_modules, list)
else self.exclude_modules
)
# if target_modules is a regex expression, then layers_to_transform should be None
if isinstance(self.target_modules, str) and self.layers_to_transform is not None:
raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.")
if (
isinstance(self.target_modules, str)
and self.layers_to_transform is not None
):
raise ValueError(
"`layers_to_transform` cannot be used when `target_modules` is a str."
)
# if target_modules is a regex expression, then layers_pattern should be None
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")
raise ValueError(
"`layers_pattern` cannot be used when `target_modules` is a str."
)
# check for layers_to_transform and layers_pattern
if self.layers_pattern and not self.layers_to_transform:
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ")
raise ValueError(
"When `layers_pattern` is specified, `layers_to_transform` must also be specified. "
)
if self.use_dora and self.megatron_config:
raise ValueError("DoRA does not support megatron_core, please set `use_dora=False`.")
raise ValueError(
"DoRA does not support megatron_core, please set `use_dora=False`."
)
# handle init_lora_weights and loftq_config
if self.init_lora_weights == "loftq":
import importlib
if not importlib.util.find_spec("scipy"):
raise ImportError("The required package 'scipy' is not installed. Please install it to continue.")
raise ImportError(
"The required package 'scipy' is not installed. Please install it to continue."
)
if not self.loftq_config:
raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")
raise ValueError(
"`loftq_config` must be specified when `init_lora_weights` is 'loftq'."
)
if not isinstance(self.loftq_config, dict):
# convert loftq_config to dict
self.loftq_config = vars(self.loftq_config)
elif self.loftq_config:
self.loftq_config = {}
warnings.warn("`loftq_config` specified but will be ignored when `init_lora_weights` is not 'loftq'.")
warnings.warn(
"`loftq_config` specified but will be ignored when `init_lora_weights` is not 'loftq'."
)
elif self.init_lora_weights == "eva" and self.eva_config is None:
warnings.warn("`init_lora_weights` is 'eva' but `eva_config` is not specified. Using default EVA config.")
warnings.warn(
"`init_lora_weights` is 'eva' but `eva_config` is not specified. Using default EVA config."
)
self.eva_config = EvaConfig()
elif self.init_lora_weights != "eva" and self.eva_config is not None:
warnings.warn("`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'.")
warnings.warn(
"`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'."
)
if self.lora_bias:
if self.init_lora_weights not in (True, False):
@ -467,7 +519,9 @@ class LoraConfig(PeftConfig):
f"init_lora_weights={self.init_lora_weights} instead."
)
if self.use_dora:
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")
raise ValueError(
"The argument lora_bias=True is not supported for DoRA, please pass use_dora=False"
)
# Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
@ -477,7 +531,10 @@ class LoraConfig(PeftConfig):
self.use_rslora
and (self.rank_pattern or self.alpha_pattern)
and (
(isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa")))
(
isinstance(self.init_lora_weights, str)
and (self.init_lora_weights.startswith("pissa"))
)
or (self.init_lora_weights == "olora")
)
):
@ -491,7 +548,9 @@ class LoraConfig(PeftConfig):
self._custom_modules: Optional[dict[type[nn.Mmodule], type[nn.Module]]] = None
def _register_custom_module(self, mapping: dict[type[nn.Mmodule], type[nn.Module]]) -> None:
def _register_custom_module(
self, mapping: dict[type[nn.Mmodule], type[nn.Module]]
) -> None:
"""
Experimental API to support providing custom LoRA layers.

View File

@ -34,7 +34,9 @@ class DoraLinearLayer(nn.Module):
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None:
def update_layer(
self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False
) -> None:
# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
dtype_is_fp16 = lora_A.dtype == torch.float16
if dtype_is_fp16:
@ -49,14 +51,18 @@ class DoraLinearLayer(nn.Module):
weight = dequantize_module_weight(base_layer)
if weight.data.ndim >= 4: # For handling LoRAs applied to Conv layers.
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
lora_weight = torch.mm(
lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1)
)
lora_weight = lora_weight.reshape(weight.shape)
else:
lora_weight = lora_B @ lora_A
if dtype_is_fp16:
lora_weight = lora_weight.half()
weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling)
weight_norm = self.get_weight_norm(
weight.to(lora_A.device), lora_weight, scaling
)
if place_on_cpu:
weight_norm = weight_norm.to("cpu")
@ -69,7 +75,9 @@ class DoraLinearLayer(nn.Module):
"""
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype)
x_eye = torch.eye(
lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype
)
lora_weight = lora_B(lora_A(x_eye)).T
magnitude = self.weight
@ -95,7 +103,9 @@ class DoraLinearLayer(nn.Module):
else:
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling
result_dora = (
mag_norm_scale - 1
) * base_result + mag_norm_scale * lora_result * scaling
return result_dora
@ -145,7 +155,9 @@ class _DoraConvNdLayer(DoraLinearLayer):
output.
"""
weight = base_layer.weight
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
lora_weight = torch.mm(
lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)
)
lora_weight = lora_weight.reshape(weight.shape)
magnitude = self.weight
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling)

View File

@ -38,7 +38,9 @@ if is_eetq_available():
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
)
super().__init__()
LoraLayer.__init__(self, base_layer)
@ -85,11 +87,17 @@ if is_eetq_available():
result = result + output
return result
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
raise AttributeError("Merging LoRA layers is not supported for Eetq layers.")
def merge(
self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None
) -> None:
raise AttributeError(
"Merging LoRA layers is not supported for Eetq layers."
)
def unmerge(self) -> None:
raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.")
raise AttributeError(
"Unmerging LoRA layers is not supported for Eetq layers."
)
def __repr__(self) -> str:
rep = super().__repr__()

View File

@ -26,7 +26,10 @@ import torch.distributed as dist
from tqdm import tqdm
from transformers.pytorch_utils import Conv1D
from peft.tuners.tuners_utils import _find_minimal_target_modules, check_target_module_exists
from peft.tuners.tuners_utils import (
_find_minimal_target_modules,
check_target_module_exists,
)
from peft.utils.constants import MIN_TARGET_MODULES_FOR_OPTIMIZATION
from peft.utils.incremental_pca import IncrementalPCA
from peft.utils.other import _get_submodules, get_pattern_key
@ -58,7 +61,9 @@ class _Hook:
self.model_input = None
@staticmethod
def _prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> torch.Tensor:
def _prepare_layer_inputs_fn_default(
layer_input, model_input, layer_name
) -> torch.Tensor:
if isinstance(layer_input, torch.Tensor):
pass
elif isinstance(layer_input, (tuple, list)):
@ -83,20 +88,28 @@ class _Hook:
# First gather sizes from all processes more efficiently
local_size = torch.tensor([layer_input.shape[0]], device=layer_input.device)
all_sizes = torch.empty(world_size, dtype=local_size.dtype, device=layer_input.device)
all_sizes = torch.empty(
world_size, dtype=local_size.dtype, device=layer_input.device
)
dist.all_gather_into_tensor(all_sizes, local_size)
all_sizes = all_sizes.tolist()
# Find maximum size and pad tensors
padded_input = layer_input.new_zeros((max(all_sizes), *layer_input.shape[1:]))
padded_input = layer_input.new_zeros(
(max(all_sizes), *layer_input.shape[1:])
)
padded_input[: layer_input.shape[0]] = layer_input
# Gather padded tensors
gathered_inputs = [torch.zeros_like(padded_input) for _ in range(world_size)]
gathered_inputs = [
torch.zeros_like(padded_input) for _ in range(world_size)
]
dist.all_gather(gathered_inputs, padded_input.contiguous())
# Remove padding for each gathered tensor
gathered_inputs = [tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes)]
gathered_inputs = [
tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes)
]
# Concatenate along batch dimension
return torch.cat(gathered_inputs, dim=0)
@ -153,7 +166,9 @@ class SVDHook(_Hook):
states = self.gather_layer_inputs(states)
# check if batch sizes is more than the number of components
if states.size(0) < self.n_components:
print(f"skipping SVD for {self.name} because there are less than {self.n_components} examples")
print(
f"skipping SVD for {self.name} because there are less than {self.n_components} examples"
)
return
self.svd.partial_fit(states.to(torch.float32))
# add if statement to check if we are in the first step where previous_components is None
@ -218,7 +233,9 @@ def get_device_with_meta_params(model: torch.nn.Module) -> torch.device:
"""
devices = list({p.device for p in model.parameters() if p.device.type != "meta"})
if len(devices) > 1:
warnings.warn(f"Could not determine device, model has multiple devices: {devices}")
warnings.warn(
f"Could not determine device, model has multiple devices: {devices}"
)
return
return devices[0]
@ -230,11 +247,15 @@ def move_inputs_to_device(inputs, device: Union[str, torch.device]):
if hasattr(inputs, "to"):
return inputs.to(device)
if isinstance(inputs, Mapping):
return type(inputs)({k: move_inputs_to_device(v, device) for k, v in inputs.items()})
return type(inputs)(
{k: move_inputs_to_device(v, device) for k, v in inputs.items()}
)
elif isinstance(inputs, (tuple, list)):
return type(inputs)(move_inputs_to_device(v, device) for v in inputs)
else:
warnings.warn(f"input of type {type(inputs)} could not be moved to the correct device")
warnings.warn(
f"input of type {type(inputs)} could not be moved to the correct device"
)
return inputs
@ -247,14 +268,22 @@ def prepare_model_inputs_fn_language_modeling(model_input, peft_config: LoraConf
peft_config (LoraConfig): The configuration for the LoRA layers.
"""
if not isinstance(model_input, dict):
raise ValueError("When using `prepare_model_inputs_fn_language_modeling` inputs must be a dictionary")
mask = model_input.get("attention_mask", torch.ones_like(model_input["input_ids"])).bool()
raise ValueError(
"When using `prepare_model_inputs_fn_language_modeling` inputs must be a dictionary"
)
mask = model_input.get(
"attention_mask", torch.ones_like(model_input["input_ids"])
).bool()
if peft_config.eva_config.use_label_mask and hasattr(model_input, "labels"):
mask = torch.logical_and(mask, model_input["labels"] != peft_config.eva_config.label_mask_value)
mask = torch.logical_and(
mask, model_input["labels"] != peft_config.eva_config.label_mask_value
)
return mask.nonzero()
def prepare_layer_inputs_fn_language_modeling(layer_input, model_input, layer_name) -> torch.Tensor:
def prepare_layer_inputs_fn_language_modeling(
layer_input, model_input, layer_name
) -> torch.Tensor:
"""
if not all items in the input should be used for SVD, this function can be used to get the indices of the items
that should be used.
@ -299,12 +328,21 @@ def _get_eva_state_dict(
) -> dict:
# Computes the rank distribution for each layer based on the explained variance ratio.
# when rank_pattern flag is False, all values in max_components are the same
def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components):
exp_vars = {k: h[0].svd.explained_variance_ratio_[: max_components[k]] for k, h in hooks.items()}
keys, values = zip(*[(k, c) for k, name in layer_hook_map.items() for c in exp_vars[name]])
def _get_rank_distribution(
hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components
):
exp_vars = {
k: h[0].svd.explained_variance_ratio_[: max_components[k]]
for k, h in hooks.items()
}
keys, values = zip(
*[(k, c) for k, name in layer_hook_map.items() for c in exp_vars[name]]
)
idx = torch.stack(values).argsort(descending=True)
counts = Counter([keys[i] for i in idx[:rank_budget]])
counts = {k: counts.get(k, 0) for k in layer_hook_map.keys()} # add layers with 0 rank
counts = {
k: counts.get(k, 0) for k in layer_hook_map.keys()
} # add layers with 0 rank
for k, k_hook in equal_inputs_map.items():
# ensure hook layers have the highest rank if they are equal to another layer
rank, rank_hook = counts[k], counts[k_hook]
@ -356,7 +394,11 @@ def _get_eva_state_dict(
fn = prepare_layer_inputs_fn.pop(name, None)
else:
fn = prepare_layer_inputs_fn
hook = HashHook(name=name, prepare_layer_inputs_fn=fn, gather_distributed_inputs=gather_distributed_inputs)
hook = HashHook(
name=name,
prepare_layer_inputs_fn=fn,
gather_distributed_inputs=gather_distributed_inputs,
)
hook.model_input = model_inputs_for_hooks
handle = module.register_forward_hook(hook)
hooks[name] = (hook, handle)
@ -365,7 +407,10 @@ def _get_eva_state_dict(
)
max_components[name] = round(layer_rank * rho)
rank_budget += layer_rank
if isinstance(prepare_layer_inputs_fn, Mapping) and len(prepare_layer_inputs_fn) > 0:
if (
isinstance(prepare_layer_inputs_fn, Mapping)
and len(prepare_layer_inputs_fn) > 0
):
raise ValueError(
"prepare_layer_inputs_fn is a mapping but the following module names were not found in the model: "
f"{prepare_layer_inputs_fn.keys()}"
@ -399,7 +444,10 @@ def _get_eva_state_dict(
)
module = model.get_submodule(name)
handle = module.register_forward_hook(hook)
hooks[name] = (hook, handle) # adding the old handle here so we dont get errors in the first forward pass
hooks[name] = (
hook,
handle,
) # adding the old handle here so we dont get errors in the first forward pass
layer_hook_map = {**dict(zip(hooks.keys(), hooks.keys())), **equal_inputs_map}
# start svd calculation
@ -441,7 +489,9 @@ def _get_eva_state_dict(
layer_converged = list(convergence_dict.values()) + [
convergence_dict[v] for v in equal_inputs_map.values()
]
pbar.set_description(f"{sum(layer_converged)}/{len(layer_converged)} layers have converged")
pbar.set_description(
f"{sum(layer_converged)}/{len(layer_converged)} layers have converged"
)
if all(convergence_dict.values()):
break
@ -453,10 +503,17 @@ def _get_eva_state_dict(
if not all(hasattr(h[0].svd, "components_") for h in hooks.values()):
continue
rank_dist = _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components)
rank_dist = _get_rank_distribution(
hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components
)
# check all custom hooks have been removed
remaining_hooks = {n for n, m in model.named_modules() for v in m._forward_hooks.values() if isinstance(v, _Hook)}
remaining_hooks = {
n
for n, m in model.named_modules()
for v in m._forward_hooks.values()
if isinstance(v, _Hook)
}
if len(remaining_hooks) > 0:
raise ValueError(
f"Found active hooks added by EVA that weren't properly removed: {remaining_hooks}. "
@ -510,9 +567,12 @@ def _load_eva_state_dict(
other_module_names.append(name_in_base_model)
continue
# Regexp matching - Find key which matches current target_name in patterns provided
r = peft_config.rank_pattern.get(get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r)
r = peft_config.rank_pattern.get(
get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r
)
alpha = peft_config.alpha_pattern.get(
get_pattern_key(peft_config.alpha_pattern.keys(), name), peft_config.lora_alpha
get_pattern_key(peft_config.alpha_pattern.keys(), name),
peft_config.lora_alpha,
)
if name in eva_state_dict:
w = eva_state_dict.pop(name)
@ -524,12 +584,22 @@ def _load_eva_state_dict(
elif new_rank != r:
if peft_config.eva_config.adjust_scaling_factors:
alpha *= new_rank / r
if new_rank != r or module.lora_A[adapter_name].weight.device.type == "meta":
module.update_layer(r=new_rank, lora_alpha=alpha, init_lora_weights="eva", **update_layer_kwargs)
if (
new_rank != r
or module.lora_A[adapter_name].weight.device.type == "meta"
):
module.update_layer(
r=new_rank,
lora_alpha=alpha,
init_lora_weights="eva",
**update_layer_kwargs,
)
module.lora_A[adapter_name].weight.copy_(w)
new_target_modules.append(name_in_base_model)
else:
module.update_layer(r=r, lora_alpha=alpha, init_lora_weights=True, **update_layer_kwargs)
module.update_layer(
r=r, lora_alpha=alpha, init_lora_weights=True, **update_layer_kwargs
)
missing_eva_inits.append(name_in_base_model)
new_rank = r
# update rank pattern and alpha pattern
@ -541,7 +611,9 @@ def _load_eva_state_dict(
# update target modules if some lora layers have been removed due to their EVA rank being 0
new_target_modules = new_target_modules + missing_eva_inits
if len(new_target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION:
new_target_modules = _find_minimal_target_modules(new_target_modules, other_module_names)
new_target_modules = _find_minimal_target_modules(
new_target_modules, other_module_names
)
model.peft_config[adapter_name].target_modules = new_target_modules
# set rank pattern obtained from EVA
@ -564,8 +636,12 @@ def get_eva_state_dict(
dataloader: Iterable,
peft_config: Optional[LoraConfig] = None,
forward_fn: Optional[callable] = forward_fn_dict,
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
prepare_model_inputs_fn: Optional[
callable
] = prepare_model_inputs_fn_language_modeling,
prepare_layer_inputs_fn: Union[
callable, Dict[str, callable], None
] = prepare_layer_inputs_fn_language_modeling,
adapter_name: str = "default",
gather_distributed_inputs: bool = True,
show_progress_bar: bool = True,
@ -613,7 +689,9 @@ def get_eva_state_dict(
def target_module_check_fn_peft_model(name, module, unsupported_lora_modules):
"check if a module is an adapter module via base_layer attribute"
return hasattr(module, "base_layer") and not isinstance(module, unsupported_lora_modules)
return hasattr(module, "base_layer") and not isinstance(
module, unsupported_lora_modules
)
def target_module_check_fn_default(name, module, peft_config):
"check if a module is an adapter module via target_modules"
@ -635,11 +713,14 @@ def get_eva_state_dict(
if is_peft_model:
ctx = model.disable_adapter()
target_module_check_fn = partial(
target_module_check_fn_peft_model, unsupported_lora_modules=UNSUPPORTED_LORA_MODULES
target_module_check_fn_peft_model,
unsupported_lora_modules=UNSUPPORTED_LORA_MODULES,
)
else:
ctx = nullcontext()
target_module_check_fn = partial(target_module_check_fn_default, peft_config=peft_config)
target_module_check_fn = partial(
target_module_check_fn_default, peft_config=peft_config
)
with ctx:
eva_state_dict = _get_eva_state_dict(
@ -662,8 +743,12 @@ def initialize_lora_eva_weights(
dataloader: Optional[Iterable] = None,
eva_state_dict: Optional[dict] = None,
forward_fn: Optional[callable] = forward_fn_dict,
prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling,
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling,
prepare_model_inputs_fn: Optional[
callable
] = prepare_model_inputs_fn_language_modeling,
prepare_layer_inputs_fn: Union[
callable, Dict[str, callable], None
] = prepare_layer_inputs_fn_language_modeling,
adapter_name: str = "default",
gather_distributed_inputs: bool = True,
show_progress_bar: bool = True,
@ -715,11 +800,15 @@ def initialize_lora_eva_weights(
# eva currently only works with a single active adapter
# Important: when removing this requirement, make sure eva init works correctly if the new rank is 0.
if len(model.active_adapters) > 1:
raise ValueError("`initialize_lora_eva_weights` currently only works with a single active adapter")
raise ValueError(
"`initialize_lora_eva_weights` currently only works with a single active adapter"
)
# initialize_lora_eva_weights only works with `init_lora_weights='eva'`
if model.peft_config[adapter_name].init_lora_weights != "eva":
raise ValueError("`initialize_lora_eva_weights` can only be used with `init_lora_weights='eva'`")
raise ValueError(
"`initialize_lora_eva_weights` can only be used with `init_lora_weights='eva'`"
)
# compute svd
if eva_state_dict is None:

View File

@ -39,7 +39,9 @@ class QuantLinear(torch.nn.Module, LoraLayer):
LoraLayer.__init__(self, base_layer)
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
)
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
@ -109,7 +111,9 @@ def dispatch_gptq(
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
if AutoGPTQQuantLinear is not None and isinstance(
target_base_layer, AutoGPTQQuantLinear
):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

View File

@ -45,7 +45,9 @@ if is_hqq_available():
**kwargs,
) -> None:
if lora_bias:
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support lora_bias yet, set it to False"
)
super().__init__()
LoraLayer.__init__(self, base_layer)
@ -63,7 +65,9 @@ if is_hqq_available():
lora_bias=lora_bias,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights into the base weights
@ -86,7 +90,10 @@ if is_hqq_available():
continue
layer = self.get_base_layer()
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
quant_config = {
**copy.deepcopy(layer.quant_config),
"offload_meta": layer.offload_meta,
}
lora_data = self.get_delta_weight(active_adapter)
output = layer.dequantize()
@ -95,19 +102,28 @@ if is_hqq_available():
else:
# handle dora
# since output already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach()
weight_norm = self._get_weight_norm(
output, lora_data, scaling=1
).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter] / weight_norm
)
w_data = dora_factor.view(-1, 1) * (output + lora_data)
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
new_hqq_layer = HQQLinear(
None,
quant_config,
compute_dtype=layer.compute_dtype,
device=layer.device,
)
quant_config.pop("offload_meta", None)
new_hqq_layer.quantize(w_data, **quant_config)
self.base_layer = new_hqq_layer
@ -128,17 +144,27 @@ if is_hqq_available():
lora_data = self.get_delta_weight(active_adapter)
layer = self.get_base_layer()
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
quant_config = {
**copy.deepcopy(layer.quant_config),
"offload_meta": layer.offload_meta,
}
output = layer.dequantize()
if not self.use_dora[active_adapter]:
w_data = output - lora_data
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter] / weight_norm
)
w_data = output.data / dora_factor.view(-1, 1) - lora_data
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
new_hqq_layer = HQQLinear(
None,
quant_config,
compute_dtype=layer.compute_dtype,
device=layer.device,
)
quant_config.pop("offload_meta", None)
new_hqq_layer.quantize(w_data, **quant_config)
self.base_layer = new_hqq_layer
@ -162,7 +188,13 @@ if is_hqq_available():
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
sub_batch_indices_list.append(
[
index
for index, item in enumerate(adapter_names)
if item == adapter
]
)
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
@ -201,7 +233,9 @@ if is_hqq_available():
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
result = self._mixed_batch_forward(
x, *args, adapter_names=adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:

View File

@ -25,11 +25,21 @@ from torch import svd_lowrank
from transformers.pytorch_utils import Conv1D
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type
from peft.utils.integrations import (
dequantize_module_weight,
gather_params_ctx,
get_bnb_param_type,
)
from peft.utils.other import transpose
from .config import LoraConfig
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer
from .dora import (
DoraConv2dLayer,
DoraConv3dLayer,
DoraEmbeddingLayer,
DoraLinearLayer,
_DoraConvNdLayer,
)
class LoraLayer(BaseTunerLayer):
@ -38,7 +48,9 @@ class LoraLayer(BaseTunerLayer):
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")
def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None:
def __init__(
self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs
) -> None:
self.base_layer = base_layer
self.r = {}
self.lora_alpha = {}
@ -67,10 +79,15 @@ class LoraLayer(BaseTunerLayer):
elif isinstance(base_layer, nn.Conv3d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
in_features, out_features = (
base_layer.num_embeddings,
base_layer.embedding_dim,
)
elif isinstance(base_layer, Conv1D):
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
base_layer.weight.ds_shape
if hasattr(base_layer.weight, "ds_shape")
else base_layer.weight.shape
)
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
@ -78,26 +95,40 @@ class LoraLayer(BaseTunerLayer):
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = base_layer.input_size, base_layer.output_size
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear":
elif (
hasattr(base_layer, "codebooks")
and base_layer.__class__.__name__ == "QuantizedLinear"
):
# AQLM QuantLinear
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
elif (
hasattr(base_layer, "w_bit")
and base_layer.__class__.__name__ == "WQLinear_GEMM"
):
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif base_layer.__class__.__name__ == "EetqLinear":
# Eetq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
elif (
hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear"
):
# HQQ layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
# possibly support user provided custom layer types using dynamic dispatch
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
in_features, out_features = base_layer.in_features, base_layer.out_features
if hasattr(base_layer, "in_features") and hasattr(
base_layer, "out_features"
):
in_features, out_features = (
base_layer.in_features,
base_layer.out_features,
)
else:
in_features, out_features = None, None
warnings.warn(
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.",
UserWarning,
)
self.in_features = in_features
@ -116,7 +147,9 @@ class LoraLayer(BaseTunerLayer):
):
# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
raise ValueError(
f"`r` should be a positive integer value but the value passed is {r}"
)
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
@ -140,7 +173,9 @@ class LoraLayer(BaseTunerLayer):
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
elif (
isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora"
):
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
elif init_lora_weights == "loftq":
@ -169,9 +204,13 @@ class LoraLayer(BaseTunerLayer):
if init_lora_weights is True:
# initialize A the same way as the default for nn.Linear and B to zero
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(
self.lora_A[adapter_name].weight, a=math.sqrt(5)
)
elif init_lora_weights.lower() == "gaussian":
nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name])
nn.init.normal_(
self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]
)
else:
raise ValueError(f"Unknown initialization {init_lora_weights=}")
nn.init.zeros_(self.lora_B[adapter_name].weight)
@ -210,7 +249,11 @@ class LoraLayer(BaseTunerLayer):
self.lora_A[adapter_name].weight.data = Rr.contiguous()
self.lora_B[adapter_name].weight.data = Qr.contiguous()
weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
weight_tensor.data -= (
scale_factor
* self.lora_B[adapter_name].weight
@ self.lora_A[adapter_name].weight
)
if bnb_param_type == "4bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
@ -249,7 +292,9 @@ class LoraLayer(BaseTunerLayer):
Uhr = Uh[: self.r[adapter_name]]
elif len(init_lora_weights.split("_niter_")) == 2:
Vr, Sr, Ur = svd_lowrank(
weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
weight.data,
self.r[adapter_name],
niter=int(init_lora_weights.split("_niter_")[-1]),
)
Sr /= self.scaling[adapter_name]
Uhr = Ur.t()
@ -290,12 +335,18 @@ class LoraLayer(BaseTunerLayer):
def dora_init(self, adapter_name: str) -> None:
if not self.lora_magnitude_vector:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)
self.adapter_layer_names = self.adapter_layer_names[:] + (
"lora_magnitude_vector",
)
dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(self, "fan_in_fan_out", False))
dora_layer = DoraLinearLayer(
fan_in_fan_out=getattr(self, "fan_in_fan_out", False)
)
lora_A = self.lora_A[adapter_name].weight
lora_B = self.lora_B[adapter_name].weight
place_on_cpu = self.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu")
place_on_cpu = self.ephemeral_gpu_offload and (
lora_A.device.type == "cpu" or lora_B.device.type == "cpu"
)
if self.ephemeral_gpu_offload:
if lora_A.device.type in ["cuda", "xpu"]:
lora_B = lora_B.to(lora_A.device)
@ -308,7 +359,11 @@ class LoraLayer(BaseTunerLayer):
lora_A = lora_A.to(lora_B.device)
scaling = self.scaling[adapter_name]
dora_layer.update_layer(
base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling, place_on_cpu=place_on_cpu
base_layer=self.get_base_layer(),
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
place_on_cpu=place_on_cpu,
)
self.lora_magnitude_vector[adapter_name] = dora_layer
@ -341,7 +396,9 @@ class LoraLayer(BaseTunerLayer):
continue
if scale is None:
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
self.scaling[active_adapter] = (
self.lora_alpha[active_adapter] / self.r[active_adapter]
)
else:
self.scaling[active_adapter] /= scale
@ -383,7 +440,9 @@ class LoraLayer(BaseTunerLayer):
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
sub_batch_indices_list.append(
[index for index, item in enumerate(adapter_names) if item == adapter]
)
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
@ -449,7 +508,9 @@ class Linear(nn.Module, LoraLayer):
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights into the base weights
@ -482,15 +543,24 @@ class Linear(nn.Module, LoraLayer):
# since delta_weight already includes scaling, set it to 1 here
weight_norm = (
self.lora_magnitude_vector[active_adapter]
.get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1)
.get_weight_norm(
orig_weights,
transpose(delta_weight, self.fan_in_fan_out),
scaling=1,
)
.detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight
/ weight_norm
)
dora_factor = transpose(
dora_factor.view(-1, 1), self.fan_in_fan_out
)
orig_weights = dora_factor * (orig_weights + delta_weight)
if not torch.isfinite(orig_weights).all():
@ -518,7 +588,9 @@ class Linear(nn.Module, LoraLayer):
weight_norm = (
self.lora_magnitude_vector[active_adapter]
.get_weight_norm(
base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1
base_layer.weight,
transpose(delta_weight, self.fan_in_fan_out),
scaling=1,
)
.detach()
)
@ -526,9 +598,16 @@ class Linear(nn.Module, LoraLayer):
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
new_weight = dora_factor * (base_layer.weight.data + delta_weight)
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight
/ weight_norm
)
dora_factor = transpose(
dora_factor.view(-1, 1), self.fan_in_fan_out
)
new_weight = dora_factor * (
base_layer.weight.data + delta_weight
)
base_layer.weight.data = new_weight
if self.lora_bias[active_adapter]:
@ -552,7 +631,9 @@ class Linear(nn.Module, LoraLayer):
weight.data -= delta_weight
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight / weight_norm
)
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
weight.data = weight_orig
@ -573,7 +654,9 @@ class Linear(nn.Module, LoraLayer):
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
cast_to_fp32 = device.type == "cpu" and (
dtype == torch.float16 or dtype == torch.bfloat16
)
weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight
@ -582,7 +665,9 @@ class Linear(nn.Module, LoraLayer):
weight_A = weight_A.float()
weight_B = weight_B.float()
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
output_tensor = (
transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
)
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
@ -602,7 +687,9 @@ class Linear(nn.Module, LoraLayer):
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
result = self._mixed_batch_forward(
x, *args, adapter_names=adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
@ -661,7 +748,9 @@ class Embedding(nn.Module, LoraLayer):
) -> None:
if lora_bias:
# lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter
raise ValueError(f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}.")
raise ValueError(
f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}."
)
super().__init__()
LoraLayer.__init__(self, base_layer)
@ -679,10 +768,20 @@ class Embedding(nn.Module, LoraLayer):
)
def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora,
lora_bias,
):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
raise ValueError(
f"`r` should be a positive integer value but the value passed is {r}"
)
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
@ -723,18 +822,25 @@ class Embedding(nn.Module, LoraLayer):
def dora_init(self, adapter_name: str) -> None:
if self.lora_magnitude_vector is None:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)
self.adapter_layer_names = self.adapter_layer_names[:] + (
"lora_magnitude_vector",
)
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
lora_embedding_A = self.lora_embedding_A[adapter_name]
lora_embedding_B = self.lora_embedding_B[adapter_name]
scaling = self.scaling[adapter_name]
dora_layer.update_layer(
base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling
base_layer=self.get_base_layer(),
lora_A=lora_embedding_A,
lora_B=lora_embedding_B,
scaling=scaling,
)
self.lora_magnitude_vector[adapter_name] = dora_layer
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights into the base weights
@ -781,7 +887,9 @@ class Embedding(nn.Module, LoraLayer):
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_embedding_A.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
self.get_base_layer().weight.data -= self.get_delta_weight(
active_adapter
)
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
@ -797,7 +905,9 @@ class Embedding(nn.Module, LoraLayer):
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
cast_to_fp32 = device.type == "cpu" and (
dtype == torch.float16 or dtype == torch.bfloat16
)
weight_A = self.lora_embedding_A[adapter]
weight_B = self.lora_embedding_B[adapter]
@ -827,7 +937,9 @@ class Embedding(nn.Module, LoraLayer):
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
sub_batch_indices_list.append(
[index for index, item in enumerate(adapter_names) if item == adapter]
)
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
@ -869,7 +981,9 @@ class Embedding(nn.Module, LoraLayer):
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
result = self._mixed_batch_forward(
x, *args, adapter_names=adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
@ -886,7 +1000,9 @@ class Embedding(nn.Module, LoraLayer):
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
else:
mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter](
mag_norm_scale, dora_result = self.lora_magnitude_vector[
active_adapter
](
x,
lora_A=embedding_A,
lora_B=embedding_B,
@ -937,10 +1053,20 @@ class _ConvNd(nn.Module, LoraLayer):
)
def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora,
lora_bias,
):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
raise ValueError(
f"`r` should be a positive integer value but the value passed is {r}"
)
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
@ -957,8 +1083,12 @@ class _ConvNd(nn.Module, LoraLayer):
padding = base_layer.padding
conv_layer = type(base_layer)
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=lora_bias)
self.lora_A[adapter_name] = conv_layer(
self.in_features, r, kernel_size, stride, padding, bias=False
)
self.lora_B[adapter_name] = conv_layer(
r, self.out_features, out_kernel, out_stride, bias=lora_bias
)
self.lora_bias[adapter_name] = lora_bias
if use_rslora:
@ -988,21 +1118,30 @@ class _ConvNd(nn.Module, LoraLayer):
def dora_init(self, adapter_name: str) -> None:
if self.lora_magnitude_vector is None:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)
self.adapter_layer_names = self.adapter_layer_names[:] + (
"lora_magnitude_vector",
)
dora_layer_class = self._get_dora_layer_class()
dora_layer = dora_layer_class(fan_in_fan_out=False)
lora_A = self.lora_A[adapter_name].weight
lora_B = self.lora_B[adapter_name].weight
scaling = self.scaling[adapter_name]
dora_layer.update_layer(base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling)
dora_layer.update_layer(
base_layer=self.get_base_layer(),
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
)
self.lora_magnitude_vector[adapter_name] = dora_layer
def _get_dora_layer_class(self) -> type[_DoraConvNdLayer]:
# Subclasses should override this method to return the appropriate DoraLayer class
raise NotImplementedError
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights inside the base weights
@ -1043,8 +1182,13 @@ class _ConvNd(nn.Module, LoraLayer):
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
orig_weights = dora_factor.view(*self._get_dora_factor_view()) * (orig_weights + delta_weight)
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight
/ weight_norm
)
orig_weights = dora_factor.view(
*self._get_dora_factor_view()
) * (orig_weights + delta_weight)
if not torch.isfinite(orig_weights).all():
raise ValueError(
@ -1076,7 +1220,10 @@ class _ConvNd(nn.Module, LoraLayer):
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight
/ weight_norm
)
new_weight = dora_factor.view(*self._get_dora_factor_view()) * (
base_layer.weight.data + delta_weight
)
@ -1103,8 +1250,13 @@ class _ConvNd(nn.Module, LoraLayer):
weight.data -= delta_weight
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
weight_orig = weight.data / dora_factor.view(*self._get_dora_factor_view()) - delta_weight
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight / weight_norm
)
weight_orig = (
weight.data / dora_factor.view(*self._get_dora_factor_view())
- delta_weight
)
weight.data = weight_orig
if self.lora_bias[active_adapter]:
@ -1124,7 +1276,9 @@ class _ConvNd(nn.Module, LoraLayer):
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
cast_to_fp32 = device.type == "cpu" and (
dtype == torch.float16 or dtype == torch.bfloat16
)
weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight
@ -1136,9 +1290,9 @@ class _ConvNd(nn.Module, LoraLayer):
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
if self.get_base_layer().weight.size()[2:4] == (1, 1):
# conv2d 1x1
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
3
) * self.scaling[adapter]
output_tensor = (
weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * self.scaling[adapter]
else:
output_tensor = (
self.conv_fn(
@ -1166,7 +1320,9 @@ class _ConvNd(nn.Module, LoraLayer):
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
result = self._mixed_batch_forward(
x, *args, adapter_names=adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
@ -1207,7 +1363,9 @@ class Conv2d(_ConvNd):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 4:
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
raise ValueError(
f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}"
)
self.conv_fn = F.conv2d
def _get_dora_layer_class(self):
@ -1219,7 +1377,9 @@ class Conv3d(_ConvNd):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 5:
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
raise ValueError(
f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}"
)
self.conv_fn = F.conv3d
def _get_dora_layer_class(self):
@ -1262,10 +1422,13 @@ def dispatch_default(
elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. " "Setting fan_in_fan_out to True."
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
new_module = Linear(
target, adapter_name, is_target_conv_1d_layer=True, **kwargs
)
return new_module

View File

@ -42,7 +42,13 @@ from peft.utils import (
get_peft_model_state_dict,
get_quantization_config,
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties
from peft.utils.merge_utils import (
dare_linear,
dare_ties,
magnitude_prune,
task_arithmetic,
ties,
)
from peft.utils.other import get_pattern_key
from .aqlm import dispatch_aqlm
@ -137,8 +143,12 @@ class LoraModel(BaseTuner):
prefix: str = "lora_"
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
def __init__(
self, model, config, adapter_name, low_cpu_mem_usage: bool = False
) -> None:
super().__init__(
model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage
)
def _check_new_adapter_config(self, config: LoraConfig) -> None:
"""
@ -213,7 +223,9 @@ class LoraModel(BaseTuner):
quant_methods = ["gptq", "aqlm", "awq"]
for quant_method in quant_methods:
quantization_config = get_quantization_config(self.model, method=quant_method)
quantization_config = get_quantization_config(
self.model, method=quant_method
)
if quantization_config is not None:
kwargs[f"{quant_method}_quantization_config"] = quantization_config
@ -232,7 +244,9 @@ class LoraModel(BaseTuner):
lora_bias=lora_config.lora_bias,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
new_module = self._create_new_module(
lora_config, adapter_name, target, **kwargs
)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
@ -269,11 +283,15 @@ class LoraModel(BaseTuner):
weight = (
child.qweight
if hasattr(child, "qweight")
else child.W_q
if hasattr(child, "W_q")
else child.weight
if hasattr(child, "weight")
else next(child.parameters())
else (
child.W_q
if hasattr(child, "W_q")
else (
child.weight
if hasattr(child, "weight")
else next(child.parameters())
)
)
)
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)
@ -294,10 +312,16 @@ class LoraModel(BaseTuner):
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
if (
isinstance(m, LoraLayer)
and hasattr(m, "bias")
and m.bias is not None
):
m.bias.requires_grad = True
else:
raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")
raise NotImplementedError(
f"Requested bias: {bias}, is not implemented."
)
@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
@ -351,7 +375,9 @@ class LoraModel(BaseTuner):
new_module = None
for dispatcher in dispatchers:
new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs)
new_module = dispatcher(
target, adapter_name, lora_config=lora_config, **kwargs
)
if new_module is not None: # first match wins
break
@ -370,14 +396,19 @@ class LoraModel(BaseTuner):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
if (
name == "model"
): # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)
def get_peft_config_as_dict(self, inference: bool = False):
config_dict = {}
for key, value in self.peft_config.items():
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
config = {
k: v.value if isinstance(v, Enum) else v
for k, v in asdict(value).items()
}
if inference:
config["inference_mode"] = True
config_dict[key] = config
@ -428,7 +459,9 @@ class LoraModel(BaseTuner):
for module in self.model.modules():
if isinstance(module, LoraLayer):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
warnings.warn(
"Adapter cannot be set when the model is merged. Unmerging the model first."
)
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
@ -443,7 +476,9 @@ class LoraModel(BaseTuner):
return
if self.training:
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
raise ValueError(
"Cannot pass `adapter_names` when the model is in training mode."
)
# Check that users only passed actually existing adapters.
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
@ -456,12 +491,18 @@ class LoraModel(BaseTuner):
unique_adapters = {name for name in adapter_names if name != "__base__"}
unexpected_adapters = unique_adapters - expected_adapters
if unexpected_adapters:
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")
raise ValueError(
f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}"
)
hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
if isinstance(module, LoraLayer) or isinstance(
module, ModulesToSaveWrapper
):
pre_forward = partial(
_adapter_names_pre_forward_hook, adapter_names=adapter_names
)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
@ -477,17 +518,26 @@ class LoraModel(BaseTuner):
"""
super()._check_merge_allowed()
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
raise ValueError(
"Cannot merge LORA layers when the model is gptq quantized"
)
if self.peft_config.get("layer_replication"):
raise ValueError("Cannot merge LORA layers when base model layers are replicated")
raise ValueError(
"Cannot merge LORA layers when base model layers are replicated"
)
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
if (
model_config["model_type"]
not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
):
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[
model_config["model_type"]
]
)
return peft_config
@ -501,7 +551,9 @@ class LoraModel(BaseTuner):
if merge:
self._check_merge_allowed()
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
key_list = [
key for key, _ in self.model.named_modules() if self.prefix not in key
]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
@ -512,14 +564,18 @@ class LoraModel(BaseTuner):
if hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
self._replace_module(
parent, target_name, target.get_base_layer(), target
)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module.merge(
safe_merge=safe_merge, adapter_names=adapter_names
)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)
@ -539,7 +595,11 @@ class LoraModel(BaseTuner):
# If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
# modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
# have modules for the adapters to be merged.
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
modules_to_save_wrappers = [
module
for module in self.modules()
if isinstance(module, ModulesToSaveWrapper)
]
problematic_wrappers = [
wrapper
for wrapper in modules_to_save_wrappers
@ -555,7 +615,13 @@ class LoraModel(BaseTuner):
combination_type = "linear" if len(adapters) == 1 else combination_type
adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
if combination_type in (
"linear",
"ties",
"dare_ties",
"dare_linear",
"magnitude_prune",
):
# all adapters ranks should be same, new rank is just this value
if len(set(adapters_ranks)) != 1:
raise ValueError(
@ -573,7 +639,9 @@ class LoraModel(BaseTuner):
else:
raise ValueError(f"Invalid combination_type: {combination_type}")
target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
target_module_types = [
type(self.peft_config[adapter].target_modules) for adapter in adapters
]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
@ -583,13 +651,18 @@ class LoraModel(BaseTuner):
)
if target_module_types[0] is str:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
new_target_modules = "|".join(
f"({self.peft_config[adapter].target_modules})" for adapter in adapters
)
elif target_module_types[0] is set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
operator.or_,
(self.peft_config[adapter].target_modules for adapter in adapters),
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")
raise TypeError(
f"Invalid type {target_module_types[0]} found in target_modules"
)
return combination_type, new_rank, new_target_modules
@ -649,10 +722,12 @@ class LoraModel(BaseTuner):
if adapter_name in list(self.peft_config.keys()):
return
combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
combination_type=combination_type,
svd_rank=svd_rank,
combination_type, new_rank, new_target_modules = (
self._check_add_weighted_adapter(
adapters=adapters,
combination_type=combination_type,
svd_rank=svd_rank,
)
)
self.peft_config[adapter_name] = replace(
@ -666,7 +741,9 @@ class LoraModel(BaseTuner):
# Do we really need that?
_freeze_adapter(self.model, adapter_name)
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
key_list = [
key for key, _ in self.model.named_modules() if self.prefix not in key
]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
@ -692,11 +769,17 @@ class LoraModel(BaseTuner):
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter])
loras_A.append(
current_adapter_lora_A.data
* weight
* target.scaling[adapter]
)
loras_B.append(current_adapter_lora_B.data)
if len(loras_A) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.")
raise ValueError(
"No matching LoRAs found. Please raise an issue on GitHub."
)
loras_A = torch.cat(loras_A, dim=0)
loras_B = torch.cat(loras_B, dim=1)
target_lora_A.data[: loras_A.shape[0], :] = loras_A
@ -708,23 +791,38 @@ class LoraModel(BaseTuner):
"dare_ties_svd",
"magnitude_prune_svd",
]:
target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter(
combination_type,
adapters,
weights,
new_rank,
target,
target_lora_A,
target_lora_B,
density,
majority_sign_method,
svd_clamp,
full_matrices=svd_full_matrices,
driver=svd_driver,
target_lora_A.data, target_lora_B.data = (
self._svd_generalized_task_arithmetic_weighted_adapter(
combination_type,
adapters,
weights,
new_rank,
target,
target_lora_A,
target_lora_B,
density,
majority_sign_method,
svd_clamp,
full_matrices=svd_full_matrices,
driver=svd_driver,
)
)
elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]:
target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter(
combination_type, adapters, weights, target, density, majority_sign_method
elif combination_type in [
"linear",
"ties",
"dare_linear",
"dare_ties",
"magnitude_prune",
]:
target_lora_A.data, target_lora_B.data = (
self._generalized_task_arithmetic_weighted_adapter(
combination_type,
adapters,
weights,
target,
density,
majority_sign_method,
)
)
def _svd_generalized_task_arithmetic_weighted_adapter(
@ -752,21 +850,29 @@ class LoraModel(BaseTuner):
# if no valid adapter, nothing to do
if len(valid_adapters) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on Github.")
raise ValueError(
"No matching LoRAs found. Please raise an issue on Github."
)
delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters]
valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device)
if combination_type == "svd":
delta_weight = task_arithmetic(delta_weight, valid_weights)
elif combination_type == "ties_svd":
delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method)
delta_weight = ties(
delta_weight, valid_weights, density, majority_sign_method
)
elif combination_type == "dare_linear_svd":
delta_weight = dare_linear(delta_weight, valid_weights, density)
elif combination_type == "dare_ties_svd":
delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method)
delta_weight = dare_ties(
delta_weight, valid_weights, density, majority_sign_method
)
elif combination_type == "magnitude_prune_svd":
delta_weight = magnitude_prune(delta_weight, valid_weights, density)
else:
raise ValueError(f"Invalid value passed to combination type: {combination_type}")
raise ValueError(
f"Invalid value passed to combination type: {combination_type}"
)
conv2d = isinstance(target, Conv2d)
if conv2d:
@ -775,11 +881,15 @@ class LoraModel(BaseTuner):
delta_weight = delta_weight.flatten(start_dim=1)
else:
delta_weight = delta_weight.squeeze()
if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding:
if (
hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out
) or is_embedding:
delta_weight = delta_weight.T
# based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131
U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver)
U, S, Vh = torch.linalg.svd(
delta_weight, full_matrices=full_matrices, driver=driver
)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
@ -827,11 +937,15 @@ class LoraModel(BaseTuner):
if combination_type == "linear":
lora_deltas[i] = task_arithmetic(task_tensors, valid_weights)
elif combination_type == "ties":
lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method)
lora_deltas[i] = ties(
task_tensors, valid_weights, density, majority_sign_method
)
elif combination_type == "dare_linear":
lora_deltas[i] = dare_linear(task_tensors, valid_weights, density)
elif combination_type == "dare_ties":
lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method)
lora_deltas[i] = dare_ties(
task_tensors, valid_weights, density, majority_sign_method
)
elif combination_type == "magnitude_prune":
lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density)
else:
@ -850,7 +964,9 @@ class LoraModel(BaseTuner):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
key_list = [
key for key, _ in self.model.named_modules() if self.prefix not in key
]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
@ -862,7 +978,10 @@ class LoraModel(BaseTuner):
self.active_adapter = new_adapter or []
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
self,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
) -> torch.nn.Module:
r"""
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
@ -900,7 +1019,9 @@ class LoraModel(BaseTuner):
"""
return self._unload_and_optionally_merge(merge=False)
def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None):
def subtract_mutated_init(
self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None
):
"""
This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA |
OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus
@ -929,11 +1050,19 @@ class LoraModel(BaseTuner):
## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'.
if "lora_A" in name:
tensors_lora[name] = torch.cat(
[output_state_dict[name], mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=0
[
output_state_dict[name],
mutated_init_state_dict[".".join(name.split(".")[1:])],
],
dim=0,
)
elif "lora_B" in name:
tensors_lora[name] = torch.cat(
[output_state_dict[name], -mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=1
[
output_state_dict[name],
-mutated_init_state_dict[".".join(name.split(".")[1:])],
],
dim=1,
)
return tensors_lora

View File

@ -33,7 +33,9 @@ class TorchaoLoraLinear(Linear):
# this is not strictly necessary, as kwargs are stored either way, but we want to error early if
# get_apply_tensor_subclass is missing.
if kwargs.get("lora_bias", False):
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support lora_bias yet, set it to False"
)
super().__init__(*args, **kwargs)
self.get_apply_tensor_subclass = get_apply_tensor_subclass
@ -43,10 +45,16 @@ class TorchaoLoraLinear(Linear):
# TODO: Not required once int4_weight_only is properly supported by torchao
base_layer = self.get_base_layer()
weight = base_layer.weight
if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8):
raise ValueError(f"{type(self).__name__} only supports int8 weights for now.")
if hasattr(weight, "layout_tensor") and (
weight.layout_tensor.data.dtype != torch.int8
):
raise ValueError(
f"{type(self).__name__} only supports int8 weights for now."
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
from torchao import quantize_
adapter_names = check_adapters_to_merge(self, adapter_names)
@ -143,7 +151,10 @@ def dispatch_torchao(
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor
if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)):
if isinstance(
target_base_layer.weight,
(AffineQuantizedTensor, LinearActivationQuantizedTensor),
):
new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)
return new_module

View File

@ -54,13 +54,17 @@ class LoraParallelLinear(nn.Module, LoraLayer):
**kwargs,
):
if lora_bias:
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support lora_bias yet, set it to False"
)
super().__init__()
LoraLayer.__init__(self, base_layer=base_layer, **kwargs)
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
raise ValueError(
f"{self.__class__.__name__} does not support DoRA yet, please set it to False"
)
self.backend = backend
self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear)
@ -113,7 +117,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
**parallel_linear_kwargs,
):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
raise ValueError(
f"`r` should be a positive integer value but the value passed is {r}"
)
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
@ -136,9 +142,19 @@ class LoraParallelLinear(nn.Module, LoraLayer):
init_method=init_method,
config=megatron_config,
)
lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32)
lora_b = nn.Linear(
in_features=r,
out_features=self.out_features,
bias=False,
dtype=torch.float32,
)
else:
lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32)
lora_a = nn.Linear(
in_features=self.in_features,
out_features=r,
bias=False,
dtype=torch.float32,
)
lora_b = self.backend.ColumnParallelLinear(
input_size=r,
output_size=self.out_features,
@ -158,7 +174,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
elif (
isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora"
):
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
elif init_lora_weights == "loftq":
@ -189,7 +207,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
self.unmerge()
result, bias = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.")
raise ValueError(
f"{self.__class__.__name__} does not support mixed_batch_forward yet."
)
elif self.merged:
result, bias = self.base_layer(x, *args, **kwargs)
else:
@ -225,7 +245,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
result = result.to(torch_result_dtype)
return result, bias
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights into the base weights
@ -258,15 +280,24 @@ class LoraParallelLinear(nn.Module, LoraLayer):
# since delta_weight already includes scaling, set it to 1 here
weight_norm = (
self.lora_magnitude_vector[active_adapter]
.get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1)
.get_weight_norm(
orig_weights,
transpose(delta_weight, self.fan_in_fan_out),
scaling=1,
)
.detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight
/ weight_norm
)
dora_factor = transpose(
dora_factor.view(-1, 1), self.fan_in_fan_out
)
orig_weights = dora_factor * (orig_weights + delta_weight)
if not torch.isfinite(orig_weights).all():
@ -285,7 +316,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
weight_norm = (
self.lora_magnitude_vector[active_adapter]
.get_weight_norm(
base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1
base_layer.weight,
transpose(delta_weight, self.fan_in_fan_out),
scaling=1,
)
.detach()
)
@ -293,9 +326,16 @@ class LoraParallelLinear(nn.Module, LoraLayer):
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
new_weight = dora_factor * (base_layer.weight.data + delta_weight)
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight
/ weight_norm
)
dora_factor = transpose(
dora_factor.view(-1, 1), self.fan_in_fan_out
)
new_weight = dora_factor * (
base_layer.weight.data + delta_weight
)
base_layer.weight.data = new_weight
self.merged_adapters.append(active_adapter)
@ -316,7 +356,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
weight.data -= delta_weight
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = (
self.lora_magnitude_vector[active_adapter].weight / weight_norm
)
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight
weight.data = weight_orig
@ -334,7 +376,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
cast_to_fp32 = device.type == "cpu" and (
dtype == torch.float16 or dtype == torch.bfloat16
)
weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight
@ -343,7 +387,9 @@ class LoraParallelLinear(nn.Module, LoraLayer):
weight_A = weight_A.float()
weight_B = weight_B.float()
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
output_tensor = (
transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
)
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
@ -379,12 +425,17 @@ def dispatch_megatron(
if megatron_core and isinstance(
target_base_layer,
(megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear),
(
megatron_core.tensor_parallel.ColumnParallelLinear,
megatron_core.tensor_parallel.RowParallelLinear,
),
):
megatron_kwargs = kwargs.copy()
megatron_config = lora_config.megatron_config
if isinstance(megatron_config, dict):
transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig
transformer_config_class = (
megatron_core.transformer.transformer_config.TransformerConfig
)
megatron_config = transformer_config_class(**lora_config.megatron_config)
megatron_kwargs["megatron_config"] = megatron_config
if megatron_kwargs["fan_in_fan_out"]:
@ -395,7 +446,10 @@ def dispatch_megatron(
)
megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
new_module = LoraParallelLinear(
base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs
base_layer=target,
adapter_name=adapter_name,
backend=megatron_core.tensor_parallel,
**megatron_kwargs,
)
return new_module

View File

@ -1,6 +1,7 @@
import sys
sys.path.insert(0, "./transformers_repo/src/")
sys.path.insert(0, "./peft_repo/src/")
sys.path.insert(0, "transformers_repo/src/")
sys.path.insert(0, "peft_repo/src/")
from dataset_library.factory import get_dataset
@ -50,7 +51,6 @@ if __name__ == "__main__":
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
# model = get_peft_model(model, peft_config)
model.add_adapter(peft_config)
elif model_args.peft_type == "LORA":
@ -58,12 +58,8 @@ if __name__ == "__main__":
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
# model = get_peft_model(model, peft_config)
model.add_adapter(peft_config)
# if accelerator.is_local_main_process:
# model.print_trainable_parameters()
else:
peft_config = None

View File

@ -3,7 +3,7 @@
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name CHEM \
--use_peft \
--peft_type MMOELORA \
--peft_type LORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \
@ -12,4 +12,4 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml
--output_dir checkpoint/qwen2mmoe/ \
--bf16 \
--torch_dtype bfloat16 \
--logging_steps 30
--logging_steps 30

View File

@ -17,4 +17,4 @@ class ContinualScriptArguments(ScriptArguments):
class ContinualModelConfig(ModelConfig):
"""Model configuration for continual learning."""
peft_type: Optional[str] = None
peft_type: Optional[str] = None

68
uv.lock generated
View File

@ -196,6 +196,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 },
]
[[package]]
name = "cfgv"
version = "3.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 },
]
[[package]]
name = "charset-normalizer"
version = "3.4.1"
@ -260,6 +269,7 @@ dependencies = [
{ name = "numba" },
{ name = "peft" },
{ name = "pip" },
{ name = "pre-commit" },
{ name = "requests" },
{ name = "rouge-score" },
{ name = "safetensors" },
@ -292,6 +302,7 @@ requires-dist = [
{ name = "numba", specifier = ">=0.60.0" },
{ name = "peft", specifier = "==0.14.0" },
{ name = "pip", specifier = "==24.3.1" },
{ name = "pre-commit", specifier = ">=4.0.1" },
{ name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" },
{ name = "rouge-score", specifier = ">=0.1.2" },
{ name = "safetensors", specifier = ">=0.5.2" },
@ -388,6 +399,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252 },
]
[[package]]
name = "distlib"
version = "0.3.9"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 },
]
[[package]]
name = "einops"
version = "0.8.0"
@ -533,6 +553,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6c/3f/50f6b25fafdcfb1c089187a328c95081abf882309afd86f4053951507cd1/huggingface_hub-0.27.1-py3-none-any.whl", hash = "sha256:1c5155ca7d60b60c2e2fc38cbb3ffb7f7c3adf48f824015b219af9061771daec", size = 450658 },
]
[[package]]
name = "identify"
version = "2.6.5"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/cf/92/69934b9ef3c31ca2470980423fda3d00f0460ddefdf30a67adf7f17e2e00/identify-2.6.5.tar.gz", hash = "sha256:c10b33f250e5bba374fae86fb57f3adcebf1161bce7cdf92031915fd480c13bc", size = 99213 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/fa/dce098f4cdf7621aa8f7b4f919ce545891f489482f0bfa5102f3eca8608b/identify-2.6.5-py2.py3-none-any.whl", hash = "sha256:14181a47091eb75b337af4c23078c9d09225cd4c48929f521f3bf16b09d02566", size = 99078 },
]
[[package]]
name = "idna"
version = "3.10"
@ -825,6 +854,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 },
]
[[package]]
name = "nodeenv"
version = "1.9.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 },
]
[[package]]
name = "numba"
version = "0.60.0"
@ -1147,6 +1185,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574 },
]
[[package]]
name = "pre-commit"
version = "4.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cfgv" },
{ name = "identify" },
{ name = "nodeenv" },
{ name = "pyyaml" },
{ name = "virtualenv" },
]
sdist = { url = "https://files.pythonhosted.org/packages/2e/c8/e22c292035f1bac8b9f5237a2622305bc0304e776080b246f3df57c4ff9f/pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2", size = 191678 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/16/8f/496e10d51edd6671ebe0432e33ff800aa86775d2d147ce7d43389324a525/pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878", size = 218713 },
]
[[package]]
name = "propcache"
version = "0.2.1"
@ -1849,6 +1903,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
]
[[package]]
name = "virtualenv"
version = "20.29.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "distlib" },
{ name = "filelock" },
{ name = "platformdirs" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a7/ca/f23dcb02e161a9bba141b1c08aa50e8da6ea25e6d780528f1d385a3efe25/virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35", size = 7658028 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/89/9b/599bcfc7064fbe5740919e78c5df18e5dceb0887e676256a1061bb5ae232/virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779", size = 4282379 },
]
[[package]]
name = "wheel"
version = "0.45.1"