diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2301b08 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: + # Using this mirror lets us use mypyc-compiled black, which is about 2x faster + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.8.0 + hooks: + - id: black + # It is recommended to specify the latest version of Python + # supported by your project here, or alternatively use + # pre-commit's default_language_version, see + # https://pre-commit.com/#top_level-default_language_version + language_version: python3.11 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 131d5c0..5b87ceb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "numba>=0.60.0", "peft==0.14.0", "pip==24.3.1", + "pre-commit>=4.0.1", "requests==2.32.3", "rouge-score>=0.1.2", "safetensors>=0.5.2", @@ -30,16 +31,14 @@ requires-python = ">=3.11" version = "0.1.0" [project.optional-dependencies] -compile = [ - "flash-attn>=2.7.2.post1", -] +compile = ["flash-attn>=2.7.2.post1"] [tool.uv.sources] -markupsafe = {index = "pytorch"} -requests = {index = "pypi"} -torch = {index = "pytorch"} -torchaudio = {index = "pytorch"} -torchvision = {index = "pytorch"} +markupsafe = { index = "pytorch" } +requests = { index = "pypi" } +torch = { index = "pytorch" } +torchaudio = { index = "pytorch" } +torchvision = { index = "pytorch" } [[tool.uv.index]] name = "pypi" @@ -52,3 +51,7 @@ concurrent-builds = 4 [[tool.uv.index]] name = "pytorch" url = "https://download.pytorch.org/whl/cu124" + +[tool.black] +line-length = 88 +exclude = "transformers_repo|peft_repo|.venv" diff --git a/src/dataset_library/CHEM.py b/src/dataset_library/CHEM.py index 357f584..92b5e9b 100644 --- a/src/dataset_library/CHEM.py +++ b/src/dataset_library/CHEM.py @@ -25,9 +25,9 @@ class CHEMDataset(Dataset): def _vis_processor(self, image: Image.Image): width, height = image.size - if width > 600 or height > 600: + if width > 800 or height > 800: max_size = max(width, height) - ratio = 600 / max_size + ratio = 800 / max_size new_width = int(width * ratio) new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) @@ -69,7 +69,7 @@ class CHEMDataset(Dataset): return processed_data def __len__(self): - return len(self.data) // 60 + return len(self.data) def __getitem__(self, index): sample = self.data[index] @@ -147,12 +147,14 @@ class CHEMDatasetForGeneration(CHEMDataset): ], }, ] - return { - "image": image, - "chat": chat, - "answer": answer, - "original": sample["original"], - } + from .format import create_generate + + return create_generate( + images=[image], + chat=chat, + answer=answer, + original=sample["original"], + ) if __name__ == "__main__": diff --git a/src/dataset_library/GigaspeechDataset.py b/src/dataset_library/GigaspeechDataset.py index e03be1b..d92fd1b 100644 --- a/src/dataset_library/GigaspeechDataset.py +++ b/src/dataset_library/GigaspeechDataset.py @@ -75,12 +75,14 @@ class GigaspeechDatasetForGeneration(GigaspeechDataset): ], }, ] + from .format import create_generate - return { - "audio": (audio, sampling_rate), - "chat": chat, - "answer": text, - } + return create_generate( + audio=[(audio, sampling_rate)], + chat=chat, + answer=text, + original=sample, + ) if __name__ == "__main__": diff --git a/src/evaluation.py b/src/evaluation.py index 1a0a16f..5898e51 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -1,4 +1,5 @@ import sys + sys.path.insert(0, "./transformers_repo/src/") sys.path.insert(0, "./peft_repo/src/") diff --git a/src/peft_library/__init__.py b/src/peft_library/__init__.py index b32e81d..0a9f545 100644 --- a/src/peft_library/__init__.py +++ b/src/peft_library/__init__.py @@ -1 +1 @@ -from .mapping import get_peft_config, get_peft_model, inject_adapter_in_model \ No newline at end of file +from .mapping import get_peft_config, get_peft_model, inject_adapter_in_model diff --git a/src/peft_library/tuners/lora/aqlm.py b/src/peft_library/tuners/lora/aqlm.py index 1715e62..a9709a3 100644 --- a/src/peft_library/tuners/lora/aqlm.py +++ b/src/peft_library/tuners/lora/aqlm.py @@ -40,7 +40,9 @@ class AqlmLoraLinear(torch.nn.Module, LoraLayer): **kwargs, ): if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support DoRA yet, please set it to False" + ) super().__init__() LoraLayer.__init__(self, base_layer) diff --git a/src/peft_library/tuners/lora/awq.py b/src/peft_library/tuners/lora/awq.py index 86989d9..afa2c39 100644 --- a/src/peft_library/tuners/lora/awq.py +++ b/src/peft_library/tuners/lora/awq.py @@ -37,7 +37,9 @@ class AwqLoraLinear(torch.nn.Module, LoraLayer): **kwargs, ): if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support DoRA yet, please set it to False" + ) super().__init__() LoraLayer.__init__(self, base_layer) @@ -107,7 +109,9 @@ def dispatch_awq( if isinstance(target_base_layer, WQLinear_GEMM): # Raise the error only at the dispatch level AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") - version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) + version_autoawq = packaging.version.parse( + importlib_metadata.version("autoawq") + ) if AUTOAWQ_MINIMUM_VERSION > version_autoawq: raise ImportError( diff --git a/src/peft_library/tuners/lora/bnb.py b/src/peft_library/tuners/lora/bnb.py index fbb1c71..38024d3 100644 --- a/src/peft_library/tuners/lora/bnb.py +++ b/src/peft_library/tuners/lora/bnb.py @@ -60,7 +60,9 @@ if is_bnb_available(): lora_bias=lora_bias, ) - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights into the base weights @@ -109,7 +111,9 @@ if is_bnb_available(): # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight / weight_norm + ) w_data = dora_factor.view(-1, 1) * (output + lora_data) if safe_merge and not torch.isfinite(w_data).all(): @@ -118,10 +122,15 @@ if is_bnb_available(): ) self.get_base_layer().weight = bnb.nn.Int8Params( - w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + w_data.to("cpu"), + requires_grad=False, + has_fp16_weights=weight.has_fp16_weights, ).to(weight.device) if self.lora_bias[active_adapter]: - bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias + bias_data = ( + self.get_base_layer().bias.data + + self.lora_B[active_adapter].bias + ) if safe_merge and not torch.isfinite(bias_data): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" @@ -158,11 +167,15 @@ if is_bnb_available(): w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data else: weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight / weight_norm + ) w_data = output.data / dora_factor.view(-1, 1) - lora_data self.get_base_layer().weight = bnb.nn.Int8Params( - w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + w_data.to("cpu"), + requires_grad=False, + has_fp16_weights=weight.has_fp16_weights, ).to(weight.device) if self.lora_bias[active_adapter]: @@ -188,7 +201,13 @@ if is_bnb_available(): unique_adapters = set(adapter_names) sub_batch_indices_list = [] for adapter in unique_adapters: - sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + sub_batch_indices_list.append( + [ + index + for index, item in enumerate(adapter_names) + if item == adapter + ] + ) for i, active_adapter in enumerate(unique_adapters): if active_adapter == "__base__": @@ -227,7 +246,9 @@ if is_bnb_available(): self.unmerge() result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: - result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: @@ -330,7 +351,9 @@ if is_bnb_4bit_available(): lora_bias=lora_bias, ) - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights into the base weights @@ -375,7 +398,9 @@ if is_bnb_4bit_available(): # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight / weight_norm + ) w_data = dora_factor.view(-1, 1) * (output + lora_data) if safe_merge and not torch.isfinite(w_data).all(): @@ -386,9 +411,14 @@ if is_bnb_4bit_available(): kwargs["bnb_quantized"] = False kwargs["requires_grad"] = False kwargs.pop("data", None) - self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) + self.get_base_layer().weight = bnb.nn.Params4bit( + w_data.to("cpu"), **kwargs + ).to(weight.device) if self.lora_bias[active_adapter]: - bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias + bias_data = ( + self.get_base_layer().bias.data + + self.lora_B[active_adapter].bias + ) if safe_merge and not torch.isfinite(bias_data): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" @@ -422,14 +452,18 @@ if is_bnb_4bit_available(): w_data = output - lora_data else: weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight / weight_norm + ) w_data = output.data / dora_factor.view(-1, 1) - lora_data if "bnb_quantized" in kwargs: kwargs["bnb_quantized"] = False kwargs["requires_grad"] = False kwargs.pop("data", None) - self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) + self.get_base_layer().weight = bnb.nn.Params4bit( + w_data.to("cpu"), **kwargs + ).to(weight.device) if self.lora_bias[active_adapter]: self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias @@ -452,7 +486,13 @@ if is_bnb_4bit_available(): unique_adapters = set(adapter_names) sub_batch_indices_list = [] for adapter in unique_adapters: - sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + sub_batch_indices_list.append( + [ + index + for index, item in enumerate(adapter_names) + if item == adapter + ] + ) for i, active_adapter in enumerate(unique_adapters): if active_adapter == "__base__": @@ -489,7 +529,9 @@ if is_bnb_4bit_available(): self.unmerge() result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: - result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: @@ -550,7 +592,11 @@ if is_bnb_4bit_available(): target_base_layer = target loaded_in_4bit = kwargs.get("loaded_in_4bit", False) - if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): + if ( + loaded_in_4bit + and is_bnb_4bit_available() + and isinstance(target_base_layer, bnb.nn.Linear4bit) + ): fourbit_kwargs = kwargs.copy() fourbit_kwargs.update( { diff --git a/src/peft_library/tuners/lora/config.py b/src/peft_library/tuners/lora/config.py index e11ad7f..955e203 100644 --- a/src/peft_library/tuners/lora/config.py +++ b/src/peft_library/tuners/lora/config.py @@ -66,7 +66,9 @@ class LoftQConfig: """ loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"}) - loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"}) + loftq_iter: int = field( + default=1, metadata={"help": "Alternating iterations for LoftQ"} + ) @dataclass @@ -101,13 +103,25 @@ class EvaConfig: are adjusted so that all LoRA gradients have the same scale regardless of their rank. Default is True. """ - rho: float = field(default=2.0, metadata={"help": "Rho value for EVA redistribution"}) - tau: float = field(default=0.99, metadata={"help": "Cosine similarity threshold for early stopping"}) - use_label_mask: bool = field(default=True, metadata={"help": "Use label mask for EVA initialization"}) - label_mask_value: int = field( - default=-100, metadata={"help": "if use_label_mask=True the value to look for to mask out ignored tokens"} + rho: float = field( + default=2.0, metadata={"help": "Rho value for EVA redistribution"} + ) + tau: float = field( + default=0.99, + metadata={"help": "Cosine similarity threshold for early stopping"}, + ) + use_label_mask: bool = field( + default=True, metadata={"help": "Use label mask for EVA initialization"} + ) + label_mask_value: int = field( + default=-100, + metadata={ + "help": "if use_label_mask=True the value to look for to mask out ignored tokens" + }, + ) + whiten: bool = field( + default=False, metadata={"help": "Apply whitening to singular vectors"} ) - whiten: bool = field(default=False, metadata={"help": "Apply whitening to singular vectors"}) adjust_scaling_factors: bool = field( default=True, metadata={"help": "Adjust LoRA scaling factors after the rank redistribution"}, @@ -233,16 +247,21 @@ class LoraConfig(PeftConfig): ) exclude_modules: Optional[Union[list[str], str]] = field( default=None, - metadata={"help": "List of module names or regex expression of the module names to exclude from Lora."}, + metadata={ + "help": "List of module names or regex expression of the module names to exclude from Lora." + }, ) lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"}) lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"}) fan_in_fan_out: bool = field( default=False, - metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + metadata={ + "help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)" + }, ) bias: Literal["none", "all", "lora_only"] = field( - default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"} + default="none", + metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}, ) use_rslora: bool = field( default=False, @@ -264,7 +283,15 @@ class LoraConfig(PeftConfig): }, ) init_lora_weights: ( - bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"] + bool + | Literal[ + "gaussian", + "eva", + "olora", + "pissa", + "pissa_niter_[number of iters]", + "loftq", + ] ) = field( default=True, metadata={ @@ -418,47 +445,72 @@ class LoraConfig(PeftConfig): super().__post_init__() self.peft_type = PeftType.LORA self.target_modules = ( - set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + set(self.target_modules) + if isinstance(self.target_modules, list) + else self.target_modules ) self.exclude_modules = ( - set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules + set(self.exclude_modules) + if isinstance(self.exclude_modules, list) + else self.exclude_modules ) # if target_modules is a regex expression, then layers_to_transform should be None - if isinstance(self.target_modules, str) and self.layers_to_transform is not None: - raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + if ( + isinstance(self.target_modules, str) + and self.layers_to_transform is not None + ): + raise ValueError( + "`layers_to_transform` cannot be used when `target_modules` is a str." + ) # if target_modules is a regex expression, then layers_pattern should be None if isinstance(self.target_modules, str) and self.layers_pattern is not None: - raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + raise ValueError( + "`layers_pattern` cannot be used when `target_modules` is a str." + ) # check for layers_to_transform and layers_pattern if self.layers_pattern and not self.layers_to_transform: - raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") + raise ValueError( + "When `layers_pattern` is specified, `layers_to_transform` must also be specified. " + ) if self.use_dora and self.megatron_config: - raise ValueError("DoRA does not support megatron_core, please set `use_dora=False`.") + raise ValueError( + "DoRA does not support megatron_core, please set `use_dora=False`." + ) # handle init_lora_weights and loftq_config if self.init_lora_weights == "loftq": import importlib if not importlib.util.find_spec("scipy"): - raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + raise ImportError( + "The required package 'scipy' is not installed. Please install it to continue." + ) if not self.loftq_config: - raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") + raise ValueError( + "`loftq_config` must be specified when `init_lora_weights` is 'loftq'." + ) if not isinstance(self.loftq_config, dict): # convert loftq_config to dict self.loftq_config = vars(self.loftq_config) elif self.loftq_config: self.loftq_config = {} - warnings.warn("`loftq_config` specified but will be ignored when `init_lora_weights` is not 'loftq'.") + warnings.warn( + "`loftq_config` specified but will be ignored when `init_lora_weights` is not 'loftq'." + ) elif self.init_lora_weights == "eva" and self.eva_config is None: - warnings.warn("`init_lora_weights` is 'eva' but `eva_config` is not specified. Using default EVA config.") + warnings.warn( + "`init_lora_weights` is 'eva' but `eva_config` is not specified. Using default EVA config." + ) self.eva_config = EvaConfig() elif self.init_lora_weights != "eva" and self.eva_config is not None: - warnings.warn("`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'.") + warnings.warn( + "`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'." + ) if self.lora_bias: if self.init_lora_weights not in (True, False): @@ -467,7 +519,9 @@ class LoraConfig(PeftConfig): f"init_lora_weights={self.init_lora_weights} instead." ) if self.use_dora: - raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False") + raise ValueError( + "The argument lora_bias=True is not supported for DoRA, please pass use_dora=False" + ) # Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot # be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends @@ -477,7 +531,10 @@ class LoraConfig(PeftConfig): self.use_rslora and (self.rank_pattern or self.alpha_pattern) and ( - (isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa"))) + ( + isinstance(self.init_lora_weights, str) + and (self.init_lora_weights.startswith("pissa")) + ) or (self.init_lora_weights == "olora") ) ): @@ -491,7 +548,9 @@ class LoraConfig(PeftConfig): self._custom_modules: Optional[dict[type[nn.Mmodule], type[nn.Module]]] = None - def _register_custom_module(self, mapping: dict[type[nn.Mmodule], type[nn.Module]]) -> None: + def _register_custom_module( + self, mapping: dict[type[nn.Mmodule], type[nn.Module]] + ) -> None: """ Experimental API to support providing custom LoRA layers. diff --git a/src/peft_library/tuners/lora/dora.py b/src/peft_library/tuners/lora/dora.py index 9d8cd9a..1c2a256 100644 --- a/src/peft_library/tuners/lora/dora.py +++ b/src/peft_library/tuners/lora/dora.py @@ -34,7 +34,9 @@ class DoraLinearLayer(nn.Module): weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) return weight_norm - def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None: + def update_layer( + self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False + ) -> None: # temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2 dtype_is_fp16 = lora_A.dtype == torch.float16 if dtype_is_fp16: @@ -49,14 +51,18 @@ class DoraLinearLayer(nn.Module): weight = dequantize_module_weight(base_layer) if weight.data.ndim >= 4: # For handling LoRAs applied to Conv layers. - lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1)) + lora_weight = torch.mm( + lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1) + ) lora_weight = lora_weight.reshape(weight.shape) else: lora_weight = lora_B @ lora_A if dtype_is_fp16: lora_weight = lora_weight.half() - weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling) + weight_norm = self.get_weight_norm( + weight.to(lora_A.device), lora_weight, scaling + ) if place_on_cpu: weight_norm = weight_norm.to("cpu") @@ -69,7 +75,9 @@ class DoraLinearLayer(nn.Module): """ # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, # calculate the same but using forward. - x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype) + x_eye = torch.eye( + lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype + ) lora_weight = lora_B(lora_A(x_eye)).T magnitude = self.weight @@ -95,7 +103,9 @@ class DoraLinearLayer(nn.Module): else: base_result = F.linear(x, transpose(weight, self.fan_in_fan_out)) - result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling + result_dora = ( + mag_norm_scale - 1 + ) * base_result + mag_norm_scale * lora_result * scaling return result_dora @@ -145,7 +155,9 @@ class _DoraConvNdLayer(DoraLinearLayer): output. """ weight = base_layer.weight - lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = torch.mm( + lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1) + ) lora_weight = lora_weight.reshape(weight.shape) magnitude = self.weight weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) diff --git a/src/peft_library/tuners/lora/eetq.py b/src/peft_library/tuners/lora/eetq.py index d1eb79b..ef10297 100644 --- a/src/peft_library/tuners/lora/eetq.py +++ b/src/peft_library/tuners/lora/eetq.py @@ -38,7 +38,9 @@ if is_eetq_available(): **kwargs, ): if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support DoRA yet, please set it to False" + ) super().__init__() LoraLayer.__init__(self, base_layer) @@ -85,11 +87,17 @@ if is_eetq_available(): result = result + output return result - def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: - raise AttributeError("Merging LoRA layers is not supported for Eetq layers.") + def merge( + self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None + ) -> None: + raise AttributeError( + "Merging LoRA layers is not supported for Eetq layers." + ) def unmerge(self) -> None: - raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.") + raise AttributeError( + "Unmerging LoRA layers is not supported for Eetq layers." + ) def __repr__(self) -> str: rep = super().__repr__() diff --git a/src/peft_library/tuners/lora/eva.py b/src/peft_library/tuners/lora/eva.py index 8512dfd..8f7418a 100644 --- a/src/peft_library/tuners/lora/eva.py +++ b/src/peft_library/tuners/lora/eva.py @@ -26,7 +26,10 @@ import torch.distributed as dist from tqdm import tqdm from transformers.pytorch_utils import Conv1D -from peft.tuners.tuners_utils import _find_minimal_target_modules, check_target_module_exists +from peft.tuners.tuners_utils import ( + _find_minimal_target_modules, + check_target_module_exists, +) from peft.utils.constants import MIN_TARGET_MODULES_FOR_OPTIMIZATION from peft.utils.incremental_pca import IncrementalPCA from peft.utils.other import _get_submodules, get_pattern_key @@ -58,7 +61,9 @@ class _Hook: self.model_input = None @staticmethod - def _prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> torch.Tensor: + def _prepare_layer_inputs_fn_default( + layer_input, model_input, layer_name + ) -> torch.Tensor: if isinstance(layer_input, torch.Tensor): pass elif isinstance(layer_input, (tuple, list)): @@ -83,20 +88,28 @@ class _Hook: # First gather sizes from all processes more efficiently local_size = torch.tensor([layer_input.shape[0]], device=layer_input.device) - all_sizes = torch.empty(world_size, dtype=local_size.dtype, device=layer_input.device) + all_sizes = torch.empty( + world_size, dtype=local_size.dtype, device=layer_input.device + ) dist.all_gather_into_tensor(all_sizes, local_size) all_sizes = all_sizes.tolist() # Find maximum size and pad tensors - padded_input = layer_input.new_zeros((max(all_sizes), *layer_input.shape[1:])) + padded_input = layer_input.new_zeros( + (max(all_sizes), *layer_input.shape[1:]) + ) padded_input[: layer_input.shape[0]] = layer_input # Gather padded tensors - gathered_inputs = [torch.zeros_like(padded_input) for _ in range(world_size)] + gathered_inputs = [ + torch.zeros_like(padded_input) for _ in range(world_size) + ] dist.all_gather(gathered_inputs, padded_input.contiguous()) # Remove padding for each gathered tensor - gathered_inputs = [tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes)] + gathered_inputs = [ + tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes) + ] # Concatenate along batch dimension return torch.cat(gathered_inputs, dim=0) @@ -153,7 +166,9 @@ class SVDHook(_Hook): states = self.gather_layer_inputs(states) # check if batch sizes is more than the number of components if states.size(0) < self.n_components: - print(f"skipping SVD for {self.name} because there are less than {self.n_components} examples") + print( + f"skipping SVD for {self.name} because there are less than {self.n_components} examples" + ) return self.svd.partial_fit(states.to(torch.float32)) # add if statement to check if we are in the first step where previous_components is None @@ -218,7 +233,9 @@ def get_device_with_meta_params(model: torch.nn.Module) -> torch.device: """ devices = list({p.device for p in model.parameters() if p.device.type != "meta"}) if len(devices) > 1: - warnings.warn(f"Could not determine device, model has multiple devices: {devices}") + warnings.warn( + f"Could not determine device, model has multiple devices: {devices}" + ) return return devices[0] @@ -230,11 +247,15 @@ def move_inputs_to_device(inputs, device: Union[str, torch.device]): if hasattr(inputs, "to"): return inputs.to(device) if isinstance(inputs, Mapping): - return type(inputs)({k: move_inputs_to_device(v, device) for k, v in inputs.items()}) + return type(inputs)( + {k: move_inputs_to_device(v, device) for k, v in inputs.items()} + ) elif isinstance(inputs, (tuple, list)): return type(inputs)(move_inputs_to_device(v, device) for v in inputs) else: - warnings.warn(f"input of type {type(inputs)} could not be moved to the correct device") + warnings.warn( + f"input of type {type(inputs)} could not be moved to the correct device" + ) return inputs @@ -247,14 +268,22 @@ def prepare_model_inputs_fn_language_modeling(model_input, peft_config: LoraConf peft_config (LoraConfig): The configuration for the LoRA layers. """ if not isinstance(model_input, dict): - raise ValueError("When using `prepare_model_inputs_fn_language_modeling` inputs must be a dictionary") - mask = model_input.get("attention_mask", torch.ones_like(model_input["input_ids"])).bool() + raise ValueError( + "When using `prepare_model_inputs_fn_language_modeling` inputs must be a dictionary" + ) + mask = model_input.get( + "attention_mask", torch.ones_like(model_input["input_ids"]) + ).bool() if peft_config.eva_config.use_label_mask and hasattr(model_input, "labels"): - mask = torch.logical_and(mask, model_input["labels"] != peft_config.eva_config.label_mask_value) + mask = torch.logical_and( + mask, model_input["labels"] != peft_config.eva_config.label_mask_value + ) return mask.nonzero() -def prepare_layer_inputs_fn_language_modeling(layer_input, model_input, layer_name) -> torch.Tensor: +def prepare_layer_inputs_fn_language_modeling( + layer_input, model_input, layer_name +) -> torch.Tensor: """ if not all items in the input should be used for SVD, this function can be used to get the indices of the items that should be used. @@ -299,12 +328,21 @@ def _get_eva_state_dict( ) -> dict: # Computes the rank distribution for each layer based on the explained variance ratio. # when rank_pattern flag is False, all values in max_components are the same - def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components): - exp_vars = {k: h[0].svd.explained_variance_ratio_[: max_components[k]] for k, h in hooks.items()} - keys, values = zip(*[(k, c) for k, name in layer_hook_map.items() for c in exp_vars[name]]) + def _get_rank_distribution( + hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components + ): + exp_vars = { + k: h[0].svd.explained_variance_ratio_[: max_components[k]] + for k, h in hooks.items() + } + keys, values = zip( + *[(k, c) for k, name in layer_hook_map.items() for c in exp_vars[name]] + ) idx = torch.stack(values).argsort(descending=True) counts = Counter([keys[i] for i in idx[:rank_budget]]) - counts = {k: counts.get(k, 0) for k in layer_hook_map.keys()} # add layers with 0 rank + counts = { + k: counts.get(k, 0) for k in layer_hook_map.keys() + } # add layers with 0 rank for k, k_hook in equal_inputs_map.items(): # ensure hook layers have the highest rank if they are equal to another layer rank, rank_hook = counts[k], counts[k_hook] @@ -356,7 +394,11 @@ def _get_eva_state_dict( fn = prepare_layer_inputs_fn.pop(name, None) else: fn = prepare_layer_inputs_fn - hook = HashHook(name=name, prepare_layer_inputs_fn=fn, gather_distributed_inputs=gather_distributed_inputs) + hook = HashHook( + name=name, + prepare_layer_inputs_fn=fn, + gather_distributed_inputs=gather_distributed_inputs, + ) hook.model_input = model_inputs_for_hooks handle = module.register_forward_hook(hook) hooks[name] = (hook, handle) @@ -365,7 +407,10 @@ def _get_eva_state_dict( ) max_components[name] = round(layer_rank * rho) rank_budget += layer_rank - if isinstance(prepare_layer_inputs_fn, Mapping) and len(prepare_layer_inputs_fn) > 0: + if ( + isinstance(prepare_layer_inputs_fn, Mapping) + and len(prepare_layer_inputs_fn) > 0 + ): raise ValueError( "prepare_layer_inputs_fn is a mapping but the following module names were not found in the model: " f"{prepare_layer_inputs_fn.keys()}" @@ -399,7 +444,10 @@ def _get_eva_state_dict( ) module = model.get_submodule(name) handle = module.register_forward_hook(hook) - hooks[name] = (hook, handle) # adding the old handle here so we dont get errors in the first forward pass + hooks[name] = ( + hook, + handle, + ) # adding the old handle here so we dont get errors in the first forward pass layer_hook_map = {**dict(zip(hooks.keys(), hooks.keys())), **equal_inputs_map} # start svd calculation @@ -441,7 +489,9 @@ def _get_eva_state_dict( layer_converged = list(convergence_dict.values()) + [ convergence_dict[v] for v in equal_inputs_map.values() ] - pbar.set_description(f"{sum(layer_converged)}/{len(layer_converged)} layers have converged") + pbar.set_description( + f"{sum(layer_converged)}/{len(layer_converged)} layers have converged" + ) if all(convergence_dict.values()): break @@ -453,10 +503,17 @@ def _get_eva_state_dict( if not all(hasattr(h[0].svd, "components_") for h in hooks.values()): continue - rank_dist = _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components) + rank_dist = _get_rank_distribution( + hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components + ) # check all custom hooks have been removed - remaining_hooks = {n for n, m in model.named_modules() for v in m._forward_hooks.values() if isinstance(v, _Hook)} + remaining_hooks = { + n + for n, m in model.named_modules() + for v in m._forward_hooks.values() + if isinstance(v, _Hook) + } if len(remaining_hooks) > 0: raise ValueError( f"Found active hooks added by EVA that weren't properly removed: {remaining_hooks}. " @@ -510,9 +567,12 @@ def _load_eva_state_dict( other_module_names.append(name_in_base_model) continue # Regexp matching - Find key which matches current target_name in patterns provided - r = peft_config.rank_pattern.get(get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r) + r = peft_config.rank_pattern.get( + get_pattern_key(peft_config.rank_pattern.keys(), name), peft_config.r + ) alpha = peft_config.alpha_pattern.get( - get_pattern_key(peft_config.alpha_pattern.keys(), name), peft_config.lora_alpha + get_pattern_key(peft_config.alpha_pattern.keys(), name), + peft_config.lora_alpha, ) if name in eva_state_dict: w = eva_state_dict.pop(name) @@ -524,12 +584,22 @@ def _load_eva_state_dict( elif new_rank != r: if peft_config.eva_config.adjust_scaling_factors: alpha *= new_rank / r - if new_rank != r or module.lora_A[adapter_name].weight.device.type == "meta": - module.update_layer(r=new_rank, lora_alpha=alpha, init_lora_weights="eva", **update_layer_kwargs) + if ( + new_rank != r + or module.lora_A[adapter_name].weight.device.type == "meta" + ): + module.update_layer( + r=new_rank, + lora_alpha=alpha, + init_lora_weights="eva", + **update_layer_kwargs, + ) module.lora_A[adapter_name].weight.copy_(w) new_target_modules.append(name_in_base_model) else: - module.update_layer(r=r, lora_alpha=alpha, init_lora_weights=True, **update_layer_kwargs) + module.update_layer( + r=r, lora_alpha=alpha, init_lora_weights=True, **update_layer_kwargs + ) missing_eva_inits.append(name_in_base_model) new_rank = r # update rank pattern and alpha pattern @@ -541,7 +611,9 @@ def _load_eva_state_dict( # update target modules if some lora layers have been removed due to their EVA rank being 0 new_target_modules = new_target_modules + missing_eva_inits if len(new_target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION: - new_target_modules = _find_minimal_target_modules(new_target_modules, other_module_names) + new_target_modules = _find_minimal_target_modules( + new_target_modules, other_module_names + ) model.peft_config[adapter_name].target_modules = new_target_modules # set rank pattern obtained from EVA @@ -564,8 +636,12 @@ def get_eva_state_dict( dataloader: Iterable, peft_config: Optional[LoraConfig] = None, forward_fn: Optional[callable] = forward_fn_dict, - prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling, - prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling, + prepare_model_inputs_fn: Optional[ + callable + ] = prepare_model_inputs_fn_language_modeling, + prepare_layer_inputs_fn: Union[ + callable, Dict[str, callable], None + ] = prepare_layer_inputs_fn_language_modeling, adapter_name: str = "default", gather_distributed_inputs: bool = True, show_progress_bar: bool = True, @@ -613,7 +689,9 @@ def get_eva_state_dict( def target_module_check_fn_peft_model(name, module, unsupported_lora_modules): "check if a module is an adapter module via base_layer attribute" - return hasattr(module, "base_layer") and not isinstance(module, unsupported_lora_modules) + return hasattr(module, "base_layer") and not isinstance( + module, unsupported_lora_modules + ) def target_module_check_fn_default(name, module, peft_config): "check if a module is an adapter module via target_modules" @@ -635,11 +713,14 @@ def get_eva_state_dict( if is_peft_model: ctx = model.disable_adapter() target_module_check_fn = partial( - target_module_check_fn_peft_model, unsupported_lora_modules=UNSUPPORTED_LORA_MODULES + target_module_check_fn_peft_model, + unsupported_lora_modules=UNSUPPORTED_LORA_MODULES, ) else: ctx = nullcontext() - target_module_check_fn = partial(target_module_check_fn_default, peft_config=peft_config) + target_module_check_fn = partial( + target_module_check_fn_default, peft_config=peft_config + ) with ctx: eva_state_dict = _get_eva_state_dict( @@ -662,8 +743,12 @@ def initialize_lora_eva_weights( dataloader: Optional[Iterable] = None, eva_state_dict: Optional[dict] = None, forward_fn: Optional[callable] = forward_fn_dict, - prepare_model_inputs_fn: Optional[callable] = prepare_model_inputs_fn_language_modeling, - prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None] = prepare_layer_inputs_fn_language_modeling, + prepare_model_inputs_fn: Optional[ + callable + ] = prepare_model_inputs_fn_language_modeling, + prepare_layer_inputs_fn: Union[ + callable, Dict[str, callable], None + ] = prepare_layer_inputs_fn_language_modeling, adapter_name: str = "default", gather_distributed_inputs: bool = True, show_progress_bar: bool = True, @@ -715,11 +800,15 @@ def initialize_lora_eva_weights( # eva currently only works with a single active adapter # Important: when removing this requirement, make sure eva init works correctly if the new rank is 0. if len(model.active_adapters) > 1: - raise ValueError("`initialize_lora_eva_weights` currently only works with a single active adapter") + raise ValueError( + "`initialize_lora_eva_weights` currently only works with a single active adapter" + ) # initialize_lora_eva_weights only works with `init_lora_weights='eva'` if model.peft_config[adapter_name].init_lora_weights != "eva": - raise ValueError("`initialize_lora_eva_weights` can only be used with `init_lora_weights='eva'`") + raise ValueError( + "`initialize_lora_eva_weights` can only be used with `init_lora_weights='eva'`" + ) # compute svd if eva_state_dict is None: diff --git a/src/peft_library/tuners/lora/gptq.py b/src/peft_library/tuners/lora/gptq.py index d482601..859109d 100644 --- a/src/peft_library/tuners/lora/gptq.py +++ b/src/peft_library/tuners/lora/gptq.py @@ -39,7 +39,9 @@ class QuantLinear(torch.nn.Module, LoraLayer): LoraLayer.__init__(self, base_layer) if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support DoRA yet, please set it to False" + ) # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter # for backwards compatibility @@ -109,7 +111,9 @@ def dispatch_gptq( gptq_quantization_config = kwargs.get("gptq_quantization_config", None) AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) - if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear): + if AutoGPTQQuantLinear is not None and isinstance( + target_base_layer, AutoGPTQQuantLinear + ): new_module = QuantLinear(target, adapter_name, **kwargs) target.qweight = target_base_layer.qweight diff --git a/src/peft_library/tuners/lora/hqq.py b/src/peft_library/tuners/lora/hqq.py index d623f7a..20f2c22 100644 --- a/src/peft_library/tuners/lora/hqq.py +++ b/src/peft_library/tuners/lora/hqq.py @@ -45,7 +45,9 @@ if is_hqq_available(): **kwargs, ) -> None: if lora_bias: - raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support lora_bias yet, set it to False" + ) super().__init__() LoraLayer.__init__(self, base_layer) @@ -63,7 +65,9 @@ if is_hqq_available(): lora_bias=lora_bias, ) - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights into the base weights @@ -86,7 +90,10 @@ if is_hqq_available(): continue layer = self.get_base_layer() - quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta} + quant_config = { + **copy.deepcopy(layer.quant_config), + "offload_meta": layer.offload_meta, + } lora_data = self.get_delta_weight(active_adapter) output = layer.dequantize() @@ -95,19 +102,28 @@ if is_hqq_available(): else: # handle dora # since output already includes scaling, set it to 1 here - weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach() + weight_norm = self._get_weight_norm( + output, lora_data, scaling=1 + ).detach() # We need to cache weight_norm because it has to be based on the original weights. We # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter] / weight_norm + ) w_data = dora_factor.view(-1, 1) * (output + lora_data) if safe_merge and not torch.isfinite(w_data).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device) + new_hqq_layer = HQQLinear( + None, + quant_config, + compute_dtype=layer.compute_dtype, + device=layer.device, + ) quant_config.pop("offload_meta", None) new_hqq_layer.quantize(w_data, **quant_config) self.base_layer = new_hqq_layer @@ -128,17 +144,27 @@ if is_hqq_available(): lora_data = self.get_delta_weight(active_adapter) layer = self.get_base_layer() - quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta} + quant_config = { + **copy.deepcopy(layer.quant_config), + "offload_meta": layer.offload_meta, + } output = layer.dequantize() if not self.use_dora[active_adapter]: w_data = output - lora_data else: weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") - dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter] / weight_norm + ) w_data = output.data / dora_factor.view(-1, 1) - lora_data - new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device) + new_hqq_layer = HQQLinear( + None, + quant_config, + compute_dtype=layer.compute_dtype, + device=layer.device, + ) quant_config.pop("offload_meta", None) new_hqq_layer.quantize(w_data, **quant_config) self.base_layer = new_hqq_layer @@ -162,7 +188,13 @@ if is_hqq_available(): unique_adapters = set(adapter_names) sub_batch_indices_list = [] for adapter in unique_adapters: - sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + sub_batch_indices_list.append( + [ + index + for index, item in enumerate(adapter_names) + if item == adapter + ] + ) for i, active_adapter in enumerate(unique_adapters): if active_adapter == "__base__": @@ -201,7 +233,9 @@ if is_hqq_available(): self.unmerge() result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: - result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: diff --git a/src/peft_library/tuners/lora/layer.py b/src/peft_library/tuners/lora/layer.py index 8965c5e..764fe44 100644 --- a/src/peft_library/tuners/lora/layer.py +++ b/src/peft_library/tuners/lora/layer.py @@ -25,11 +25,21 @@ from torch import svd_lowrank from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge -from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type +from peft.utils.integrations import ( + dequantize_module_weight, + gather_params_ctx, + get_bnb_param_type, +) from peft.utils.other import transpose from .config import LoraConfig -from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer +from .dora import ( + DoraConv2dLayer, + DoraConv3dLayer, + DoraEmbeddingLayer, + DoraLinearLayer, + _DoraConvNdLayer, +) class LoraLayer(BaseTunerLayer): @@ -38,7 +48,9 @@ class LoraLayer(BaseTunerLayer): # All names of other parameters that may contain adapter-related parameters other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") - def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None: + def __init__( + self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs + ) -> None: self.base_layer = base_layer self.r = {} self.lora_alpha = {} @@ -67,10 +79,15 @@ class LoraLayer(BaseTunerLayer): elif isinstance(base_layer, nn.Conv3d): in_features, out_features = base_layer.in_channels, base_layer.out_channels elif isinstance(base_layer, nn.Embedding): - in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim + in_features, out_features = ( + base_layer.num_embeddings, + base_layer.embedding_dim, + ) elif isinstance(base_layer, Conv1D): in_features, out_features = ( - base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + base_layer.weight.ds_shape + if hasattr(base_layer.weight, "ds_shape") + else base_layer.weight.shape ) elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): # QuantLinear @@ -78,26 +95,40 @@ class LoraLayer(BaseTunerLayer): elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): # Megatron ColumnParallelLinear,RowParallelLinear in_features, out_features = base_layer.input_size, base_layer.output_size - elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear": + elif ( + hasattr(base_layer, "codebooks") + and base_layer.__class__.__name__ == "QuantizedLinear" + ): # AQLM QuantLinear in_features, out_features = base_layer.in_features, base_layer.out_features - elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": + elif ( + hasattr(base_layer, "w_bit") + and base_layer.__class__.__name__ == "WQLinear_GEMM" + ): # Awq layers in_features, out_features = base_layer.in_features, base_layer.out_features elif base_layer.__class__.__name__ == "EetqLinear": # Eetq layers in_features, out_features = base_layer.in_features, base_layer.out_features - elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear": + elif ( + hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear" + ): # HQQ layers in_features, out_features = base_layer.in_features, base_layer.out_features else: # possibly support user provided custom layer types using dynamic dispatch - if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): - in_features, out_features = base_layer.in_features, base_layer.out_features + if hasattr(base_layer, "in_features") and hasattr( + base_layer, "out_features" + ): + in_features, out_features = ( + base_layer.in_features, + base_layer.out_features, + ) else: in_features, out_features = None, None warnings.warn( - f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning + f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", + UserWarning, ) self.in_features = in_features @@ -116,7 +147,9 @@ class LoraLayer(BaseTunerLayer): ): # This code works for linear layers, override for other layer types if r <= 0: - raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + raise ValueError( + f"`r` should be a positive integer value but the value passed is {r}" + ) self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha @@ -140,7 +173,9 @@ class LoraLayer(BaseTunerLayer): if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): with gather_params_ctx(self.get_base_layer().weight): self.pissa_init(adapter_name, init_lora_weights) - elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": + elif ( + isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora" + ): with gather_params_ctx(self.get_base_layer().weight): self.olora_init(adapter_name) elif init_lora_weights == "loftq": @@ -169,9 +204,13 @@ class LoraLayer(BaseTunerLayer): if init_lora_weights is True: # initialize A the same way as the default for nn.Linear and B to zero # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 - nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_( + self.lora_A[adapter_name].weight, a=math.sqrt(5) + ) elif init_lora_weights.lower() == "gaussian": - nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]) + nn.init.normal_( + self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name] + ) else: raise ValueError(f"Unknown initialization {init_lora_weights=}") nn.init.zeros_(self.lora_B[adapter_name].weight) @@ -210,7 +249,11 @@ class LoraLayer(BaseTunerLayer): self.lora_A[adapter_name].weight.data = Rr.contiguous() self.lora_B[adapter_name].weight.data = Qr.contiguous() - weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight + weight_tensor.data -= ( + scale_factor + * self.lora_B[adapter_name].weight + @ self.lora_A[adapter_name].weight + ) if bnb_param_type == "4bit": weight_tensor = orig_weight.__class__( weight_tensor, @@ -249,7 +292,9 @@ class LoraLayer(BaseTunerLayer): Uhr = Uh[: self.r[adapter_name]] elif len(init_lora_weights.split("_niter_")) == 2: Vr, Sr, Ur = svd_lowrank( - weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1]) + weight.data, + self.r[adapter_name], + niter=int(init_lora_weights.split("_niter_")[-1]), ) Sr /= self.scaling[adapter_name] Uhr = Ur.t() @@ -290,12 +335,18 @@ class LoraLayer(BaseTunerLayer): def dora_init(self, adapter_name: str) -> None: if not self.lora_magnitude_vector: # first dora layer being added, add lora_magnitude_vector to the list of learnable parameters - self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) + self.adapter_layer_names = self.adapter_layer_names[:] + ( + "lora_magnitude_vector", + ) - dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(self, "fan_in_fan_out", False)) + dora_layer = DoraLinearLayer( + fan_in_fan_out=getattr(self, "fan_in_fan_out", False) + ) lora_A = self.lora_A[adapter_name].weight lora_B = self.lora_B[adapter_name].weight - place_on_cpu = self.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu") + place_on_cpu = self.ephemeral_gpu_offload and ( + lora_A.device.type == "cpu" or lora_B.device.type == "cpu" + ) if self.ephemeral_gpu_offload: if lora_A.device.type in ["cuda", "xpu"]: lora_B = lora_B.to(lora_A.device) @@ -308,7 +359,11 @@ class LoraLayer(BaseTunerLayer): lora_A = lora_A.to(lora_B.device) scaling = self.scaling[adapter_name] dora_layer.update_layer( - base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling, place_on_cpu=place_on_cpu + base_layer=self.get_base_layer(), + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + place_on_cpu=place_on_cpu, ) self.lora_magnitude_vector[adapter_name] = dora_layer @@ -341,7 +396,9 @@ class LoraLayer(BaseTunerLayer): continue if scale is None: - self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter] + self.scaling[active_adapter] = ( + self.lora_alpha[active_adapter] / self.r[active_adapter] + ) else: self.scaling[active_adapter] /= scale @@ -383,7 +440,9 @@ class LoraLayer(BaseTunerLayer): unique_adapters = set(adapter_names) sub_batch_indices_list = [] for adapter in unique_adapters: - sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + sub_batch_indices_list.append( + [index for index, item in enumerate(adapter_names) if item == adapter] + ) for i, active_adapter in enumerate(unique_adapters): if active_adapter == "__base__": @@ -449,7 +508,9 @@ class Linear(nn.Module, LoraLayer): ) self.is_target_conv_1d_layer = is_target_conv_1d_layer - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights into the base weights @@ -482,15 +543,24 @@ class Linear(nn.Module, LoraLayer): # since delta_weight already includes scaling, set it to 1 here weight_norm = ( self.lora_magnitude_vector[active_adapter] - .get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1) + .get_weight_norm( + orig_weights, + transpose(delta_weight, self.fan_in_fan_out), + scaling=1, + ) .detach() ) # We need to cache weight_norm because it has to be based on the original weights. We # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm - dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out) + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight + / weight_norm + ) + dora_factor = transpose( + dora_factor.view(-1, 1), self.fan_in_fan_out + ) orig_weights = dora_factor * (orig_weights + delta_weight) if not torch.isfinite(orig_weights).all(): @@ -518,7 +588,9 @@ class Linear(nn.Module, LoraLayer): weight_norm = ( self.lora_magnitude_vector[active_adapter] .get_weight_norm( - base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1 + base_layer.weight, + transpose(delta_weight, self.fan_in_fan_out), + scaling=1, ) .detach() ) @@ -526,9 +598,16 @@ class Linear(nn.Module, LoraLayer): # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm - dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out) - new_weight = dora_factor * (base_layer.weight.data + delta_weight) + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight + / weight_norm + ) + dora_factor = transpose( + dora_factor.view(-1, 1), self.fan_in_fan_out + ) + new_weight = dora_factor * ( + base_layer.weight.data + delta_weight + ) base_layer.weight.data = new_weight if self.lora_bias[active_adapter]: @@ -552,7 +631,9 @@ class Linear(nn.Module, LoraLayer): weight.data -= delta_weight else: weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight / weight_norm + ) weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight weight.data = weight_orig @@ -573,7 +654,9 @@ class Linear(nn.Module, LoraLayer): # In case users wants to merge the adapter weights that are in # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to # (b)float16 because some CPUs have slow bf16/fp16 matmuls. - cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + cast_to_fp32 = device.type == "cpu" and ( + dtype == torch.float16 or dtype == torch.bfloat16 + ) weight_A = self.lora_A[adapter].weight weight_B = self.lora_B[adapter].weight @@ -582,7 +665,9 @@ class Linear(nn.Module, LoraLayer): weight_A = weight_A.float() weight_B = weight_B.float() - output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] + output_tensor = ( + transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] + ) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) @@ -602,7 +687,9 @@ class Linear(nn.Module, LoraLayer): self.unmerge() result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: - result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: @@ -661,7 +748,9 @@ class Embedding(nn.Module, LoraLayer): ) -> None: if lora_bias: # lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter - raise ValueError(f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}.") + raise ValueError( + f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}." + ) super().__init__() LoraLayer.__init__(self, base_layer) @@ -679,10 +768,20 @@ class Embedding(nn.Module, LoraLayer): ) def update_layer( - self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + use_rslora, + use_dora, + lora_bias, ): if r <= 0: - raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + raise ValueError( + f"`r` should be a positive integer value but the value passed is {r}" + ) self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha @@ -723,18 +822,25 @@ class Embedding(nn.Module, LoraLayer): def dora_init(self, adapter_name: str) -> None: if self.lora_magnitude_vector is None: # first dora layer being added, add lora_magnitude_vector to the list of learnable parameters - self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) + self.adapter_layer_names = self.adapter_layer_names[:] + ( + "lora_magnitude_vector", + ) dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True) lora_embedding_A = self.lora_embedding_A[adapter_name] lora_embedding_B = self.lora_embedding_B[adapter_name] scaling = self.scaling[adapter_name] dora_layer.update_layer( - base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling + base_layer=self.get_base_layer(), + lora_A=lora_embedding_A, + lora_B=lora_embedding_B, + scaling=scaling, ) self.lora_magnitude_vector[adapter_name] = dora_layer - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights into the base weights @@ -781,7 +887,9 @@ class Embedding(nn.Module, LoraLayer): while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter in self.lora_embedding_A.keys(): - self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + self.get_base_layer().weight.data -= self.get_delta_weight( + active_adapter + ) def get_delta_weight(self, adapter) -> torch.Tensor: """ @@ -797,7 +905,9 @@ class Embedding(nn.Module, LoraLayer): # In case users wants to merge the adapter weights that are in # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to # (b)float16 because some CPUs have slow bf16/fp16 matmuls. - cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + cast_to_fp32 = device.type == "cpu" and ( + dtype == torch.float16 or dtype == torch.bfloat16 + ) weight_A = self.lora_embedding_A[adapter] weight_B = self.lora_embedding_B[adapter] @@ -827,7 +937,9 @@ class Embedding(nn.Module, LoraLayer): unique_adapters = set(adapter_names) sub_batch_indices_list = [] for adapter in unique_adapters: - sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + sub_batch_indices_list.append( + [index for index, item in enumerate(adapter_names) if item == adapter] + ) for i, active_adapter in enumerate(unique_adapters): if active_adapter == "__base__": @@ -869,7 +981,9 @@ class Embedding(nn.Module, LoraLayer): self.unmerge() result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: - result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: @@ -886,7 +1000,9 @@ class Embedding(nn.Module, LoraLayer): after_A = self._embed(x, embedding_A) result = result + (after_A @ embedding_B) * scaling else: - mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter]( + mag_norm_scale, dora_result = self.lora_magnitude_vector[ + active_adapter + ]( x, lora_A=embedding_A, lora_B=embedding_B, @@ -937,10 +1053,20 @@ class _ConvNd(nn.Module, LoraLayer): ) def update_layer( - self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + use_rslora, + use_dora, + lora_bias, ): if r <= 0: - raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + raise ValueError( + f"`r` should be a positive integer value but the value passed is {r}" + ) self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha @@ -957,8 +1083,12 @@ class _ConvNd(nn.Module, LoraLayer): padding = base_layer.padding conv_layer = type(base_layer) out_kernel = out_stride = (1,) * (self._kernel_dim - 2) - self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False) - self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=lora_bias) + self.lora_A[adapter_name] = conv_layer( + self.in_features, r, kernel_size, stride, padding, bias=False + ) + self.lora_B[adapter_name] = conv_layer( + r, self.out_features, out_kernel, out_stride, bias=lora_bias + ) self.lora_bias[adapter_name] = lora_bias if use_rslora: @@ -988,21 +1118,30 @@ class _ConvNd(nn.Module, LoraLayer): def dora_init(self, adapter_name: str) -> None: if self.lora_magnitude_vector is None: # first dora layer being added, add lora_magnitude_vector to the list of learnable parameters - self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) + self.adapter_layer_names = self.adapter_layer_names[:] + ( + "lora_magnitude_vector", + ) dora_layer_class = self._get_dora_layer_class() dora_layer = dora_layer_class(fan_in_fan_out=False) lora_A = self.lora_A[adapter_name].weight lora_B = self.lora_B[adapter_name].weight scaling = self.scaling[adapter_name] - dora_layer.update_layer(base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling) + dora_layer.update_layer( + base_layer=self.get_base_layer(), + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + ) self.lora_magnitude_vector[adapter_name] = dora_layer def _get_dora_layer_class(self) -> type[_DoraConvNdLayer]: # Subclasses should override this method to return the appropriate DoraLayer class raise NotImplementedError - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights inside the base weights @@ -1043,8 +1182,13 @@ class _ConvNd(nn.Module, LoraLayer): # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm - orig_weights = dora_factor.view(*self._get_dora_factor_view()) * (orig_weights + delta_weight) + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight + / weight_norm + ) + orig_weights = dora_factor.view( + *self._get_dora_factor_view() + ) * (orig_weights + delta_weight) if not torch.isfinite(orig_weights).all(): raise ValueError( @@ -1076,7 +1220,10 @@ class _ConvNd(nn.Module, LoraLayer): # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight + / weight_norm + ) new_weight = dora_factor.view(*self._get_dora_factor_view()) * ( base_layer.weight.data + delta_weight ) @@ -1103,8 +1250,13 @@ class _ConvNd(nn.Module, LoraLayer): weight.data -= delta_weight else: weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm - weight_orig = weight.data / dora_factor.view(*self._get_dora_factor_view()) - delta_weight + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight / weight_norm + ) + weight_orig = ( + weight.data / dora_factor.view(*self._get_dora_factor_view()) + - delta_weight + ) weight.data = weight_orig if self.lora_bias[active_adapter]: @@ -1124,7 +1276,9 @@ class _ConvNd(nn.Module, LoraLayer): # In case users wants to merge the adapter weights that are in # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to # (b)float16 because some CPUs have slow bf16/fp16 matmuls. - cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + cast_to_fp32 = device.type == "cpu" and ( + dtype == torch.float16 or dtype == torch.bfloat16 + ) weight_A = self.lora_A[adapter].weight weight_B = self.lora_B[adapter].weight @@ -1136,9 +1290,9 @@ class _ConvNd(nn.Module, LoraLayer): # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 if self.get_base_layer().weight.size()[2:4] == (1, 1): # conv2d 1x1 - output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( - 3 - ) * self.scaling[adapter] + output_tensor = ( + weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * self.scaling[adapter] else: output_tensor = ( self.conv_fn( @@ -1166,7 +1320,9 @@ class _ConvNd(nn.Module, LoraLayer): self.unmerge() result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: - result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: @@ -1207,7 +1363,9 @@ class Conv2d(_ConvNd): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self._kernel_dim == 4: - raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}") + raise ValueError( + f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}" + ) self.conv_fn = F.conv2d def _get_dora_layer_class(self): @@ -1219,7 +1377,9 @@ class Conv3d(_ConvNd): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self._kernel_dim == 5: - raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}") + raise ValueError( + f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}" + ) self.conv_fn = F.conv3d def _get_dora_layer_class(self): @@ -1262,10 +1422,13 @@ def dispatch_default( elif isinstance(target_base_layer, Conv1D): if not kwargs["fan_in_fan_out"]: warnings.warn( - "fan_in_fan_out is set to False but the target module is `Conv1D`. " "Setting fan_in_fan_out to True." + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True kwargs.update(lora_config.loftq_config) - new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) + new_module = Linear( + target, adapter_name, is_target_conv_1d_layer=True, **kwargs + ) return new_module diff --git a/src/peft_library/tuners/lora/model.py b/src/peft_library/tuners/lora/model.py index ba45202..e8c11d3 100644 --- a/src/peft_library/tuners/lora/model.py +++ b/src/peft_library/tuners/lora/model.py @@ -42,7 +42,13 @@ from peft.utils import ( get_peft_model_state_dict, get_quantization_config, ) -from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties +from peft.utils.merge_utils import ( + dare_linear, + dare_ties, + magnitude_prune, + task_arithmetic, + ties, +) from peft.utils.other import get_pattern_key from .aqlm import dispatch_aqlm @@ -137,8 +143,12 @@ class LoraModel(BaseTuner): prefix: str = "lora_" - def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: - super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + def __init__( + self, model, config, adapter_name, low_cpu_mem_usage: bool = False + ) -> None: + super().__init__( + model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage + ) def _check_new_adapter_config(self, config: LoraConfig) -> None: """ @@ -213,7 +223,9 @@ class LoraModel(BaseTuner): quant_methods = ["gptq", "aqlm", "awq"] for quant_method in quant_methods: - quantization_config = get_quantization_config(self.model, method=quant_method) + quantization_config = get_quantization_config( + self.model, method=quant_method + ) if quantization_config is not None: kwargs[f"{quant_method}_quantization_config"] = quantization_config @@ -232,7 +244,9 @@ class LoraModel(BaseTuner): lora_bias=lora_config.lora_bias, ) else: - new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + new_module = self._create_new_module( + lora_config, adapter_name, target, **kwargs + ) if adapter_name not in self.active_adapters: # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) @@ -269,11 +283,15 @@ class LoraModel(BaseTuner): weight = ( child.qweight if hasattr(child, "qweight") - else child.W_q - if hasattr(child, "W_q") - else child.weight - if hasattr(child, "weight") - else next(child.parameters()) + else ( + child.W_q + if hasattr(child, "W_q") + else ( + child.weight + if hasattr(child, "weight") + else next(child.parameters()) + ) + ) ) if not any(p.device == meta for p in module.parameters()): module.to(weight.device) @@ -294,10 +312,16 @@ class LoraModel(BaseTuner): p.requires_grad = True elif bias == "lora_only": for m in model.modules(): - if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: + if ( + isinstance(m, LoraLayer) + and hasattr(m, "bias") + and m.bias is not None + ): m.bias.requires_grad = True else: - raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + raise NotImplementedError( + f"Requested bias: {bias}, is not implemented." + ) @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): @@ -351,7 +375,9 @@ class LoraModel(BaseTuner): new_module = None for dispatcher in dispatchers: - new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs) + new_module = dispatcher( + target, adapter_name, lora_config=lora_config, **kwargs + ) if new_module is not None: # first match wins break @@ -370,14 +396,19 @@ class LoraModel(BaseTuner): try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: - if name == "model": # see #1892: prevent infinite recursion if class is not initialized + if ( + name == "model" + ): # see #1892: prevent infinite recursion if class is not initialized raise return getattr(self.model, name) def get_peft_config_as_dict(self, inference: bool = False): config_dict = {} for key, value in self.peft_config.items(): - config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + config = { + k: v.value if isinstance(v, Enum) else v + for k, v in asdict(value).items() + } if inference: config["inference_mode"] = True config_dict[key] = config @@ -428,7 +459,9 @@ class LoraModel(BaseTuner): for module in self.model.modules(): if isinstance(module, LoraLayer): if module.merged: - warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + warnings.warn( + "Adapter cannot be set when the model is merged. Unmerging the model first." + ) module.unmerge() module.set_adapter(adapter_name) self.active_adapter = adapter_name @@ -443,7 +476,9 @@ class LoraModel(BaseTuner): return if self.training: - raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") + raise ValueError( + "Cannot pass `adapter_names` when the model is in training mode." + ) # Check that users only passed actually existing adapters. # Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want @@ -456,12 +491,18 @@ class LoraModel(BaseTuner): unique_adapters = {name for name in adapter_names if name != "__base__"} unexpected_adapters = unique_adapters - expected_adapters if unexpected_adapters: - raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}") + raise ValueError( + f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}" + ) hook_handles = [] for module in self.modules(): - if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper): - pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) + if isinstance(module, LoraLayer) or isinstance( + module, ModulesToSaveWrapper + ): + pre_forward = partial( + _adapter_names_pre_forward_hook, adapter_names=adapter_names + ) handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) hook_handles.append(handle) @@ -477,17 +518,26 @@ class LoraModel(BaseTuner): """ super()._check_merge_allowed() if getattr(self.model, "quantization_method", None) == "gptq": - raise ValueError("Cannot merge LORA layers when the model is gptq quantized") + raise ValueError( + "Cannot merge LORA layers when the model is gptq quantized" + ) if self.peft_config.get("layer_replication"): - raise ValueError("Cannot merge LORA layers when base model layers are replicated") + raise ValueError( + "Cannot merge LORA layers when base model layers are replicated" + ) @staticmethod def _prepare_adapter_config(peft_config, model_config): if peft_config.target_modules is None: - if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + if ( + model_config["model_type"] + not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING + ): raise ValueError("Please specify `target_modules` in `peft_config`") peft_config.target_modules = set( - TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[ + model_config["model_type"] + ] ) return peft_config @@ -501,7 +551,9 @@ class LoraModel(BaseTuner): if merge: self._check_merge_allowed() - key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + key_list = [ + key for key, _ in self.model.named_modules() if self.prefix not in key + ] desc = "Unloading " + ("and merging " if merge else "") + "model" for key in tqdm(key_list, disable=not progressbar, desc=desc): try: @@ -512,14 +564,18 @@ class LoraModel(BaseTuner): if hasattr(target, "base_layer"): if merge: target.merge(safe_merge=safe_merge, adapter_names=adapter_names) - self._replace_module(parent, target_name, target.get_base_layer(), target) + self._replace_module( + parent, target_name, target.get_base_layer(), target + ) elif isinstance(target, ModulesToSaveWrapper): # save any additional trainable modules part of `modules_to_save` new_module = target.modules_to_save[target.active_adapter] if hasattr(new_module, "base_layer"): # check if the module is itself a tuner layer if merge: - new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module.merge( + safe_merge=safe_merge, adapter_names=adapter_names + ) new_module = new_module.get_base_layer() setattr(parent, target_name, new_module) @@ -539,7 +595,11 @@ class LoraModel(BaseTuner): # If more than one of the adapters targets the same module with modules_to_save, raise an error, as these # modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they # have modules for the adapters to be merged. - modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)] + modules_to_save_wrappers = [ + module + for module in self.modules() + if isinstance(module, ModulesToSaveWrapper) + ] problematic_wrappers = [ wrapper for wrapper in modules_to_save_wrappers @@ -555,7 +615,13 @@ class LoraModel(BaseTuner): combination_type = "linear" if len(adapters) == 1 else combination_type adapters_ranks = [self.peft_config[adapter].r for adapter in adapters] - if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"): + if combination_type in ( + "linear", + "ties", + "dare_ties", + "dare_linear", + "magnitude_prune", + ): # all adapters ranks should be same, new rank is just this value if len(set(adapters_ranks)) != 1: raise ValueError( @@ -573,7 +639,9 @@ class LoraModel(BaseTuner): else: raise ValueError(f"Invalid combination_type: {combination_type}") - target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters] + target_module_types = [ + type(self.peft_config[adapter].target_modules) for adapter in adapters + ] if not target_module_types: raise ValueError(f"Found no adapter matching the names in {adapters}") if len(set(target_module_types)) > 1: @@ -583,13 +651,18 @@ class LoraModel(BaseTuner): ) if target_module_types[0] is str: - new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters) + new_target_modules = "|".join( + f"({self.peft_config[adapter].target_modules})" for adapter in adapters + ) elif target_module_types[0] is set: new_target_modules = reduce( - operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters) + operator.or_, + (self.peft_config[adapter].target_modules for adapter in adapters), ) else: - raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules") + raise TypeError( + f"Invalid type {target_module_types[0]} found in target_modules" + ) return combination_type, new_rank, new_target_modules @@ -649,10 +722,12 @@ class LoraModel(BaseTuner): if adapter_name in list(self.peft_config.keys()): return - combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter( - adapters=adapters, - combination_type=combination_type, - svd_rank=svd_rank, + combination_type, new_rank, new_target_modules = ( + self._check_add_weighted_adapter( + adapters=adapters, + combination_type=combination_type, + svd_rank=svd_rank, + ) ) self.peft_config[adapter_name] = replace( @@ -666,7 +741,9 @@ class LoraModel(BaseTuner): # Do we really need that? _freeze_adapter(self.model, adapter_name) - key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + key_list = [ + key for key, _ in self.model.named_modules() if self.prefix not in key + ] for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, LoraLayer): @@ -692,11 +769,17 @@ class LoraModel(BaseTuner): current_adapter_lora_B = target.lora_embedding_B[adapter] else: continue - loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter]) + loras_A.append( + current_adapter_lora_A.data + * weight + * target.scaling[adapter] + ) loras_B.append(current_adapter_lora_B.data) if len(loras_A) == 0: - raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.") + raise ValueError( + "No matching LoRAs found. Please raise an issue on GitHub." + ) loras_A = torch.cat(loras_A, dim=0) loras_B = torch.cat(loras_B, dim=1) target_lora_A.data[: loras_A.shape[0], :] = loras_A @@ -708,23 +791,38 @@ class LoraModel(BaseTuner): "dare_ties_svd", "magnitude_prune_svd", ]: - target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter( - combination_type, - adapters, - weights, - new_rank, - target, - target_lora_A, - target_lora_B, - density, - majority_sign_method, - svd_clamp, - full_matrices=svd_full_matrices, - driver=svd_driver, + target_lora_A.data, target_lora_B.data = ( + self._svd_generalized_task_arithmetic_weighted_adapter( + combination_type, + adapters, + weights, + new_rank, + target, + target_lora_A, + target_lora_B, + density, + majority_sign_method, + svd_clamp, + full_matrices=svd_full_matrices, + driver=svd_driver, + ) ) - elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]: - target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter( - combination_type, adapters, weights, target, density, majority_sign_method + elif combination_type in [ + "linear", + "ties", + "dare_linear", + "dare_ties", + "magnitude_prune", + ]: + target_lora_A.data, target_lora_B.data = ( + self._generalized_task_arithmetic_weighted_adapter( + combination_type, + adapters, + weights, + target, + density, + majority_sign_method, + ) ) def _svd_generalized_task_arithmetic_weighted_adapter( @@ -752,21 +850,29 @@ class LoraModel(BaseTuner): # if no valid adapter, nothing to do if len(valid_adapters) == 0: - raise ValueError("No matching LoRAs found. Please raise an issue on Github.") + raise ValueError( + "No matching LoRAs found. Please raise an issue on Github." + ) delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters] valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device) if combination_type == "svd": delta_weight = task_arithmetic(delta_weight, valid_weights) elif combination_type == "ties_svd": - delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method) + delta_weight = ties( + delta_weight, valid_weights, density, majority_sign_method + ) elif combination_type == "dare_linear_svd": delta_weight = dare_linear(delta_weight, valid_weights, density) elif combination_type == "dare_ties_svd": - delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method) + delta_weight = dare_ties( + delta_weight, valid_weights, density, majority_sign_method + ) elif combination_type == "magnitude_prune_svd": delta_weight = magnitude_prune(delta_weight, valid_weights, density) else: - raise ValueError(f"Invalid value passed to combination type: {combination_type}") + raise ValueError( + f"Invalid value passed to combination type: {combination_type}" + ) conv2d = isinstance(target, Conv2d) if conv2d: @@ -775,11 +881,15 @@ class LoraModel(BaseTuner): delta_weight = delta_weight.flatten(start_dim=1) else: delta_weight = delta_weight.squeeze() - if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding: + if ( + hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out + ) or is_embedding: delta_weight = delta_weight.T # based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131 - U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver) + U, S, Vh = torch.linalg.svd( + delta_weight, full_matrices=full_matrices, driver=driver + ) U = U[:, :new_rank] S = S[:new_rank] U = U @ torch.diag(S) @@ -827,11 +937,15 @@ class LoraModel(BaseTuner): if combination_type == "linear": lora_deltas[i] = task_arithmetic(task_tensors, valid_weights) elif combination_type == "ties": - lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method) + lora_deltas[i] = ties( + task_tensors, valid_weights, density, majority_sign_method + ) elif combination_type == "dare_linear": lora_deltas[i] = dare_linear(task_tensors, valid_weights, density) elif combination_type == "dare_ties": - lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method) + lora_deltas[i] = dare_ties( + task_tensors, valid_weights, density, majority_sign_method + ) elif combination_type == "magnitude_prune": lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density) else: @@ -850,7 +964,9 @@ class LoraModel(BaseTuner): raise ValueError(f"Adapter {adapter_name} does not exist") del self.peft_config[adapter_name] - key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + key_list = [ + key for key, _ in self.model.named_modules() if self.prefix not in key + ] new_adapter = None for key in key_list: _, target, _ = _get_submodules(self.model, key) @@ -862,7 +978,10 @@ class LoraModel(BaseTuner): self.active_adapter = new_adapter or [] def merge_and_unload( - self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + self, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, ) -> torch.nn.Module: r""" This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model @@ -900,7 +1019,9 @@ class LoraModel(BaseTuner): """ return self._unload_and_optionally_merge(merge=False) - def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None): + def subtract_mutated_init( + self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None + ): """ This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA | OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus @@ -929,11 +1050,19 @@ class LoraModel(BaseTuner): ## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'. if "lora_A" in name: tensors_lora[name] = torch.cat( - [output_state_dict[name], mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=0 + [ + output_state_dict[name], + mutated_init_state_dict[".".join(name.split(".")[1:])], + ], + dim=0, ) elif "lora_B" in name: tensors_lora[name] = torch.cat( - [output_state_dict[name], -mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=1 + [ + output_state_dict[name], + -mutated_init_state_dict[".".join(name.split(".")[1:])], + ], + dim=1, ) return tensors_lora diff --git a/src/peft_library/tuners/lora/torchao.py b/src/peft_library/tuners/lora/torchao.py index 05ffbc3..3a5b022 100644 --- a/src/peft_library/tuners/lora/torchao.py +++ b/src/peft_library/tuners/lora/torchao.py @@ -33,7 +33,9 @@ class TorchaoLoraLinear(Linear): # this is not strictly necessary, as kwargs are stored either way, but we want to error early if # get_apply_tensor_subclass is missing. if kwargs.get("lora_bias", False): - raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support lora_bias yet, set it to False" + ) super().__init__(*args, **kwargs) self.get_apply_tensor_subclass = get_apply_tensor_subclass @@ -43,10 +45,16 @@ class TorchaoLoraLinear(Linear): # TODO: Not required once int4_weight_only is properly supported by torchao base_layer = self.get_base_layer() weight = base_layer.weight - if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8): - raise ValueError(f"{type(self).__name__} only supports int8 weights for now.") + if hasattr(weight, "layout_tensor") and ( + weight.layout_tensor.data.dtype != torch.int8 + ): + raise ValueError( + f"{type(self).__name__} only supports int8 weights for now." + ) - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: from torchao import quantize_ adapter_names = check_adapters_to_merge(self, adapter_names) @@ -143,7 +151,10 @@ def dispatch_torchao( from torchao.dtypes import AffineQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor - if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)): + if isinstance( + target_base_layer.weight, + (AffineQuantizedTensor, LinearActivationQuantizedTensor), + ): new_module = TorchaoLoraLinear(target, adapter_name, **kwargs) return new_module diff --git a/src/peft_library/tuners/lora/tp_layer.py b/src/peft_library/tuners/lora/tp_layer.py index e12609d..e6e0e0d 100644 --- a/src/peft_library/tuners/lora/tp_layer.py +++ b/src/peft_library/tuners/lora/tp_layer.py @@ -54,13 +54,17 @@ class LoraParallelLinear(nn.Module, LoraLayer): **kwargs, ): if lora_bias: - raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support lora_bias yet, set it to False" + ) super().__init__() LoraLayer.__init__(self, base_layer=base_layer, **kwargs) if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + raise ValueError( + f"{self.__class__.__name__} does not support DoRA yet, please set it to False" + ) self.backend = backend self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear) @@ -113,7 +117,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): **parallel_linear_kwargs, ): if r <= 0: - raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + raise ValueError( + f"`r` should be a positive integer value but the value passed is {r}" + ) self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: @@ -136,9 +142,19 @@ class LoraParallelLinear(nn.Module, LoraLayer): init_method=init_method, config=megatron_config, ) - lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32) + lora_b = nn.Linear( + in_features=r, + out_features=self.out_features, + bias=False, + dtype=torch.float32, + ) else: - lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32) + lora_a = nn.Linear( + in_features=self.in_features, + out_features=r, + bias=False, + dtype=torch.float32, + ) lora_b = self.backend.ColumnParallelLinear( input_size=r, output_size=self.out_features, @@ -158,7 +174,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): with gather_params_ctx(self.get_base_layer().weight): self.pissa_init(adapter_name, init_lora_weights) - elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": + elif ( + isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora" + ): with gather_params_ctx(self.get_base_layer().weight): self.olora_init(adapter_name) elif init_lora_weights == "loftq": @@ -189,7 +207,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): self.unmerge() result, bias = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: - raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") + raise ValueError( + f"{self.__class__.__name__} does not support mixed_batch_forward yet." + ) elif self.merged: result, bias = self.base_layer(x, *args, **kwargs) else: @@ -225,7 +245,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): result = result.to(torch_result_dtype) return result, bias - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights into the base weights @@ -258,15 +280,24 @@ class LoraParallelLinear(nn.Module, LoraLayer): # since delta_weight already includes scaling, set it to 1 here weight_norm = ( self.lora_magnitude_vector[active_adapter] - .get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1) + .get_weight_norm( + orig_weights, + transpose(delta_weight, self.fan_in_fan_out), + scaling=1, + ) .detach() ) # We need to cache weight_norm because it has to be based on the original weights. We # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm - dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out) + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight + / weight_norm + ) + dora_factor = transpose( + dora_factor.view(-1, 1), self.fan_in_fan_out + ) orig_weights = dora_factor * (orig_weights + delta_weight) if not torch.isfinite(orig_weights).all(): @@ -285,7 +316,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): weight_norm = ( self.lora_magnitude_vector[active_adapter] .get_weight_norm( - base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1 + base_layer.weight, + transpose(delta_weight, self.fan_in_fan_out), + scaling=1, ) .detach() ) @@ -293,9 +326,16 @@ class LoraParallelLinear(nn.Module, LoraLayer): # cannot calculate it on the fly based on the merged weights when unmerging because its a # different value self._cache_store(f"{active_adapter}-weight_norm", weight_norm) - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm - dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out) - new_weight = dora_factor * (base_layer.weight.data + delta_weight) + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight + / weight_norm + ) + dora_factor = transpose( + dora_factor.view(-1, 1), self.fan_in_fan_out + ) + new_weight = dora_factor * ( + base_layer.weight.data + delta_weight + ) base_layer.weight.data = new_weight self.merged_adapters.append(active_adapter) @@ -316,7 +356,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): weight.data -= delta_weight else: weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") - dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm + dora_factor = ( + self.lora_magnitude_vector[active_adapter].weight / weight_norm + ) weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight weight.data = weight_orig @@ -334,7 +376,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): # In case users wants to merge the adapter weights that are in # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to # (b)float16 because some CPUs have slow bf16/fp16 matmuls. - cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + cast_to_fp32 = device.type == "cpu" and ( + dtype == torch.float16 or dtype == torch.bfloat16 + ) weight_A = self.lora_A[adapter].weight weight_B = self.lora_B[adapter].weight @@ -343,7 +387,9 @@ class LoraParallelLinear(nn.Module, LoraLayer): weight_A = weight_A.float() weight_B = weight_B.float() - output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] + output_tensor = ( + transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] + ) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) @@ -379,12 +425,17 @@ def dispatch_megatron( if megatron_core and isinstance( target_base_layer, - (megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear), + ( + megatron_core.tensor_parallel.ColumnParallelLinear, + megatron_core.tensor_parallel.RowParallelLinear, + ), ): megatron_kwargs = kwargs.copy() megatron_config = lora_config.megatron_config if isinstance(megatron_config, dict): - transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig + transformer_config_class = ( + megatron_core.transformer.transformer_config.TransformerConfig + ) megatron_config = transformer_config_class(**lora_config.megatron_config) megatron_kwargs["megatron_config"] = megatron_config if megatron_kwargs["fan_in_fan_out"]: @@ -395,7 +446,10 @@ def dispatch_megatron( ) megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False new_module = LoraParallelLinear( - base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs + base_layer=target, + adapter_name=adapter_name, + backend=megatron_core.tensor_parallel, + **megatron_kwargs, ) return new_module diff --git a/src/train.py b/src/train.py index 5808322..5edbac1 100644 --- a/src/train.py +++ b/src/train.py @@ -1,6 +1,7 @@ import sys -sys.path.insert(0, "./transformers_repo/src/") -sys.path.insert(0, "./peft_repo/src/") + +sys.path.insert(0, "transformers_repo/src/") +sys.path.insert(0, "peft_repo/src/") from dataset_library.factory import get_dataset @@ -50,7 +51,6 @@ if __name__ == "__main__": peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules) - # model = get_peft_model(model, peft_config) model.add_adapter(peft_config) elif model_args.peft_type == "LORA": @@ -58,12 +58,8 @@ if __name__ == "__main__": peft_config = LoraConfig(target_modules=model_args.lora_target_modules) - # model = get_peft_model(model, peft_config) model.add_adapter(peft_config) - # if accelerator.is_local_main_process: - # model.print_trainable_parameters() - else: peft_config = None diff --git a/src/train.sh b/src/train.sh index 3c56c2f..2b5624f 100755 --- a/src/train.sh +++ b/src/train.sh @@ -3,7 +3,7 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \ --dataset_name CHEM \ --use_peft \ - --peft_type MMOELORA \ + --peft_type LORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --lora_target_modules q_proj v_proj \ --per_device_train_batch_size 1 \ @@ -12,4 +12,4 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml --output_dir checkpoint/qwen2mmoe/ \ --bf16 \ --torch_dtype bfloat16 \ - --logging_steps 30 + --logging_steps 30 diff --git a/src/utils/args.py b/src/utils/args.py index 6c66992..e0b1fcf 100644 --- a/src/utils/args.py +++ b/src/utils/args.py @@ -17,4 +17,4 @@ class ContinualScriptArguments(ScriptArguments): class ContinualModelConfig(ModelConfig): """Model configuration for continual learning.""" - peft_type: Optional[str] = None \ No newline at end of file + peft_type: Optional[str] = None diff --git a/uv.lock b/uv.lock index 3631262..a6bdf59 100644 --- a/uv.lock +++ b/uv.lock @@ -196,6 +196,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 }, ] +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 }, +] + [[package]] name = "charset-normalizer" version = "3.4.1" @@ -260,6 +269,7 @@ dependencies = [ { name = "numba" }, { name = "peft" }, { name = "pip" }, + { name = "pre-commit" }, { name = "requests" }, { name = "rouge-score" }, { name = "safetensors" }, @@ -292,6 +302,7 @@ requires-dist = [ { name = "numba", specifier = ">=0.60.0" }, { name = "peft", specifier = "==0.14.0" }, { name = "pip", specifier = "==24.3.1" }, + { name = "pre-commit", specifier = ">=4.0.1" }, { name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" }, { name = "rouge-score", specifier = ">=0.1.2" }, { name = "safetensors", specifier = ">=0.5.2" }, @@ -388,6 +399,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252 }, ] +[[package]] +name = "distlib" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 }, +] + [[package]] name = "einops" version = "0.8.0" @@ -533,6 +553,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/3f/50f6b25fafdcfb1c089187a328c95081abf882309afd86f4053951507cd1/huggingface_hub-0.27.1-py3-none-any.whl", hash = "sha256:1c5155ca7d60b60c2e2fc38cbb3ffb7f7c3adf48f824015b219af9061771daec", size = 450658 }, ] +[[package]] +name = "identify" +version = "2.6.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/92/69934b9ef3c31ca2470980423fda3d00f0460ddefdf30a67adf7f17e2e00/identify-2.6.5.tar.gz", hash = "sha256:c10b33f250e5bba374fae86fb57f3adcebf1161bce7cdf92031915fd480c13bc", size = 99213 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/fa/dce098f4cdf7621aa8f7b4f919ce545891f489482f0bfa5102f3eca8608b/identify-2.6.5-py2.py3-none-any.whl", hash = "sha256:14181a47091eb75b337af4c23078c9d09225cd4c48929f521f3bf16b09d02566", size = 99078 }, +] + [[package]] name = "idna" version = "3.10" @@ -825,6 +854,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 }, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, +] + [[package]] name = "numba" version = "0.60.0" @@ -1147,6 +1185,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574 }, ] +[[package]] +name = "pre-commit" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/c8/e22c292035f1bac8b9f5237a2622305bc0304e776080b246f3df57c4ff9f/pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2", size = 191678 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/8f/496e10d51edd6671ebe0432e33ff800aa86775d2d147ce7d43389324a525/pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878", size = 218713 }, +] + [[package]] name = "propcache" version = "0.2.1" @@ -1849,6 +1903,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, ] +[[package]] +name = "virtualenv" +version = "20.29.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/ca/f23dcb02e161a9bba141b1c08aa50e8da6ea25e6d780528f1d385a3efe25/virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35", size = 7658028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/9b/599bcfc7064fbe5740919e78c5df18e5dceb0887e676256a1061bb5ae232/virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779", size = 4282379 }, +] + [[package]] name = "wheel" version = "0.45.1"