From 2cd1bb4993d4d91afe966c7394405b08c0ec9e07 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Thu, 2 Jan 2025 02:44:58 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0PEFT=E5=BA=93=E7=9A=84?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=96=87=E4=BB=B6=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=95=B0=E6=8D=AE=E9=9B=86=E5=AF=BC=E5=85=A5=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=EF=BC=8C=E4=BF=AE=E6=94=B9=E8=AE=AD=E7=BB=83=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E4=BB=A5=E6=94=AF=E6=8C=81=E6=96=B0=E7=9A=84PEFT?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E5=92=8C=E9=85=8D=E7=BD=AE=EF=BC=8C=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E6=8C=81=E7=BB=AD=E5=AD=A6=E4=B9=A0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=B1=BB=EF=BC=8C=E6=B7=BB=E5=8A=A0PEFT?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=9E=9A=E4=B8=BE=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E8=AF=84=E4=BC=B0=E5=92=8C=E8=AE=AD=E7=BB=83=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E4=BB=A5=E9=80=82=E5=BA=94=E6=96=B0=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/collatefn_library/qwen2.py | 2 +- .../accelerate_configs/deepspeed_zero1.yaml | 0 .../accelerate_configs/deepspeed_zero2.yaml | 0 .../accelerate_configs/deepspeed_zero3.yaml | 0 .../accelerate_configs/fsdp_qlora.yaml | 0 .../accelerate_configs/multi_gpu.yaml | 0 .../accelerate_configs/single_gpu.yaml | 0 .../OCRVQADataset.py | 0 .../__init__.py | 0 .../factory.py | 0 src/evaluation.py | 2 +- src/peft_library/__init__.py | 1 + src/peft_library/mapping.py | 288 +++++++++++ src/peft_library/tuners/__init__.py | 2 + src/peft_library/tuners/mmoelora/__init__.py | 0 src/peft_library/tuners/mmoelora/mmoelora.py | 467 ++++++++++++++++++ src/peft_library/tuners/mmoelora/mmoeloraS.py | 227 +++++++++ src/peft_library/utils/__init__.py | 39 ++ src/peft_library/utils/config.py | 176 +++++++ src/peft_library/utils/constants.py | 4 + src/peft_library/utils/other.py | 250 ++++++++++ src/peft_library/utils/peft_types.py | 51 ++ src/peft_library/utils/save_and_load.py | 130 +++++ src/train.py | 35 +- src/train.sh | 3 +- src/utils/args.py | 16 +- 26 files changed, 1673 insertions(+), 20 deletions(-) rename src/{ => configs}/accelerate_configs/deepspeed_zero1.yaml (100%) rename src/{ => configs}/accelerate_configs/deepspeed_zero2.yaml (100%) rename src/{ => configs}/accelerate_configs/deepspeed_zero3.yaml (100%) rename src/{ => configs}/accelerate_configs/fsdp_qlora.yaml (100%) rename src/{ => configs}/accelerate_configs/multi_gpu.yaml (100%) rename src/{ => configs}/accelerate_configs/single_gpu.yaml (100%) rename src/{datasets_library => dataset_library}/OCRVQADataset.py (100%) rename src/{datasets_library => dataset_library}/__init__.py (100%) rename src/{datasets_library => dataset_library}/factory.py (100%) create mode 100644 src/peft_library/__init__.py create mode 100644 src/peft_library/mapping.py create mode 100644 src/peft_library/tuners/__init__.py create mode 100644 src/peft_library/tuners/mmoelora/__init__.py create mode 100644 src/peft_library/tuners/mmoelora/mmoelora.py create mode 100644 src/peft_library/tuners/mmoelora/mmoeloraS.py create mode 100644 src/peft_library/utils/__init__.py create mode 100644 src/peft_library/utils/config.py create mode 100644 src/peft_library/utils/constants.py create mode 100644 src/peft_library/utils/other.py create mode 100644 src/peft_library/utils/peft_types.py create mode 100644 src/peft_library/utils/save_and_load.py diff --git a/src/collatefn_library/qwen2.py b/src/collatefn_library/qwen2.py index ea9b7db..028382e 100644 --- a/src/collatefn_library/qwen2.py +++ b/src/collatefn_library/qwen2.py @@ -84,7 +84,7 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): if __name__ == "__main__": from transformers import Qwen2VLProcessor - from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration + from dataset_library.OCRVQADataset import OCRVQADatasetForGeneration processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") dataset = OCRVQADatasetForGeneration( diff --git a/src/accelerate_configs/deepspeed_zero1.yaml b/src/configs/accelerate_configs/deepspeed_zero1.yaml similarity index 100% rename from src/accelerate_configs/deepspeed_zero1.yaml rename to src/configs/accelerate_configs/deepspeed_zero1.yaml diff --git a/src/accelerate_configs/deepspeed_zero2.yaml b/src/configs/accelerate_configs/deepspeed_zero2.yaml similarity index 100% rename from src/accelerate_configs/deepspeed_zero2.yaml rename to src/configs/accelerate_configs/deepspeed_zero2.yaml diff --git a/src/accelerate_configs/deepspeed_zero3.yaml b/src/configs/accelerate_configs/deepspeed_zero3.yaml similarity index 100% rename from src/accelerate_configs/deepspeed_zero3.yaml rename to src/configs/accelerate_configs/deepspeed_zero3.yaml diff --git a/src/accelerate_configs/fsdp_qlora.yaml b/src/configs/accelerate_configs/fsdp_qlora.yaml similarity index 100% rename from src/accelerate_configs/fsdp_qlora.yaml rename to src/configs/accelerate_configs/fsdp_qlora.yaml diff --git a/src/accelerate_configs/multi_gpu.yaml b/src/configs/accelerate_configs/multi_gpu.yaml similarity index 100% rename from src/accelerate_configs/multi_gpu.yaml rename to src/configs/accelerate_configs/multi_gpu.yaml diff --git a/src/accelerate_configs/single_gpu.yaml b/src/configs/accelerate_configs/single_gpu.yaml similarity index 100% rename from src/accelerate_configs/single_gpu.yaml rename to src/configs/accelerate_configs/single_gpu.yaml diff --git a/src/datasets_library/OCRVQADataset.py b/src/dataset_library/OCRVQADataset.py similarity index 100% rename from src/datasets_library/OCRVQADataset.py rename to src/dataset_library/OCRVQADataset.py diff --git a/src/datasets_library/__init__.py b/src/dataset_library/__init__.py similarity index 100% rename from src/datasets_library/__init__.py rename to src/dataset_library/__init__.py diff --git a/src/datasets_library/factory.py b/src/dataset_library/factory.py similarity index 100% rename from src/datasets_library/factory.py rename to src/dataset_library/factory.py diff --git a/src/evaluation.py b/src/evaluation.py index e306f9b..3851553 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -1,5 +1,5 @@ import torch -from datasets_library.factory import get_dataset +from dataset_library.factory import get_dataset from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments from trl import ( diff --git a/src/peft_library/__init__.py b/src/peft_library/__init__.py new file mode 100644 index 0000000..faad1ec --- /dev/null +++ b/src/peft_library/__init__.py @@ -0,0 +1 @@ +from .mapping import get_peft_config, get_peft_model \ No newline at end of file diff --git a/src/peft_library/mapping.py b/src/peft_library/mapping.py new file mode 100644 index 0000000..b6630c0 --- /dev/null +++ b/src/peft_library/mapping.py @@ -0,0 +1,288 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from peft.tuners.xlora.model import XLoraModel + +from peft.config import PeftConfig +from peft.mixed_model import PeftMixedModel +from peft.peft_model import ( + PeftModel, + PeftModelForCausalLM, + PeftModelForFeatureExtraction, + PeftModelForQuestionAnswering, + PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, + PeftModelForTokenClassification, +) +from peft.tuners import ( + AdaLoraConfig, + AdaLoraModel, + AdaptionPromptConfig, + BOFTConfig, + BOFTModel, + BoneConfig, + BoneModel, + CPTConfig, + CPTEmbedding, + FourierFTConfig, + FourierFTModel, + HRAConfig, + HRAModel, + IA3Config, + IA3Model, + LNTuningConfig, + LNTuningModel, + LoHaConfig, + LoHaModel, + LoKrConfig, + LoKrModel, + LoraConfig, + LoraModel, + MultitaskPromptTuningConfig, + OFTConfig, + OFTModel, + PolyConfig, + PolyModel, + PrefixTuningConfig, + PromptEncoderConfig, + PromptTuningConfig, + VBLoRAConfig, + VBLoRAModel, + VeraConfig, + VeraModel, + XLoraConfig, +) +from .tuners import MMOELoraConfigS, MMOELoraModelS, MMOELoraModel, MMOELoraConfig +from peft.tuners.tuners_utils import BaseTuner +from peft.utils import _prepare_prompt_learning_config +from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = { + "SEQ_CLS": PeftModelForSequenceClassification, + "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, + "CAUSAL_LM": PeftModelForCausalLM, + "TOKEN_CLS": PeftModelForTokenClassification, + "QUESTION_ANS": PeftModelForQuestionAnswering, + "FEATURE_EXTRACTION": PeftModelForFeatureExtraction, +} + +PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = { + "ADAPTION_PROMPT": AdaptionPromptConfig, + "PROMPT_TUNING": PromptTuningConfig, + "PREFIX_TUNING": PrefixTuningConfig, + "P_TUNING": PromptEncoderConfig, + "LORA": LoraConfig, + "LOHA": LoHaConfig, + "LORAPLUS": LoraConfig, + "LOKR": LoKrConfig, + "ADALORA": AdaLoraConfig, + "BOFT": BOFTConfig, + "IA3": IA3Config, + "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig, + "OFT": OFTConfig, + "POLY": PolyConfig, + "LN_TUNING": LNTuningConfig, + "VERA": VeraConfig, + "FOURIERFT": FourierFTConfig, + "XLORA": XLoraConfig, + "HRA": HRAConfig, + "VBLORA": VBLoRAConfig, + "CPT": CPTConfig, + "BONE": BoneConfig, + "MMOELORA": MMOELoraConfig, + "MMOELORAS": MMOELoraConfigS, +} + +PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = { + "LORA": LoraModel, + "LOHA": LoHaModel, + "LOKR": LoKrModel, + "ADALORA": AdaLoraModel, + "BOFT": BOFTModel, + "IA3": IA3Model, + "OFT": OFTModel, + "POLY": PolyModel, + "LN_TUNING": LNTuningModel, + "VERA": VeraModel, + "FOURIERFT": FourierFTModel, + "XLORA": XLoraModel, + "HRA": HRAModel, + "VBLORA": VBLoRAModel, + "CPT": CPTEmbedding, + "BONE": BoneModel, + "MMOELORA": MMOELoraModel, + "MMOELORAS": MMOELoraModelS, +} + + +def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig: + """ + Returns a Peft config object from a dictionary. + + Args: + config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters. + """ + + return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) + + +def get_peft_model( + model: PreTrainedModel, + peft_config: PeftConfig, + adapter_name: str = "default", + mixed: bool = False, + autocast_adapter_dtype: bool = True, + revision: Optional[str] = None, + low_cpu_mem_usage: bool = False, +) -> PeftModel | PeftMixedModel: + """ + Returns a Peft model object from a model and a config. + + Args: + model ([`transformers.PreTrainedModel`]): + Model to be wrapped. + peft_config ([`PeftConfig`]): + Configuration object containing the parameters of the Peft model. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + mixed (`bool`, `optional`, defaults to `False`): + Whether to allow mixing different (compatible) adapter types. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. + revision (`str`, `optional`, defaults to `main`): + The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for + the base model + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as + False if you intend on training the model, unless the adapter weights will be replaced by different weights + before training starts. + """ + model_config = BaseTuner.get_model_config(model) + old_name = peft_config.base_model_name_or_path + new_name = model.__dict__.get("name_or_path", None) + peft_config.base_model_name_or_path = new_name + + if (old_name is not None) and (old_name != new_name): + warnings.warn( + f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " + "Please ensure that the correct base model is loaded when loading this checkpoint." + ) + + if revision is not None: + if peft_config.revision is not None and peft_config.revision != revision: + warnings.warn( + f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}" + ) + peft_config.revision = revision + + if ( + (isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"])) + and (peft_config.init_lora_weights == "eva") + and not low_cpu_mem_usage + ): + warnings.warn( + "lora with eva initialization used with low_cpu_mem_usage=False. " + "Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization." + ) + + prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type) + if prefix and adapter_name in prefix: + warnings.warn( + f"Adapter name {adapter_name} should not be contained in the prefix {prefix}." + "This may lead to reinitialization of the adapter weights during loading." + ) + + if mixed: + # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it + return PeftMixedModel(model, peft_config, adapter_name=adapter_name) + + if ( + peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() + and not peft_config.is_prompt_learning + ): + return PeftModel( + model, + peft_config, + adapter_name=adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + if peft_config.is_prompt_learning: + peft_config = _prepare_prompt_learning_config(peft_config, model_config) + return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( + model, + peft_config, + adapter_name=adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + +def inject_adapter_in_model( + peft_config: PeftConfig, + model: torch.nn.Module, + adapter_name: str = "default", + low_cpu_mem_usage: bool = False, +) -> torch.nn.Module: + r""" + A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning + methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API + calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods. + + Args: + peft_config (`PeftConfig`): + Configuration object containing the parameters of the Peft model. + model (`torch.nn.Module`): + The input model where the adapter will be injected. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + """ + if peft_config.is_prompt_learning or peft_config.is_adaption_prompt: + raise ValueError( + "`create_and_replace` does not support prompt learning and adaption prompt yet." + ) + + if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys(): + raise ValueError( + f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`." + ) + + tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] + + # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules. + peft_model = tuner_cls( + model, + peft_config, + adapter_name=adapter_name, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + return peft_model.model diff --git a/src/peft_library/tuners/__init__.py b/src/peft_library/tuners/__init__.py new file mode 100644 index 0000000..1aa9210 --- /dev/null +++ b/src/peft_library/tuners/__init__.py @@ -0,0 +1,2 @@ +from .mmoelora.mmoelora import MMOELoraModel, MMOELoraConfig +from .mmoelora.mmoeloraS import MMOELoraModelS, MMOELoraConfigS \ No newline at end of file diff --git a/src/peft_library/tuners/mmoelora/__init__.py b/src/peft_library/tuners/mmoelora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/peft_library/tuners/mmoelora/mmoelora.py b/src/peft_library/tuners/mmoelora/mmoelora.py new file mode 100644 index 0000000..7a25a42 --- /dev/null +++ b/src/peft_library/tuners/mmoelora/mmoelora.py @@ -0,0 +1,467 @@ +# -*- encoding: utf-8 -*- +# here put the import lib +import importlib +import re +import warnings +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +from peft_library.utils.peft_types import PeftType + +from peft_library.utils.constants import TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING +from peft.utils.other import _freeze_adapter, _get_submodules, transpose + + +from peft.tuners.lora import ( + LoraConfig, + LoraLayer, + LoraModel, +) + + + +def is_bnb_available(): + return importlib.util.find_spec("bitsandbytes") is not None + + +@dataclass +class MMOELoraConfig(LoraConfig): + """ + This is the configuration class to store the configuration of a [`~peft.MMOELora`] + """ + + task_num: int = field(default=2, metadata={"help": "The number of tasks."}) + task_embedding_dim: int = field(default=64) + expert_num: int = field(default=4) + + def __post_init__(self): + self.peft_type = PeftType.MMOELORA + + +class MMOELoraModel(LoraModel): + """ + Create MMOELoRA (MMOE based LoRA) model from a pretrained transformers model. + """ + + def __init__(self, model, config, adapter_name): + nn.Module.__init__(self) + self.model = model + self.forward = self.model.forward + self.peft_config = config + self.add_adapter(adapter_name, self.peft_config[adapter_name]) + + def add_adapter(self, adapter_name, config=None): + if config is not None: # get the lora config + model_config = ( + self.model.config.to_dict() + if hasattr(self.model.config, "to_dict") + else self.model.config + ) + config = self._prepare_mmoelora_config(config, model_config) # load config + self.peft_config[adapter_name] = config # subsititue the original config + self._find_and_replace(adapter_name) + if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": + raise ValueError( + "MMOELoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." + ) + + self._mark_only_adapters_as_trainable(self.model) + if self.peft_config[adapter_name].inference_mode: + _freeze_adapter(self.model, adapter_name) + + def _find_and_replace(self, adapter_name): + """Replace the target `Linear` module with LoRA layer (Linear+LoRA)""" + lora_config = self.peft_config[adapter_name] + loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) + if loaded_in_8bit and not is_bnb_available(): + raise ImportError( + "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " + "You can install it with `pip install bitsandbytes`." + ) + is_target_modules_in_base_model = False + kwargs = { + "r": lora_config.r, + "lora_alpha": lora_config.lora_alpha, + "lora_dropout": lora_config.lora_dropout, + "fan_in_fan_out": lora_config.fan_in_fan_out, + "init_lora_weights": lora_config.init_lora_weights, + "task_num": lora_config.task_num, + "task_embedding_dim": lora_config.task_embedding_dim, + "expert_num": lora_config.expert_num, + } + key_list = [ + key for key, _ in self.model.named_modules() + ] # all module in raw model + for key in key_list: + # find the corresponding modules. target module has been split into list. + if isinstance(lora_config.target_modules, str): + target_module_found = re.fullmatch(lora_config.target_modules, key) + else: + target_module_found = any( + key.endswith(target_key) + for target_key in lora_config.target_modules + ) + if target_module_found: + if not is_target_modules_in_base_model: + is_target_modules_in_base_model = True + parent, target, target_name = _get_submodules(self.model, key) + bias = target.bias is not None + if isinstance(target, MMOELoraLayer): + target.update_layer( + adapter_name, + lora_config.init_r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + else: + if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + raise NotImplementedError + else: + if isinstance(target, torch.nn.Linear): + in_features, out_features = ( + target.in_features, + target.out_features, + ) + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = ( + lora_config.fan_in_fan_out + ) = False + elif isinstance(target, Conv1D): + in_features, out_features = ( + target.weight.ds_shape + if hasattr(target.weight, "ds_shape") + else target.weight.shape + ) + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = ( + lora_config.fan_in_fan_out + ) = True + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." + ) + new_module = MMOELoraLinear( + adapter_name, in_features, out_features, bias=bias, **kwargs + ) + + self._replace_module(parent, target_name, new_module, target) + if not is_target_modules_in_base_model: + raise ValueError( + f"Target modules {lora_config.target_modules} not found in the base model. " + f"Please check the target modules and try again." + ) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + @staticmethod + def _prepare_mmoelora_config(peft_config, model_config): + if peft_config.target_modules is None: + if ( + model_config["model_type"] + not in TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING + ): + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = ( + TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING[ + model_config["model_type"] + ] + ) + if peft_config.inference_mode: + peft_config.merge_weights = True + return peft_config + + +class MMOELoraLayer(LoraLayer): + + def __init__(self, in_features: int, out_features: int, expert_num: int): + + super().__init__(in_features, out_features) + self.expert_num = expert_num + + def update_layer( + self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights + ): + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + if r > 0: + self.lora_A.update( + nn.ModuleDict( + {adapter_name: MMOELinearA(self.in_features, r, self.expert_num)} + ) + ) + self.lora_B.update( + nn.ModuleDict( + {adapter_name: MMOELinearB(r, self.out_features, self.expert_num)} + ) + ) + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + self.to(self.weight.device) + + def reset_lora_parameters(self, adapter_name): + if adapter_name in self.lora_A.keys(): + # initialize A the same way as the default for nn.Linear and B to zero + for i in range(self.expert_num): + nn.init.normal_( + self.lora_A[adapter_name].loraA[i].mlp.weight, mean=0.0, std=0.01 + ) + nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight) + + +class MMOELoraLinear(nn.Linear, MMOELoraLayer): + # Lora implemented in a dense layer + # nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Lora + def __init__( + self, + adapter_name: str, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + **kwargs, + ): + init_lora_weights = kwargs.pop("init_lora_weights", True) + self.expert_num = kwargs.pop("expert_num", True) + self.task_num = kwargs.pop("task_num", True) + self.te_dim = kwargs.pop("task_embedding_dim", True) + + nn.Linear.__init__(self, in_features, out_features, **kwargs) + MMOELoraLayer.__init__( + self, + in_features=in_features, + out_features=out_features, + expert_num=self.expert_num, + ) + + # init the Gate network + self.lora_task_embedding = nn.ModuleDict({}) + self.lora_gate = nn.ModuleDict({}) + self.lora_task_embedding.update( + nn.ModuleDict({adapter_name: nn.Embedding(self.task_num + 1, self.te_dim)}) + ) + self.lora_gate.update( + nn.ModuleDict({adapter_name: Gate(self.te_dim, self.expert_num)}) + ) + + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + self.fan_in_fan_out = fan_in_fan_out + if fan_in_fan_out: + self.weight.data = self.weight.data.T + + nn.Linear.reset_parameters(self) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + self.active_adapter = adapter_name + + def merge(self, task_id): + if self.active_adapter not in self.lora_A.keys(): + return + if self.merged: + warnings.warn("Already merged. Nothing to do.") + return + if self.r[self.active_adapter] > 0: + expert_weight = self.lora_gate[self.active_adapter]( + self.lora_task_embedding[self.active_adapter](task_id) + ) + for i in range(self.expert_num): + lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight + lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight + self.weight.data += ( + transpose( + lora_B_weights @ lora_A_weights, + self.fan_in_fan_out, + ) + * self.scaling[self.active_adapter] + * expert_weight[..., i] + ) + self.merged = True + + def unmerge(self, task_id): + if self.active_adapter not in self.lora_A.keys(): + return + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + if self.r[self.active_adapter] > 0: + expert_weight = self.lora_gate[self.active_adapter]( + self.lora_task_embedding[self.active_adapter](task_id) + ) + for i in range(self.expert_num): + lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight + lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight + self.weight.data -= ( + transpose( + lora_B_weights @ lora_A_weights, + self.fan_in_fan_out, + ) + * self.scaling[self.active_adapter] + * expert_weight[..., i] + ) + self.merged = False + + def forward(self, x: torch.Tensor, **kwargs): + task_id = kwargs["task_id"] + previous_dtype = x.dtype + + if ( + self.active_adapter not in self.lora_A.keys() + ): # No adapter, directly use linear + return F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + if self.disable_adapters: # No adapter + if ( + self.r[self.active_adapter] > 0 and self.merged + ): # merge the adapter to linear + self.unmerge(task_id) + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + elif ( + self.r[self.active_adapter] > 0 and not self.merged + ): # general lora process + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + + x = x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype) + + expert_weight = self.lora_gate[self.active_adapter]( + self.lora_task_embedding[self.active_adapter](task_id) + ) + for i in range(self.expert_num): + result += ( # lora process + self.lora_B[self.active_adapter].loraB[i]( + self.lora_A[self.active_adapter].loraA[i]( + self.lora_dropout[self.active_adapter](x) + ), + ) + * self.scaling[self.active_adapter] + * expert_weight[..., i].unsqueeze(-1).unsqueeze(0) + ) + else: + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + + result = result.to(previous_dtype) + + return result + + +class MMOELinearA(nn.Module): + """MMOE based LoRA block""" + + def __init__(self, in_features, out_features, expert_num) -> None: + + super().__init__() + + self.expert_num = expert_num + self.in_features, self.out_features = in_features, out_features + self.loraA = nn.ModuleList([]) + + assert ( + self.out_features % self.expert_num == 0 + ) # lora rank should be divided by expert number + self.r = self.out_features // self.expert_num + + for _ in range(self.expert_num): + self.loraA.append(Expert(self.in_features, self.r)) + + def forward(self, x): + """input x is a vector, return output is a list""" + outputs = [] + for i in range(self.expert_num): + outputs.append(self.loraA[i](x)) + + return outputs + + +class MMOELinearB(nn.Module): + """MMOE based LoRA block""" + + def __init__(self, in_features, out_features, expert_num) -> None: + + super().__init__() + + self.expert_num = expert_num + self.in_features, self.out_features = in_features, out_features + self.loraB = nn.ModuleList([]) + + assert self.in_features % self.expert_num == 0 + self.r = self.in_features // self.expert_num + + for _ in range(self.expert_num): + self.loraB.append(Expert(self.r, self.out_features)) + + def forward(self, x): + """input x is a list, return output is also a list""" + outputs = [] + for i in range(self.expert_num): + outputs.append(self.loraB[i](x[i])) + + return outputs + + +class Expert(nn.Module): + + def __init__(self, in_features, out_features): + + super().__init__() + + self.in_features, self.out_features = in_features, out_features + self.mlp = nn.Linear(self.in_features, self.out_features, bias=False) + self.weight = self.mlp.weight + + def forward(self, x): + # LoRA A or B block + y = self.mlp(x) + + return y + + +class Gate(nn.Module): + + def __init__(self, input_size, expert_num): + + super().__init__() + # 使用embedding来代替线性层 + self.GateL = nn.Linear(input_size, expert_num, bias=False) + self.act = nn.Softmax(dim=1) # 第0维为batch size + + def forward(self, x): + + y = self.GateL(x) + y = self.act(y) + + return y diff --git a/src/peft_library/tuners/mmoelora/mmoeloraS.py b/src/peft_library/tuners/mmoelora/mmoeloraS.py new file mode 100644 index 0000000..bdd450c --- /dev/null +++ b/src/peft_library/tuners/mmoelora/mmoeloraS.py @@ -0,0 +1,227 @@ +# -*- encoding: utf-8 -*- +# here put the import lib +import re +import importlib +import warnings +from dataclasses import dataclass, field +from .mmoelora import MMOELoraModel, MMOELoraLinear, MMOELoraLayer +from peft.tuners.lora import LoraConfig +import torch +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +# from ..utils import _get_submodules, transpose, PeftType +from peft.utils.other import _get_submodules, transpose +from peft.utils.peft_types import PeftType + + +def is_bnb_available(): + return importlib.util.find_spec("bitsandbytes") is not None + + +# TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING + + +@dataclass +class MMOELoraConfigS(LoraConfig): + """ + This is the configuration class to store the configuration of a [`~peft.MMOELora`] + """ + + task_num: int = field(default=2, metadata={"help": "The number of tasks."}) + task_embedding_dim: int = field(default=64) + expert_num: int = field(default=4) + + def __post_init__(self): + self.peft_type = PeftType.MMOELORAS + + +class MMOELoraModelS(MMOELoraModel): + + def __init__(self, model, config, adapter_name): + + super().__init__(model, config, adapter_name) + + def _find_and_replace(self, adapter_name): + """Replace the target `Linear` module with LoRA layer (Linear+LoRA)""" + lora_config = self.peft_config[adapter_name] + loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) + if loaded_in_8bit and not is_bnb_available(): + raise ImportError( + "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " + "You can install it with `pip install bitsandbytes`." + ) + is_target_modules_in_base_model = False + kwargs = { + "r": lora_config.r, + "lora_alpha": lora_config.lora_alpha, + "lora_dropout": lora_config.lora_dropout, + "fan_in_fan_out": lora_config.fan_in_fan_out, + "init_lora_weights": lora_config.init_lora_weights, + "task_num": lora_config.task_num, + "task_embedding_dim": lora_config.task_embedding_dim, + "expert_num": lora_config.expert_num, + } + key_list = [ + key for key, _ in self.model.named_modules() + ] # all module in raw model + for key in key_list: + # find the corresponding modules. target module has been split into list. + if isinstance(lora_config.target_modules, str): + target_module_found = re.fullmatch(lora_config.target_modules, key) + else: + target_module_found = any( + key.endswith(target_key) + for target_key in lora_config.target_modules + ) + if target_module_found: + if not is_target_modules_in_base_model: + is_target_modules_in_base_model = True + parent, target, target_name = _get_submodules(self.model, key) + bias = target.bias is not None + if isinstance(target, MMOELoraLayer): + target.update_layer( + adapter_name, + lora_config.init_r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + else: + if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + raise NotImplementedError + else: + if isinstance(target, torch.nn.Linear): + in_features, out_features = ( + target.in_features, + target.out_features, + ) + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = ( + lora_config.fan_in_fan_out + ) = False + elif isinstance(target, Conv1D): + in_features, out_features = ( + target.weight.ds_shape + if hasattr(target.weight, "ds_shape") + else target.weight.shape + ) + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = ( + lora_config.fan_in_fan_out + ) = True + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." + ) + new_module = MMOELoraLinearS( + adapter_name, in_features, out_features, bias=bias, **kwargs + ) + + self._replace_module(parent, target_name, new_module, target) + if not is_target_modules_in_base_model: + raise ValueError( + f"Target modules {lora_config.target_modules} not found in the base model. " + f"Please check the target modules and try again." + ) + + +class MMOELoraLinearS(MMOELoraLinear): + + def __init__( + self, + adapter_name: str, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0, + fan_in_fan_out: bool = False, + **kwargs, + ): + + super().__init__( + adapter_name, + in_features, + out_features, + r, + lora_alpha, + lora_dropout, + fan_in_fan_out, + **kwargs, + ) + + def unmerge(self, expert_weight): + if self.active_adapter not in self.lora_A.keys(): + return + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + if self.r[self.active_adapter] > 0: + for i in range(self.expert_num): + lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight + lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight + self.weight.data -= ( + transpose( + lora_B_weights @ lora_A_weights, + self.fan_in_fan_out, + ) + * self.scaling[self.active_adapter] + * expert_weight[..., i] + ) + self.merged = False + + def forward(self, x: torch.Tensor, **kwargs): + expert_weight = kwargs["task_id"] + previous_dtype = x.dtype + + if ( + self.active_adapter not in self.lora_A.keys() + ): # No adapter, directly use linear + return F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + if self.disable_adapters: # No adapter + if ( + self.r[self.active_adapter] > 0 and self.merged + ): # merge the adapter to linear + self.unmerge(expert_weight) + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + elif ( + self.r[self.active_adapter] > 0 and not self.merged + ): # general lora process + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + + x = x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype) + + for i in range(self.expert_num): + result += ( # lora process + self.lora_B[self.active_adapter].loraB[i]( + self.lora_A[self.active_adapter].loraA[i]( + self.lora_dropout[self.active_adapter](x) + ), + ) + * self.scaling[self.active_adapter] + * expert_weight[..., i].unsqueeze(-1).unsqueeze(0) + ) + else: + result = F.linear( + x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias + ) + + result = result.to(previous_dtype) + + return result diff --git a/src/peft_library/utils/__init__.py b/src/peft_library/utils/__init__.py new file mode 100644 index 0000000..902a4a1 --- /dev/null +++ b/src/peft_library/utils/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType +from .other import ( + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING, + CONFIG_NAME, + WEIGHTS_NAME, + _set_trainable, + bloom_model_postprocess_past_key_value, + prepare_model_for_int8_training, + shift_tokens_right, + transpose, + _get_submodules, + _set_adapter, + _freeze_adapter, + ModulesToSaveWrapper, +) +from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict diff --git a/src/peft_library/utils/config.py b/src/peft_library/utils/config.py new file mode 100644 index 0000000..7174e83 --- /dev/null +++ b/src/peft_library/utils/config.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +import json +import os +from dataclasses import asdict, dataclass, field +from typing import Optional, Union + +from huggingface_hub import hf_hub_download +from transformers.utils import PushToHubMixin + +from .other import CONFIG_NAME + + +class PeftType(str, enum.Enum): + PROMPT_TUNING = "PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" + ADALORA = "ADALORA" + ADAPTION_PROMPT = "ADAPTION_PROMPT" + MMOELORAS = "MMOELORAS" + + +class TaskType(str, enum.Enum): + SEQ_CLS = "SEQ_CLS" + SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" + CAUSAL_LM = "CAUSAL_LM" + TOKEN_CLS = "TOKEN_CLS" + CAUSAL_LMS = "CAUSAL_LMS" + + +@dataclass +class PeftConfigMixin(PushToHubMixin): + r""" + This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all + PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to + push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a + directory. The method `from_pretrained` will load the configuration of your adapter model from a directory. + + Args: + peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. + """ + peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."}) + + @property + def __dict__(self): + return asdict(self) + + def to_dict(self): + return self.__dict__ + + def save_pretrained(self, save_directory, **kwargs): + r""" + This method saves the configuration of your adapter model in a directory. + + Args: + save_directory (`str`): + The directory where the configuration will be saved. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`] + method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + output_dict = self.__dict__ + output_path = os.path.join(save_directory, CONFIG_NAME) + + # save it + with open(output_path, "w") as writer: + writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): + r""" + This method loads the configuration of your adapter model from a directory. + + Args: + pretrained_model_name_or_path (`str`): + The directory or the Hub repository id where the configuration is saved. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments passed along to the child class initialization. + """ + path = ( + os.path.join(pretrained_model_name_or_path, subfolder) + if subfolder is not None + else pretrained_model_name_or_path + ) + if os.path.isfile(os.path.join(path, CONFIG_NAME)): + config_file = os.path.join(path, CONFIG_NAME) + else: + try: + config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder) + except Exception: + raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") + + loaded_attributes = cls.from_json_file(config_file) + + config = cls(**kwargs) + + for key, value in loaded_attributes.items(): + if hasattr(config, key): + setattr(config, key, value) + + return config + + @classmethod + def from_json_file(cls, path_json_file, **kwargs): + r""" + Loads a configuration file from a json file. + + Args: + path_json_file (`str`): + The path to the json file. + """ + with open(path_json_file, "r") as file: + json_object = json.load(file) + + return json_object + + +@dataclass +class PeftConfig(PeftConfigMixin): + """ + This is the base configuration class to store the configuration of a [`PeftModel`]. + + Args: + peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. + task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. + inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. + """ + + base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."}) + peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"}) + task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"}) + inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"}) + + +@dataclass +class PromptLearningConfig(PeftConfig): + """ + This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or + [`PromptTuning`]. + + Args: + num_virtual_tokens (`int`): The number of virtual tokens to use. + token_dim (`int`): The hidden embedding dimension of the base transformer model. + num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model. + num_attention_heads (`int`): The number of attention heads in the base transformer model. + num_layers (`int`): The number of layers in the base transformer model. + """ + + num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"}) + token_dim: int = field( + default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"} + ) + num_transformer_submodules: Optional[int] = field( + default=None, metadata={"help": "Number of transformer submodules"} + ) + num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"}) + num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"}) diff --git a/src/peft_library/utils/constants.py b/src/peft_library/utils/constants.py new file mode 100644 index 0000000..2377a9e --- /dev/null +++ b/src/peft_library/utils/constants.py @@ -0,0 +1,4 @@ +from peft.utils.constants import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING + +TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING +TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING \ No newline at end of file diff --git a/src/peft_library/utils/other.py b/src/peft_library/utils/other.py new file mode 100644 index 0000000..4bccbf3 --- /dev/null +++ b/src/peft_library/utils/other.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import torch + + +# needed for prefix-tuning of bloom model +def bloom_model_postprocess_past_key_value(past_key_values): + past_key_values = torch.cat(past_key_values) + total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape + keys = past_key_values[: total_layers // 2] + keys = keys.transpose(2, 3).reshape( + total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens + ) + values = past_key_values[total_layers // 2 :] + values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) + + return tuple(zip(keys, values)) + + +def prepare_model_for_int8_training( + model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"] +): + r""" + This method wraps the entire protocol for preparing a model before running a training. This includes: + 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm + head to fp32 + + Args: + model, (`transformers.PreTrainedModel`): + The loaded model from `transformers` + """ + loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False) + + for name, param in model.named_parameters(): + # freeze base model's layers + param.requires_grad = False + + if loaded_in_8bit: + # cast layer norm in fp32 for stability for 8bit models + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): + param.data = param.data.to(torch.float32) + + if loaded_in_8bit and use_gradient_checkpointing: + # For backward compatibility + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable() + + if hasattr(model, output_embedding_layer_name): + output_embedding_layer = getattr(model, output_embedding_layer_name) + input_dtype = output_embedding_layer.weight.dtype + + class CastOutputToFloat(torch.nn.Sequential): + r""" + Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted + in fp32 + + """ + + def forward(self, x): + return super().forward(x.to(input_dtype)).to(torch.float32) + + setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) + + return model + + +# copied from transformers.models.bart.modeling_bart +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids + pad_token_id (`int`): The id of the `padding` token. + decoder_start_token_id (`int`): The id of the `start` token. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class ModulesToSaveWrapper(torch.nn.Module): + def __init__(self, module_to_save, adapter_name): + super().__init__() + self.original_module = module_to_save + self.modules_to_save = torch.nn.ModuleDict({}) + self.update(adapter_name) + self.active_adapter = adapter_name + + def update(self, adapter_name): + self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) + + def forward(self, *args, **kwargs): + if self.active_adapter not in self.modules_to_save: + return self.original_module(*args, **kwargs) + return self.modules_to_save[self.active_adapter](*args, **kwargs) + + +def _get_submodules(model, key): + parent = model.get_submodule(".".join(key.split(".")[:-1])) + target_name = key.split(".")[-1] + target = model.get_submodule(key) + return parent, target, target_name + + +def _freeze_adapter(model, adapter_name): + for n, p in model.named_parameters(): + if adapter_name in n: + p.requires_grad = False + + +def _set_trainable(model, adapter_name): + key_list = [key for key, _ in model.named_modules()] + for key in key_list: + target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) + if target_module_found: + parent, target, target_name = _get_submodules(model, key) + if isinstance(target, ModulesToSaveWrapper): + target.update(adapter_name) + else: + for param in target.parameters(): + param.requires_grad = True + setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name)) + + +def _set_adapter(model, adapter_name): + for module in model.modules(): + if isinstance(module, ModulesToSaveWrapper): + module.active_adapter = adapter_name + + +def fsdp_auto_wrap_policy(model): + import functools + import os + + from accelerate import FullyShardedDataParallelPlugin + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + FullyShardedDataParallelPlugin.get_module_class_from_name( + model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") + ), + ), + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy + + +def transpose(weight, fan_in_fan_out): + return weight.T if fan_in_fan_out else weight + + +TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { + "t5": ["q", "v"], + "mt5": ["q", "v"], + "bart": ["q_proj", "v_proj"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "blip-2": ["q", "v", "q_proj", "v_proj"], + "opt": ["q_proj", "v_proj"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "value"], + "xlm-roberta": ["query", "value"], + "electra": ["query", "value"], + "deberta-v2": ["query_proj", "value_proj"], + "deberta": ["in_proj"], + "layoutlm": ["query", "value"], + "llama": ["q_proj", "v_proj"], + "chatglm": ["query_key_value"], +} + +TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { + "t5": ["q", "k", "v", "o", "wi", "wo"], + "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], + "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + # "gpt2": ["c_attn"], + # "bloom": ["query_key_value"], + "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], + # "gptj": ["q_proj", "v_proj"], + # "gpt_neox": ["query_key_value"], + # "gpt_neo": ["q_proj", "v_proj"], + # "bert": ["query", "value"], + "roberta": ["query", "key", "value", "dense"], + # "xlm-roberta": ["query", "value"], + # "electra": ["query", "value"], + "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], + # "deberta": ["in_proj"], + # "layoutlm": ["query", "value"], +} + +TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING +TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING +TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { + "bloom": bloom_model_postprocess_past_key_value, +} + +WEIGHTS_NAME = "adapter_model.bin" +CONFIG_NAME = "adapter_config.json" diff --git a/src/peft_library/utils/peft_types.py b/src/peft_library/utils/peft_types.py new file mode 100644 index 0000000..2e5b717 --- /dev/null +++ b/src/peft_library/utils/peft_types.py @@ -0,0 +1,51 @@ +import enum + +class PeftType(str, enum.Enum): + """ + Enum class for the different types of adapters in PEFT. + + Supported PEFT types: + - PROMPT_TUNING + - MULTITASK_PROMPT_TUNING + - P_TUNING + - PREFIX_TUNING + - LORA + - ADALORA + - BOFT + - ADAPTION_PROMPT + - IA3 + - LOHA + - LOKR + - OFT + - XLORA + - POLY + - LN_TUNING + - VERA + - FOURIERFT + - HRA + - BONE + """ + + PROMPT_TUNING = "PROMPT_TUNING" + MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" + ADALORA = "ADALORA" + BOFT = "BOFT" + ADAPTION_PROMPT = "ADAPTION_PROMPT" + IA3 = "IA3" + LOHA = "LOHA" + LOKR = "LOKR" + OFT = "OFT" + POLY = "POLY" + LN_TUNING = "LN_TUNING" + VERA = "VERA" + FOURIERFT = "FOURIERFT" + XLORA = "XLORA" + HRA = "HRA" + VBLORA = "VBLORA" + CPT = "CPT" + BONE = "BONE" + MMOELORAS = "MMOELORAS" + MMOELORA = "MMOELORA" diff --git a/src/peft_library/utils/save_and_load.py b/src/peft_library/utils/save_and_load.py new file mode 100644 index 0000000..fc914a0 --- /dev/null +++ b/src/peft_library/utils/save_and_load.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import PeftType, PromptLearningConfig + + +def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): + """ + Get the state dict of the Peft model. + + Args: + model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, + the model should be the underlying model/unwrapped model (i.e. model.module). + state_dict (`dict`, *optional*, defaults to `None`): + The state dict of the model. If not provided, the state dict of the model + will be used. + """ + config = model.peft_config[adapter_name] + if state_dict is None: + state_dict = model.state_dict() + if config.peft_type in (PeftType.LORA, PeftType.ADALORA, + PeftType.MMOELORAS): + # to_return = lora_state_dict(model, bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` + # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP + bias = config.bias + if bias == "none": # filter out all lora parameters + to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} + elif bias == "all": + to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + for k in state_dict: + if "lora_" in k: + to_return[k] = state_dict[k] + bias_name = k.split("lora_")[0] + "bias" + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} + + if config.peft_type == PeftType.ADALORA: + rank_pattern = config.rank_pattern + if rank_pattern is not None: + rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()} + config.rank_pattern = rank_pattern + to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) + + elif config.peft_type == PeftType.ADAPTION_PROMPT: + to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} + elif isinstance(config, PromptLearningConfig): + to_return = {} + if config.inference_mode: + prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight + else: + prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) + to_return["prompt_embeddings"] = prompt_embeddings + else: + raise NotImplementedError + if model.modules_to_save is not None: + for key, value in state_dict.items(): + if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): + to_return[key.replace("modules_to_save.", "")] = value + + to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} + return to_return + + +def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"): + """ + Set the state dict of the Peft model. + + Args: + model ([`PeftModel`]): The Peft model. + peft_model_state_dict (`dict`): The state dict of the Peft model. + """ + config = model.peft_config[adapter_name] + state_dict = {} + if model.modules_to_save is not None: + for key, value in peft_model_state_dict.items(): + if any(module_name in key for module_name in model.modules_to_save): + for module_name in model.modules_to_save: + if module_name in key: + key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") + break + state_dict[key] = value + else: + state_dict = peft_model_state_dict + + if config.peft_type in (PeftType.LORA, PeftType.ADALORA, + PeftType.MMOELORAS): + peft_model_state_dict = {} + for k, v in state_dict.items(): + if "lora_" in k: + suffix = k.split("lora_")[1] + if "." in suffix: + suffix_to_replace = ".".join(suffix.split(".")[1:]) + k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") + else: + k = f"{k}.{adapter_name}" + peft_model_state_dict[k] = v + else: + peft_model_state_dict[k] = v + if config.peft_type == PeftType.ADALORA: + rank_pattern = config.rank_pattern + if rank_pattern is not None: + model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) + elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT: + peft_model_state_dict = state_dict + else: + raise NotImplementedError + + model.load_state_dict(peft_model_state_dict, strict=False) + if isinstance(config, PromptLearningConfig): + model.prompt_encoder[adapter_name].embedding.load_state_dict( + {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True + ) diff --git a/src/train.py b/src/train.py index 4dd9c27..a36df28 100644 --- a/src/train.py +++ b/src/train.py @@ -1,32 +1,46 @@ import torch -from datasets_library.factory import get_dataset +from dataset_library.factory import get_dataset from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments from trl import ( ModelConfig, TrlParser, get_kbit_device_map, - get_peft_config, + # get_peft_config, get_quantization_config, ) -from peft import get_peft_model +from peft_library import get_peft_model, get_peft_config from utils.trainer import ContinualTrainer -from utils.args import ContinualScriptArguments +from utils.args import ContinualScriptArguments, ContinualModelConfig if __name__ == "__main__": - parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig)) + parser = TrlParser( + (ContinualScriptArguments, TrainingArguments, ContinualModelConfig) + ) script_args, training_args, model_args = parser.parse_args_and_config() # for type hint if 0 == 1: script_args = ContinualScriptArguments() training_args = TrainingArguments() - model_args = ModelConfig() + model_args = ContinualModelConfig() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} + # peft_config = get_peft_config(dict(**vars(model_args))) + if model_args.peft_type == "MMOELora": + from peft_library.tuners import MMOELoraConfig + + peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules) + elif model_args.peft_type == "LORA": + from peft.tuners.lora import LoraConfig + + peft_config = LoraConfig(target_modules=model_args.lora_target_modules) + else: + peft_config = None + torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] @@ -62,9 +76,6 @@ if __name__ == "__main__": collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) - peft_config = get_peft_config(model_args) - model = get_peft_model(model, peft_config) - ################ # Dataset ################ @@ -73,8 +84,10 @@ if __name__ == "__main__": accelerator = create_accelerator_and_postprocess(training_args) - if accelerator.is_local_main_process: - model.print_trainable_parameters() + if peft_config is not None: + model = get_peft_model(model, peft_config) + if accelerator.is_local_main_process: + model.print_trainable_parameters() for dataset_name in script_args.dataset_name: dataset = get_dataset(dataset_name) diff --git a/src/train.sh b/src/train.sh index 4834c50..b81ed77 100755 --- a/src/train.sh +++ b/src/train.sh @@ -1,8 +1,9 @@ #!/bin/bash -accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py \ +accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \ --dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \ --use_peft \ + --peft_type LORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --lora_target_modules q_proj v_proj \ --per_device_train_batch_size 1 \ diff --git a/src/utils/args.py b/src/utils/args.py index e2298a6..55668cb 100644 --- a/src/utils/args.py +++ b/src/utils/args.py @@ -1,17 +1,21 @@ from dataclasses import dataclass, field from typing import Optional +from trl import ScriptArguments, ModelConfig +from transformers import TrainingArguments @dataclass -class ContinualScriptArguments: +class ContinualScriptArguments(ScriptArguments): """Script arguments for continual learning.""" dataset_name: list[str] = field( default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"] ) - dataset_config: Optional[str] = None - dataset_train_split: str = "train" - dataset_test_split: str = "test" dataset_generation_split: str = "generation" - gradient_checkpointing_use_reentrant: bool = False - ignore_bias_buffers: bool = False + + +@dataclass +class ContinualModelConfig(ModelConfig): + """Model configuration for continual learning.""" + + peft_type: Optional[str] = None \ No newline at end of file