添加PEFT库的初始化文件,更新数据集导入路径,修改训练脚本以支持新的PEFT类型和配置,新增持续学习模型配置类,添加PEFT类型枚举,更新评估和训练逻辑以适应新结构
This commit is contained in:
parent
aef0f6834e
commit
2cd1bb4993
@ -84,7 +84,7 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import Qwen2VLProcessor
|
||||
from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration
|
||||
from dataset_library.OCRVQADataset import OCRVQADatasetForGeneration
|
||||
|
||||
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
dataset = OCRVQADatasetForGeneration(
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from datasets_library.factory import get_dataset
|
||||
from dataset_library.factory import get_dataset
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments
|
||||
|
||||
from trl import (
|
||||
|
1
src/peft_library/__init__.py
Normal file
1
src/peft_library/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .mapping import get_peft_config, get_peft_model
|
288
src/peft_library/mapping.py
Normal file
288
src/peft_library/mapping.py
Normal file
@ -0,0 +1,288 @@
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from peft.tuners.xlora.model import XLoraModel
|
||||
|
||||
from peft.config import PeftConfig
|
||||
from peft.mixed_model import PeftMixedModel
|
||||
from peft.peft_model import (
|
||||
PeftModel,
|
||||
PeftModelForCausalLM,
|
||||
PeftModelForFeatureExtraction,
|
||||
PeftModelForQuestionAnswering,
|
||||
PeftModelForSeq2SeqLM,
|
||||
PeftModelForSequenceClassification,
|
||||
PeftModelForTokenClassification,
|
||||
)
|
||||
from peft.tuners import (
|
||||
AdaLoraConfig,
|
||||
AdaLoraModel,
|
||||
AdaptionPromptConfig,
|
||||
BOFTConfig,
|
||||
BOFTModel,
|
||||
BoneConfig,
|
||||
BoneModel,
|
||||
CPTConfig,
|
||||
CPTEmbedding,
|
||||
FourierFTConfig,
|
||||
FourierFTModel,
|
||||
HRAConfig,
|
||||
HRAModel,
|
||||
IA3Config,
|
||||
IA3Model,
|
||||
LNTuningConfig,
|
||||
LNTuningModel,
|
||||
LoHaConfig,
|
||||
LoHaModel,
|
||||
LoKrConfig,
|
||||
LoKrModel,
|
||||
LoraConfig,
|
||||
LoraModel,
|
||||
MultitaskPromptTuningConfig,
|
||||
OFTConfig,
|
||||
OFTModel,
|
||||
PolyConfig,
|
||||
PolyModel,
|
||||
PrefixTuningConfig,
|
||||
PromptEncoderConfig,
|
||||
PromptTuningConfig,
|
||||
VBLoRAConfig,
|
||||
VBLoRAModel,
|
||||
VeraConfig,
|
||||
VeraModel,
|
||||
XLoraConfig,
|
||||
)
|
||||
from .tuners import MMOELoraConfigS, MMOELoraModelS, MMOELoraModel, MMOELoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
from peft.utils import _prepare_prompt_learning_config
|
||||
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
|
||||
"SEQ_CLS": PeftModelForSequenceClassification,
|
||||
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
|
||||
"CAUSAL_LM": PeftModelForCausalLM,
|
||||
"TOKEN_CLS": PeftModelForTokenClassification,
|
||||
"QUESTION_ANS": PeftModelForQuestionAnswering,
|
||||
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
|
||||
}
|
||||
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
|
||||
"ADAPTION_PROMPT": AdaptionPromptConfig,
|
||||
"PROMPT_TUNING": PromptTuningConfig,
|
||||
"PREFIX_TUNING": PrefixTuningConfig,
|
||||
"P_TUNING": PromptEncoderConfig,
|
||||
"LORA": LoraConfig,
|
||||
"LOHA": LoHaConfig,
|
||||
"LORAPLUS": LoraConfig,
|
||||
"LOKR": LoKrConfig,
|
||||
"ADALORA": AdaLoraConfig,
|
||||
"BOFT": BOFTConfig,
|
||||
"IA3": IA3Config,
|
||||
"MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
|
||||
"OFT": OFTConfig,
|
||||
"POLY": PolyConfig,
|
||||
"LN_TUNING": LNTuningConfig,
|
||||
"VERA": VeraConfig,
|
||||
"FOURIERFT": FourierFTConfig,
|
||||
"XLORA": XLoraConfig,
|
||||
"HRA": HRAConfig,
|
||||
"VBLORA": VBLoRAConfig,
|
||||
"CPT": CPTConfig,
|
||||
"BONE": BoneConfig,
|
||||
"MMOELORA": MMOELoraConfig,
|
||||
"MMOELORAS": MMOELoraConfigS,
|
||||
}
|
||||
|
||||
PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = {
|
||||
"LORA": LoraModel,
|
||||
"LOHA": LoHaModel,
|
||||
"LOKR": LoKrModel,
|
||||
"ADALORA": AdaLoraModel,
|
||||
"BOFT": BOFTModel,
|
||||
"IA3": IA3Model,
|
||||
"OFT": OFTModel,
|
||||
"POLY": PolyModel,
|
||||
"LN_TUNING": LNTuningModel,
|
||||
"VERA": VeraModel,
|
||||
"FOURIERFT": FourierFTModel,
|
||||
"XLORA": XLoraModel,
|
||||
"HRA": HRAModel,
|
||||
"VBLORA": VBLoRAModel,
|
||||
"CPT": CPTEmbedding,
|
||||
"BONE": BoneModel,
|
||||
"MMOELORA": MMOELoraModel,
|
||||
"MMOELORAS": MMOELoraModelS,
|
||||
}
|
||||
|
||||
|
||||
def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
|
||||
"""
|
||||
Returns a Peft config object from a dictionary.
|
||||
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
|
||||
"""
|
||||
|
||||
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
|
||||
|
||||
|
||||
def get_peft_model(
|
||||
model: PreTrainedModel,
|
||||
peft_config: PeftConfig,
|
||||
adapter_name: str = "default",
|
||||
mixed: bool = False,
|
||||
autocast_adapter_dtype: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> PeftModel | PeftMixedModel:
|
||||
"""
|
||||
Returns a Peft model object from a model and a config.
|
||||
|
||||
Args:
|
||||
model ([`transformers.PreTrainedModel`]):
|
||||
Model to be wrapped.
|
||||
peft_config ([`PeftConfig`]):
|
||||
Configuration object containing the parameters of the Peft model.
|
||||
adapter_name (`str`, `optional`, defaults to `"default"`):
|
||||
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
||||
mixed (`bool`, `optional`, defaults to `False`):
|
||||
Whether to allow mixing different (compatible) adapter types.
|
||||
autocast_adapter_dtype (`bool`, *optional*):
|
||||
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
|
||||
using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
|
||||
select PEFT tuners.
|
||||
revision (`str`, `optional`, defaults to `main`):
|
||||
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
|
||||
the base model
|
||||
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
|
||||
Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as
|
||||
False if you intend on training the model, unless the adapter weights will be replaced by different weights
|
||||
before training starts.
|
||||
"""
|
||||
model_config = BaseTuner.get_model_config(model)
|
||||
old_name = peft_config.base_model_name_or_path
|
||||
new_name = model.__dict__.get("name_or_path", None)
|
||||
peft_config.base_model_name_or_path = new_name
|
||||
|
||||
if (old_name is not None) and (old_name != new_name):
|
||||
warnings.warn(
|
||||
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
|
||||
"Please ensure that the correct base model is loaded when loading this checkpoint."
|
||||
)
|
||||
|
||||
if revision is not None:
|
||||
if peft_config.revision is not None and peft_config.revision != revision:
|
||||
warnings.warn(
|
||||
f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}"
|
||||
)
|
||||
peft_config.revision = revision
|
||||
|
||||
if (
|
||||
(isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"]))
|
||||
and (peft_config.init_lora_weights == "eva")
|
||||
and not low_cpu_mem_usage
|
||||
):
|
||||
warnings.warn(
|
||||
"lora with eva initialization used with low_cpu_mem_usage=False. "
|
||||
"Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization."
|
||||
)
|
||||
|
||||
prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
|
||||
if prefix and adapter_name in prefix:
|
||||
warnings.warn(
|
||||
f"Adapter name {adapter_name} should not be contained in the prefix {prefix}."
|
||||
"This may lead to reinitialization of the adapter weights during loading."
|
||||
)
|
||||
|
||||
if mixed:
|
||||
# note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
|
||||
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
|
||||
|
||||
if (
|
||||
peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys()
|
||||
and not peft_config.is_prompt_learning
|
||||
):
|
||||
return PeftModel(
|
||||
model,
|
||||
peft_config,
|
||||
adapter_name=adapter_name,
|
||||
autocast_adapter_dtype=autocast_adapter_dtype,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
if peft_config.is_prompt_learning:
|
||||
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
|
||||
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
|
||||
model,
|
||||
peft_config,
|
||||
adapter_name=adapter_name,
|
||||
autocast_adapter_dtype=autocast_adapter_dtype,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
|
||||
def inject_adapter_in_model(
|
||||
peft_config: PeftConfig,
|
||||
model: torch.nn.Module,
|
||||
adapter_name: str = "default",
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> torch.nn.Module:
|
||||
r"""
|
||||
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
|
||||
methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
|
||||
calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.
|
||||
|
||||
Args:
|
||||
peft_config (`PeftConfig`):
|
||||
Configuration object containing the parameters of the Peft model.
|
||||
model (`torch.nn.Module`):
|
||||
The input model where the adapter will be injected.
|
||||
adapter_name (`str`, `optional`, defaults to `"default"`):
|
||||
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
||||
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
|
||||
Create empty adapter weights on meta device. Useful to speed up the loading process.
|
||||
"""
|
||||
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
|
||||
raise ValueError(
|
||||
"`create_and_replace` does not support prompt learning and adaption prompt yet."
|
||||
)
|
||||
|
||||
if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys():
|
||||
raise ValueError(
|
||||
f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`."
|
||||
)
|
||||
|
||||
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
|
||||
|
||||
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
|
||||
peft_model = tuner_cls(
|
||||
model,
|
||||
peft_config,
|
||||
adapter_name=adapter_name,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
return peft_model.model
|
2
src/peft_library/tuners/__init__.py
Normal file
2
src/peft_library/tuners/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .mmoelora.mmoelora import MMOELoraModel, MMOELoraConfig
|
||||
from .mmoelora.mmoeloraS import MMOELoraModelS, MMOELoraConfigS
|
0
src/peft_library/tuners/mmoelora/__init__.py
Normal file
0
src/peft_library/tuners/mmoelora/__init__.py
Normal file
467
src/peft_library/tuners/mmoelora/mmoelora.py
Normal file
467
src/peft_library/tuners/mmoelora/mmoelora.py
Normal file
@ -0,0 +1,467 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# here put the import lib
|
||||
import importlib
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft_library.utils.peft_types import PeftType
|
||||
|
||||
from peft_library.utils.constants import TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING
|
||||
from peft.utils.other import _freeze_adapter, _get_submodules, transpose
|
||||
|
||||
|
||||
from peft.tuners.lora import (
|
||||
LoraConfig,
|
||||
LoraLayer,
|
||||
LoraModel,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def is_bnb_available():
|
||||
return importlib.util.find_spec("bitsandbytes") is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMOELoraConfig(LoraConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`~peft.MMOELora`]
|
||||
"""
|
||||
|
||||
task_num: int = field(default=2, metadata={"help": "The number of tasks."})
|
||||
task_embedding_dim: int = field(default=64)
|
||||
expert_num: int = field(default=4)
|
||||
|
||||
def __post_init__(self):
|
||||
self.peft_type = PeftType.MMOELORA
|
||||
|
||||
|
||||
class MMOELoraModel(LoraModel):
|
||||
"""
|
||||
Create MMOELoRA (MMOE based LoRA) model from a pretrained transformers model.
|
||||
"""
|
||||
|
||||
def __init__(self, model, config, adapter_name):
|
||||
nn.Module.__init__(self)
|
||||
self.model = model
|
||||
self.forward = self.model.forward
|
||||
self.peft_config = config
|
||||
self.add_adapter(adapter_name, self.peft_config[adapter_name])
|
||||
|
||||
def add_adapter(self, adapter_name, config=None):
|
||||
if config is not None: # get the lora config
|
||||
model_config = (
|
||||
self.model.config.to_dict()
|
||||
if hasattr(self.model.config, "to_dict")
|
||||
else self.model.config
|
||||
)
|
||||
config = self._prepare_mmoelora_config(config, model_config) # load config
|
||||
self.peft_config[adapter_name] = config # subsititue the original config
|
||||
self._find_and_replace(adapter_name)
|
||||
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
|
||||
raise ValueError(
|
||||
"MMOELoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
|
||||
)
|
||||
|
||||
self._mark_only_adapters_as_trainable(self.model)
|
||||
if self.peft_config[adapter_name].inference_mode:
|
||||
_freeze_adapter(self.model, adapter_name)
|
||||
|
||||
def _find_and_replace(self, adapter_name):
|
||||
"""Replace the target `Linear` module with LoRA layer (Linear+LoRA)"""
|
||||
lora_config = self.peft_config[adapter_name]
|
||||
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
|
||||
if loaded_in_8bit and not is_bnb_available():
|
||||
raise ImportError(
|
||||
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
is_target_modules_in_base_model = False
|
||||
kwargs = {
|
||||
"r": lora_config.r,
|
||||
"lora_alpha": lora_config.lora_alpha,
|
||||
"lora_dropout": lora_config.lora_dropout,
|
||||
"fan_in_fan_out": lora_config.fan_in_fan_out,
|
||||
"init_lora_weights": lora_config.init_lora_weights,
|
||||
"task_num": lora_config.task_num,
|
||||
"task_embedding_dim": lora_config.task_embedding_dim,
|
||||
"expert_num": lora_config.expert_num,
|
||||
}
|
||||
key_list = [
|
||||
key for key, _ in self.model.named_modules()
|
||||
] # all module in raw model
|
||||
for key in key_list:
|
||||
# find the corresponding modules. target module has been split into list.
|
||||
if isinstance(lora_config.target_modules, str):
|
||||
target_module_found = re.fullmatch(lora_config.target_modules, key)
|
||||
else:
|
||||
target_module_found = any(
|
||||
key.endswith(target_key)
|
||||
for target_key in lora_config.target_modules
|
||||
)
|
||||
if target_module_found:
|
||||
if not is_target_modules_in_base_model:
|
||||
is_target_modules_in_base_model = True
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
bias = target.bias is not None
|
||||
if isinstance(target, MMOELoraLayer):
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
lora_config.init_r,
|
||||
lora_config.lora_alpha,
|
||||
lora_config.lora_dropout,
|
||||
lora_config.init_lora_weights,
|
||||
)
|
||||
else:
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
in_features, out_features = (
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
)
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = (
|
||||
lora_config.fan_in_fan_out
|
||||
) = False
|
||||
elif isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape
|
||||
if hasattr(target.weight, "ds_shape")
|
||||
else target.weight.shape
|
||||
)
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
"Setting fan_in_fan_out to True."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = (
|
||||
lora_config.fan_in_fan_out
|
||||
) = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
|
||||
)
|
||||
new_module = MMOELoraLinear(
|
||||
adapter_name, in_features, out_features, bias=bias, **kwargs
|
||||
)
|
||||
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
if not is_target_modules_in_base_model:
|
||||
raise ValueError(
|
||||
f"Target modules {lora_config.target_modules} not found in the base model. "
|
||||
f"Please check the target modules and try again."
|
||||
)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Forward missing attributes to the wrapped module."""
|
||||
try:
|
||||
return super().__getattr__(name) # defer to nn.Module's logic
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_mmoelora_config(peft_config, model_config):
|
||||
if peft_config.target_modules is None:
|
||||
if (
|
||||
model_config["model_type"]
|
||||
not in TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING
|
||||
):
|
||||
raise ValueError("Please specify `target_modules` in `peft_config`")
|
||||
peft_config.target_modules = (
|
||||
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING[
|
||||
model_config["model_type"]
|
||||
]
|
||||
)
|
||||
if peft_config.inference_mode:
|
||||
peft_config.merge_weights = True
|
||||
return peft_config
|
||||
|
||||
|
||||
class MMOELoraLayer(LoraLayer):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, expert_num: int):
|
||||
|
||||
super().__init__(in_features, out_features)
|
||||
self.expert_num = expert_num
|
||||
|
||||
def update_layer(
|
||||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights
|
||||
):
|
||||
self.r[adapter_name] = r
|
||||
self.lora_alpha[adapter_name] = lora_alpha
|
||||
if lora_dropout > 0.0:
|
||||
lora_dropout_layer = nn.Dropout(p=lora_dropout)
|
||||
else:
|
||||
lora_dropout_layer = nn.Identity()
|
||||
|
||||
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A.update(
|
||||
nn.ModuleDict(
|
||||
{adapter_name: MMOELinearA(self.in_features, r, self.expert_num)}
|
||||
)
|
||||
)
|
||||
self.lora_B.update(
|
||||
nn.ModuleDict(
|
||||
{adapter_name: MMOELinearB(r, self.out_features, self.expert_num)}
|
||||
)
|
||||
)
|
||||
self.scaling[adapter_name] = lora_alpha / r
|
||||
if init_lora_weights:
|
||||
self.reset_lora_parameters(adapter_name)
|
||||
self.to(self.weight.device)
|
||||
|
||||
def reset_lora_parameters(self, adapter_name):
|
||||
if adapter_name in self.lora_A.keys():
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
for i in range(self.expert_num):
|
||||
nn.init.normal_(
|
||||
self.lora_A[adapter_name].loraA[i].mlp.weight, mean=0.0, std=0.01
|
||||
)
|
||||
nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight)
|
||||
|
||||
|
||||
class MMOELoraLinear(nn.Linear, MMOELoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
# nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Lora
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name: str,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
**kwargs,
|
||||
):
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.expert_num = kwargs.pop("expert_num", True)
|
||||
self.task_num = kwargs.pop("task_num", True)
|
||||
self.te_dim = kwargs.pop("task_embedding_dim", True)
|
||||
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
MMOELoraLayer.__init__(
|
||||
self,
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
expert_num=self.expert_num,
|
||||
)
|
||||
|
||||
# init the Gate network
|
||||
self.lora_task_embedding = nn.ModuleDict({})
|
||||
self.lora_gate = nn.ModuleDict({})
|
||||
self.lora_task_embedding.update(
|
||||
nn.ModuleDict({adapter_name: nn.Embedding(self.task_num + 1, self.te_dim)})
|
||||
)
|
||||
self.lora_gate.update(
|
||||
nn.ModuleDict({adapter_name: Gate(self.te_dim, self.expert_num)})
|
||||
)
|
||||
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
nn.Linear.reset_parameters(self)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.active_adapter = adapter_name
|
||||
|
||||
def merge(self, task_id):
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
return
|
||||
if self.merged:
|
||||
warnings.warn("Already merged. Nothing to do.")
|
||||
return
|
||||
if self.r[self.active_adapter] > 0:
|
||||
expert_weight = self.lora_gate[self.active_adapter](
|
||||
self.lora_task_embedding[self.active_adapter](task_id)
|
||||
)
|
||||
for i in range(self.expert_num):
|
||||
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
|
||||
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
|
||||
self.weight.data += (
|
||||
transpose(
|
||||
lora_B_weights @ lora_A_weights,
|
||||
self.fan_in_fan_out,
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
* expert_weight[..., i]
|
||||
)
|
||||
self.merged = True
|
||||
|
||||
def unmerge(self, task_id):
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
return
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
if self.r[self.active_adapter] > 0:
|
||||
expert_weight = self.lora_gate[self.active_adapter](
|
||||
self.lora_task_embedding[self.active_adapter](task_id)
|
||||
)
|
||||
for i in range(self.expert_num):
|
||||
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
|
||||
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
|
||||
self.weight.data -= (
|
||||
transpose(
|
||||
lora_B_weights @ lora_A_weights,
|
||||
self.fan_in_fan_out,
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
* expert_weight[..., i]
|
||||
)
|
||||
self.merged = False
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs):
|
||||
task_id = kwargs["task_id"]
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if (
|
||||
self.active_adapter not in self.lora_A.keys()
|
||||
): # No adapter, directly use linear
|
||||
return F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
if self.disable_adapters: # No adapter
|
||||
if (
|
||||
self.r[self.active_adapter] > 0 and self.merged
|
||||
): # merge the adapter to linear
|
||||
self.unmerge(task_id)
|
||||
result = F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
elif (
|
||||
self.r[self.active_adapter] > 0 and not self.merged
|
||||
): # general lora process
|
||||
result = F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
|
||||
x = x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)
|
||||
|
||||
expert_weight = self.lora_gate[self.active_adapter](
|
||||
self.lora_task_embedding[self.active_adapter](task_id)
|
||||
)
|
||||
for i in range(self.expert_num):
|
||||
result += ( # lora process
|
||||
self.lora_B[self.active_adapter].loraB[i](
|
||||
self.lora_A[self.active_adapter].loraA[i](
|
||||
self.lora_dropout[self.active_adapter](x)
|
||||
),
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
* expert_weight[..., i].unsqueeze(-1).unsqueeze(0)
|
||||
)
|
||||
else:
|
||||
result = F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class MMOELinearA(nn.Module):
|
||||
"""MMOE based LoRA block"""
|
||||
|
||||
def __init__(self, in_features, out_features, expert_num) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.expert_num = expert_num
|
||||
self.in_features, self.out_features = in_features, out_features
|
||||
self.loraA = nn.ModuleList([])
|
||||
|
||||
assert (
|
||||
self.out_features % self.expert_num == 0
|
||||
) # lora rank should be divided by expert number
|
||||
self.r = self.out_features // self.expert_num
|
||||
|
||||
for _ in range(self.expert_num):
|
||||
self.loraA.append(Expert(self.in_features, self.r))
|
||||
|
||||
def forward(self, x):
|
||||
"""input x is a vector, return output is a list"""
|
||||
outputs = []
|
||||
for i in range(self.expert_num):
|
||||
outputs.append(self.loraA[i](x))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MMOELinearB(nn.Module):
|
||||
"""MMOE based LoRA block"""
|
||||
|
||||
def __init__(self, in_features, out_features, expert_num) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.expert_num = expert_num
|
||||
self.in_features, self.out_features = in_features, out_features
|
||||
self.loraB = nn.ModuleList([])
|
||||
|
||||
assert self.in_features % self.expert_num == 0
|
||||
self.r = self.in_features // self.expert_num
|
||||
|
||||
for _ in range(self.expert_num):
|
||||
self.loraB.append(Expert(self.r, self.out_features))
|
||||
|
||||
def forward(self, x):
|
||||
"""input x is a list, return output is also a list"""
|
||||
outputs = []
|
||||
for i in range(self.expert_num):
|
||||
outputs.append(self.loraB[i](x[i]))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Expert(nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.in_features, self.out_features = in_features, out_features
|
||||
self.mlp = nn.Linear(self.in_features, self.out_features, bias=False)
|
||||
self.weight = self.mlp.weight
|
||||
|
||||
def forward(self, x):
|
||||
# LoRA A or B block
|
||||
y = self.mlp(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
|
||||
def __init__(self, input_size, expert_num):
|
||||
|
||||
super().__init__()
|
||||
# 使用embedding来代替线性层
|
||||
self.GateL = nn.Linear(input_size, expert_num, bias=False)
|
||||
self.act = nn.Softmax(dim=1) # 第0维为batch size
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
y = self.GateL(x)
|
||||
y = self.act(y)
|
||||
|
||||
return y
|
227
src/peft_library/tuners/mmoelora/mmoeloraS.py
Normal file
227
src/peft_library/tuners/mmoelora/mmoeloraS.py
Normal file
@ -0,0 +1,227 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# here put the import lib
|
||||
import re
|
||||
import importlib
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from .mmoelora import MMOELoraModel, MMOELoraLinear, MMOELoraLayer
|
||||
from peft.tuners.lora import LoraConfig
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
# from ..utils import _get_submodules, transpose, PeftType
|
||||
from peft.utils.other import _get_submodules, transpose
|
||||
from peft.utils.peft_types import PeftType
|
||||
|
||||
|
||||
def is_bnb_available():
|
||||
return importlib.util.find_spec("bitsandbytes") is not None
|
||||
|
||||
|
||||
# TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMOELoraConfigS(LoraConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`~peft.MMOELora`]
|
||||
"""
|
||||
|
||||
task_num: int = field(default=2, metadata={"help": "The number of tasks."})
|
||||
task_embedding_dim: int = field(default=64)
|
||||
expert_num: int = field(default=4)
|
||||
|
||||
def __post_init__(self):
|
||||
self.peft_type = PeftType.MMOELORAS
|
||||
|
||||
|
||||
class MMOELoraModelS(MMOELoraModel):
|
||||
|
||||
def __init__(self, model, config, adapter_name):
|
||||
|
||||
super().__init__(model, config, adapter_name)
|
||||
|
||||
def _find_and_replace(self, adapter_name):
|
||||
"""Replace the target `Linear` module with LoRA layer (Linear+LoRA)"""
|
||||
lora_config = self.peft_config[adapter_name]
|
||||
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
|
||||
if loaded_in_8bit and not is_bnb_available():
|
||||
raise ImportError(
|
||||
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
is_target_modules_in_base_model = False
|
||||
kwargs = {
|
||||
"r": lora_config.r,
|
||||
"lora_alpha": lora_config.lora_alpha,
|
||||
"lora_dropout": lora_config.lora_dropout,
|
||||
"fan_in_fan_out": lora_config.fan_in_fan_out,
|
||||
"init_lora_weights": lora_config.init_lora_weights,
|
||||
"task_num": lora_config.task_num,
|
||||
"task_embedding_dim": lora_config.task_embedding_dim,
|
||||
"expert_num": lora_config.expert_num,
|
||||
}
|
||||
key_list = [
|
||||
key for key, _ in self.model.named_modules()
|
||||
] # all module in raw model
|
||||
for key in key_list:
|
||||
# find the corresponding modules. target module has been split into list.
|
||||
if isinstance(lora_config.target_modules, str):
|
||||
target_module_found = re.fullmatch(lora_config.target_modules, key)
|
||||
else:
|
||||
target_module_found = any(
|
||||
key.endswith(target_key)
|
||||
for target_key in lora_config.target_modules
|
||||
)
|
||||
if target_module_found:
|
||||
if not is_target_modules_in_base_model:
|
||||
is_target_modules_in_base_model = True
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
bias = target.bias is not None
|
||||
if isinstance(target, MMOELoraLayer):
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
lora_config.init_r,
|
||||
lora_config.lora_alpha,
|
||||
lora_config.lora_dropout,
|
||||
lora_config.init_lora_weights,
|
||||
)
|
||||
else:
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
in_features, out_features = (
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
)
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = (
|
||||
lora_config.fan_in_fan_out
|
||||
) = False
|
||||
elif isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape
|
||||
if hasattr(target.weight, "ds_shape")
|
||||
else target.weight.shape
|
||||
)
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
"Setting fan_in_fan_out to True."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = (
|
||||
lora_config.fan_in_fan_out
|
||||
) = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
|
||||
)
|
||||
new_module = MMOELoraLinearS(
|
||||
adapter_name, in_features, out_features, bias=bias, **kwargs
|
||||
)
|
||||
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
if not is_target_modules_in_base_model:
|
||||
raise ValueError(
|
||||
f"Target modules {lora_config.target_modules} not found in the base model. "
|
||||
f"Please check the target modules and try again."
|
||||
)
|
||||
|
||||
|
||||
class MMOELoraLinearS(MMOELoraLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name: str,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0,
|
||||
fan_in_fan_out: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
adapter_name,
|
||||
in_features,
|
||||
out_features,
|
||||
r,
|
||||
lora_alpha,
|
||||
lora_dropout,
|
||||
fan_in_fan_out,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unmerge(self, expert_weight):
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
return
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
if self.r[self.active_adapter] > 0:
|
||||
for i in range(self.expert_num):
|
||||
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
|
||||
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
|
||||
self.weight.data -= (
|
||||
transpose(
|
||||
lora_B_weights @ lora_A_weights,
|
||||
self.fan_in_fan_out,
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
* expert_weight[..., i]
|
||||
)
|
||||
self.merged = False
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs):
|
||||
expert_weight = kwargs["task_id"]
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if (
|
||||
self.active_adapter not in self.lora_A.keys()
|
||||
): # No adapter, directly use linear
|
||||
return F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
if self.disable_adapters: # No adapter
|
||||
if (
|
||||
self.r[self.active_adapter] > 0 and self.merged
|
||||
): # merge the adapter to linear
|
||||
self.unmerge(expert_weight)
|
||||
result = F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
elif (
|
||||
self.r[self.active_adapter] > 0 and not self.merged
|
||||
): # general lora process
|
||||
result = F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
|
||||
x = x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)
|
||||
|
||||
for i in range(self.expert_num):
|
||||
result += ( # lora process
|
||||
self.lora_B[self.active_adapter].loraB[i](
|
||||
self.lora_A[self.active_adapter].loraA[i](
|
||||
self.lora_dropout[self.active_adapter](x)
|
||||
),
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
* expert_weight[..., i].unsqueeze(-1).unsqueeze(0)
|
||||
)
|
||||
else:
|
||||
result = F.linear(
|
||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
||||
)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
|
||||
return result
|
39
src/peft_library/utils/__init__.py
Normal file
39
src/peft_library/utils/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
|
||||
from .other import (
|
||||
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING,
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_set_trainable,
|
||||
bloom_model_postprocess_past_key_value,
|
||||
prepare_model_for_int8_training,
|
||||
shift_tokens_right,
|
||||
transpose,
|
||||
_get_submodules,
|
||||
_set_adapter,
|
||||
_freeze_adapter,
|
||||
ModulesToSaveWrapper,
|
||||
)
|
||||
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
|
176
src/peft_library/utils/config.py
Normal file
176
src/peft_library/utils/config.py
Normal file
@ -0,0 +1,176 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers.utils import PushToHubMixin
|
||||
|
||||
from .other import CONFIG_NAME
|
||||
|
||||
|
||||
class PeftType(str, enum.Enum):
|
||||
PROMPT_TUNING = "PROMPT_TUNING"
|
||||
P_TUNING = "P_TUNING"
|
||||
PREFIX_TUNING = "PREFIX_TUNING"
|
||||
LORA = "LORA"
|
||||
ADALORA = "ADALORA"
|
||||
ADAPTION_PROMPT = "ADAPTION_PROMPT"
|
||||
MMOELORAS = "MMOELORAS"
|
||||
|
||||
|
||||
class TaskType(str, enum.Enum):
|
||||
SEQ_CLS = "SEQ_CLS"
|
||||
SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
|
||||
CAUSAL_LM = "CAUSAL_LM"
|
||||
TOKEN_CLS = "TOKEN_CLS"
|
||||
CAUSAL_LMS = "CAUSAL_LMS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeftConfigMixin(PushToHubMixin):
|
||||
r"""
|
||||
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
|
||||
PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to
|
||||
push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
|
||||
directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
|
||||
|
||||
Args:
|
||||
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
||||
"""
|
||||
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
|
||||
|
||||
@property
|
||||
def __dict__(self):
|
||||
return asdict(self)
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__
|
||||
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
r"""
|
||||
This method saves the configuration of your adapter model in a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory where the configuration will be saved.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`]
|
||||
method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
output_dict = self.__dict__
|
||||
output_path = os.path.join(save_directory, CONFIG_NAME)
|
||||
|
||||
# save it
|
||||
with open(output_path, "w") as writer:
|
||||
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
|
||||
r"""
|
||||
This method loads the configuration of your adapter model from a directory.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str`):
|
||||
The directory or the Hub repository id where the configuration is saved.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments passed along to the child class initialization.
|
||||
"""
|
||||
path = (
|
||||
os.path.join(pretrained_model_name_or_path, subfolder)
|
||||
if subfolder is not None
|
||||
else pretrained_model_name_or_path
|
||||
)
|
||||
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
||||
config_file = os.path.join(path, CONFIG_NAME)
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder)
|
||||
except Exception:
|
||||
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
|
||||
|
||||
loaded_attributes = cls.from_json_file(config_file)
|
||||
|
||||
config = cls(**kwargs)
|
||||
|
||||
for key, value in loaded_attributes.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, path_json_file, **kwargs):
|
||||
r"""
|
||||
Loads a configuration file from a json file.
|
||||
|
||||
Args:
|
||||
path_json_file (`str`):
|
||||
The path to the json file.
|
||||
"""
|
||||
with open(path_json_file, "r") as file:
|
||||
json_object = json.load(file)
|
||||
|
||||
return json_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeftConfig(PeftConfigMixin):
|
||||
"""
|
||||
This is the base configuration class to store the configuration of a [`PeftModel`].
|
||||
|
||||
Args:
|
||||
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
||||
task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
|
||||
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
|
||||
"""
|
||||
|
||||
base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
|
||||
peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
|
||||
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
|
||||
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptLearningConfig(PeftConfig):
|
||||
"""
|
||||
This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or
|
||||
[`PromptTuning`].
|
||||
|
||||
Args:
|
||||
num_virtual_tokens (`int`): The number of virtual tokens to use.
|
||||
token_dim (`int`): The hidden embedding dimension of the base transformer model.
|
||||
num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
|
||||
num_attention_heads (`int`): The number of attention heads in the base transformer model.
|
||||
num_layers (`int`): The number of layers in the base transformer model.
|
||||
"""
|
||||
|
||||
num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
|
||||
token_dim: int = field(
|
||||
default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
|
||||
)
|
||||
num_transformer_submodules: Optional[int] = field(
|
||||
default=None, metadata={"help": "Number of transformer submodules"}
|
||||
)
|
||||
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
|
||||
num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
|
4
src/peft_library/utils/constants.py
Normal file
4
src/peft_library/utils/constants.py
Normal file
@ -0,0 +1,4 @@
|
||||
from peft.utils.constants import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
|
||||
TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
250
src/peft_library/utils/other.py
Normal file
250
src/peft_library/utils/other.py
Normal file
@ -0,0 +1,250 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# needed for prefix-tuning of bloom model
|
||||
def bloom_model_postprocess_past_key_value(past_key_values):
|
||||
past_key_values = torch.cat(past_key_values)
|
||||
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
|
||||
keys = past_key_values[: total_layers // 2]
|
||||
keys = keys.transpose(2, 3).reshape(
|
||||
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
|
||||
)
|
||||
values = past_key_values[total_layers // 2 :]
|
||||
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
|
||||
|
||||
return tuple(zip(keys, values))
|
||||
|
||||
|
||||
def prepare_model_for_int8_training(
|
||||
model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
|
||||
):
|
||||
r"""
|
||||
This method wraps the entire protocol for preparing a model before running a training. This includes:
|
||||
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
|
||||
head to fp32
|
||||
|
||||
Args:
|
||||
model, (`transformers.PreTrainedModel`):
|
||||
The loaded model from `transformers`
|
||||
"""
|
||||
loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
# freeze base model's layers
|
||||
param.requires_grad = False
|
||||
|
||||
if loaded_in_8bit:
|
||||
# cast layer norm in fp32 for stability for 8bit models
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if loaded_in_8bit and use_gradient_checkpointing:
|
||||
# For backward compatibility
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
# enable gradient checkpointing for memory efficiency
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
if hasattr(model, output_embedding_layer_name):
|
||||
output_embedding_layer = getattr(model, output_embedding_layer_name)
|
||||
input_dtype = output_embedding_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
r"""
|
||||
Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
|
||||
in fp32
|
||||
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# copied from transformers.models.bart.modeling_bart
|
||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
|
||||
pad_token_id (`int`): The id of the `padding` token.
|
||||
decoder_start_token_id (`int`): The id of the `start` token.
|
||||
"""
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
||||
|
||||
if pad_token_id is None:
|
||||
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
class ModulesToSaveWrapper(torch.nn.Module):
|
||||
def __init__(self, module_to_save, adapter_name):
|
||||
super().__init__()
|
||||
self.original_module = module_to_save
|
||||
self.modules_to_save = torch.nn.ModuleDict({})
|
||||
self.update(adapter_name)
|
||||
self.active_adapter = adapter_name
|
||||
|
||||
def update(self, adapter_name):
|
||||
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.active_adapter not in self.modules_to_save:
|
||||
return self.original_module(*args, **kwargs)
|
||||
return self.modules_to_save[self.active_adapter](*args, **kwargs)
|
||||
|
||||
|
||||
def _get_submodules(model, key):
|
||||
parent = model.get_submodule(".".join(key.split(".")[:-1]))
|
||||
target_name = key.split(".")[-1]
|
||||
target = model.get_submodule(key)
|
||||
return parent, target, target_name
|
||||
|
||||
|
||||
def _freeze_adapter(model, adapter_name):
|
||||
for n, p in model.named_parameters():
|
||||
if adapter_name in n:
|
||||
p.requires_grad = False
|
||||
|
||||
|
||||
def _set_trainable(model, adapter_name):
|
||||
key_list = [key for key, _ in model.named_modules()]
|
||||
for key in key_list:
|
||||
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
|
||||
if target_module_found:
|
||||
parent, target, target_name = _get_submodules(model, key)
|
||||
if isinstance(target, ModulesToSaveWrapper):
|
||||
target.update(adapter_name)
|
||||
else:
|
||||
for param in target.parameters():
|
||||
param.requires_grad = True
|
||||
setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name))
|
||||
|
||||
|
||||
def _set_adapter(model, adapter_name):
|
||||
for module in model.modules():
|
||||
if isinstance(module, ModulesToSaveWrapper):
|
||||
module.active_adapter = adapter_name
|
||||
|
||||
|
||||
def fsdp_auto_wrap_policy(model):
|
||||
import functools
|
||||
import os
|
||||
|
||||
from accelerate import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
|
||||
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
|
||||
|
||||
def lambda_policy_fn(module):
|
||||
if (
|
||||
len(list(module.named_children())) == 0
|
||||
and getattr(module, "weight", None) is not None
|
||||
and module.weight.requires_grad
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls=(
|
||||
PrefixEncoder,
|
||||
PromptEncoder,
|
||||
PromptEmbedding,
|
||||
FullyShardedDataParallelPlugin.get_module_class_from_name(
|
||||
model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
|
||||
return auto_wrap_policy
|
||||
|
||||
|
||||
def transpose(weight, fan_in_fan_out):
|
||||
return weight.T if fan_in_fan_out else weight
|
||||
|
||||
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
|
||||
"t5": ["q", "v"],
|
||||
"mt5": ["q", "v"],
|
||||
"bart": ["q_proj", "v_proj"],
|
||||
"gpt2": ["c_attn"],
|
||||
"bloom": ["query_key_value"],
|
||||
"blip-2": ["q", "v", "q_proj", "v_proj"],
|
||||
"opt": ["q_proj", "v_proj"],
|
||||
"gptj": ["q_proj", "v_proj"],
|
||||
"gpt_neox": ["query_key_value"],
|
||||
"gpt_neo": ["q_proj", "v_proj"],
|
||||
"bert": ["query", "value"],
|
||||
"roberta": ["query", "value"],
|
||||
"xlm-roberta": ["query", "value"],
|
||||
"electra": ["query", "value"],
|
||||
"deberta-v2": ["query_proj", "value_proj"],
|
||||
"deberta": ["in_proj"],
|
||||
"layoutlm": ["query", "value"],
|
||||
"llama": ["q_proj", "v_proj"],
|
||||
"chatglm": ["query_key_value"],
|
||||
}
|
||||
|
||||
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = {
|
||||
"t5": ["q", "k", "v", "o", "wi", "wo"],
|
||||
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
|
||||
"bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
||||
# "gpt2": ["c_attn"],
|
||||
# "bloom": ["query_key_value"],
|
||||
"opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
||||
# "gptj": ["q_proj", "v_proj"],
|
||||
# "gpt_neox": ["query_key_value"],
|
||||
# "gpt_neo": ["q_proj", "v_proj"],
|
||||
# "bert": ["query", "value"],
|
||||
"roberta": ["query", "key", "value", "dense"],
|
||||
# "xlm-roberta": ["query", "value"],
|
||||
# "electra": ["query", "value"],
|
||||
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
|
||||
# "deberta": ["in_proj"],
|
||||
# "layoutlm": ["query", "value"],
|
||||
}
|
||||
|
||||
TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
|
||||
"bloom": bloom_model_postprocess_past_key_value,
|
||||
}
|
||||
|
||||
WEIGHTS_NAME = "adapter_model.bin"
|
||||
CONFIG_NAME = "adapter_config.json"
|
51
src/peft_library/utils/peft_types.py
Normal file
51
src/peft_library/utils/peft_types.py
Normal file
@ -0,0 +1,51 @@
|
||||
import enum
|
||||
|
||||
class PeftType(str, enum.Enum):
|
||||
"""
|
||||
Enum class for the different types of adapters in PEFT.
|
||||
|
||||
Supported PEFT types:
|
||||
- PROMPT_TUNING
|
||||
- MULTITASK_PROMPT_TUNING
|
||||
- P_TUNING
|
||||
- PREFIX_TUNING
|
||||
- LORA
|
||||
- ADALORA
|
||||
- BOFT
|
||||
- ADAPTION_PROMPT
|
||||
- IA3
|
||||
- LOHA
|
||||
- LOKR
|
||||
- OFT
|
||||
- XLORA
|
||||
- POLY
|
||||
- LN_TUNING
|
||||
- VERA
|
||||
- FOURIERFT
|
||||
- HRA
|
||||
- BONE
|
||||
"""
|
||||
|
||||
PROMPT_TUNING = "PROMPT_TUNING"
|
||||
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
|
||||
P_TUNING = "P_TUNING"
|
||||
PREFIX_TUNING = "PREFIX_TUNING"
|
||||
LORA = "LORA"
|
||||
ADALORA = "ADALORA"
|
||||
BOFT = "BOFT"
|
||||
ADAPTION_PROMPT = "ADAPTION_PROMPT"
|
||||
IA3 = "IA3"
|
||||
LOHA = "LOHA"
|
||||
LOKR = "LOKR"
|
||||
OFT = "OFT"
|
||||
POLY = "POLY"
|
||||
LN_TUNING = "LN_TUNING"
|
||||
VERA = "VERA"
|
||||
FOURIERFT = "FOURIERFT"
|
||||
XLORA = "XLORA"
|
||||
HRA = "HRA"
|
||||
VBLORA = "VBLORA"
|
||||
CPT = "CPT"
|
||||
BONE = "BONE"
|
||||
MMOELORAS = "MMOELORAS"
|
||||
MMOELORA = "MMOELORA"
|
130
src/peft_library/utils/save_and_load.py
Normal file
130
src/peft_library/utils/save_and_load.py
Normal file
@ -0,0 +1,130 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config import PeftType, PromptLearningConfig
|
||||
|
||||
|
||||
def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"):
|
||||
"""
|
||||
Get the state dict of the Peft model.
|
||||
|
||||
Args:
|
||||
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
|
||||
the model should be the underlying model/unwrapped model (i.e. model.module).
|
||||
state_dict (`dict`, *optional*, defaults to `None`):
|
||||
The state dict of the model. If not provided, the state dict of the model
|
||||
will be used.
|
||||
"""
|
||||
config = model.peft_config[adapter_name]
|
||||
if state_dict is None:
|
||||
state_dict = model.state_dict()
|
||||
if config.peft_type in (PeftType.LORA, PeftType.ADALORA,
|
||||
PeftType.MMOELORAS):
|
||||
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
|
||||
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
|
||||
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
|
||||
bias = config.bias
|
||||
if bias == "none": # filter out all lora parameters
|
||||
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = {}
|
||||
for k in state_dict:
|
||||
if "lora_" in k:
|
||||
to_return[k] = state_dict[k]
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
if bias_name in state_dict:
|
||||
to_return[bias_name] = state_dict[bias_name]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
|
||||
|
||||
if config.peft_type == PeftType.ADALORA:
|
||||
rank_pattern = config.rank_pattern
|
||||
if rank_pattern is not None:
|
||||
rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()}
|
||||
config.rank_pattern = rank_pattern
|
||||
to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name)
|
||||
|
||||
elif config.peft_type == PeftType.ADAPTION_PROMPT:
|
||||
to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")}
|
||||
elif isinstance(config, PromptLearningConfig):
|
||||
to_return = {}
|
||||
if config.inference_mode:
|
||||
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
|
||||
else:
|
||||
prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
|
||||
to_return["prompt_embeddings"] = prompt_embeddings
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if model.modules_to_save is not None:
|
||||
for key, value in state_dict.items():
|
||||
if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
|
||||
to_return[key.replace("modules_to_save.", "")] = value
|
||||
|
||||
to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
|
||||
def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"):
|
||||
"""
|
||||
Set the state dict of the Peft model.
|
||||
|
||||
Args:
|
||||
model ([`PeftModel`]): The Peft model.
|
||||
peft_model_state_dict (`dict`): The state dict of the Peft model.
|
||||
"""
|
||||
config = model.peft_config[adapter_name]
|
||||
state_dict = {}
|
||||
if model.modules_to_save is not None:
|
||||
for key, value in peft_model_state_dict.items():
|
||||
if any(module_name in key for module_name in model.modules_to_save):
|
||||
for module_name in model.modules_to_save:
|
||||
if module_name in key:
|
||||
key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
|
||||
break
|
||||
state_dict[key] = value
|
||||
else:
|
||||
state_dict = peft_model_state_dict
|
||||
|
||||
if config.peft_type in (PeftType.LORA, PeftType.ADALORA,
|
||||
PeftType.MMOELORAS):
|
||||
peft_model_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if "lora_" in k:
|
||||
suffix = k.split("lora_")[1]
|
||||
if "." in suffix:
|
||||
suffix_to_replace = ".".join(suffix.split(".")[1:])
|
||||
k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
|
||||
else:
|
||||
k = f"{k}.{adapter_name}"
|
||||
peft_model_state_dict[k] = v
|
||||
else:
|
||||
peft_model_state_dict[k] = v
|
||||
if config.peft_type == PeftType.ADALORA:
|
||||
rank_pattern = config.rank_pattern
|
||||
if rank_pattern is not None:
|
||||
model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
|
||||
elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT:
|
||||
peft_model_state_dict = state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
model.load_state_dict(peft_model_state_dict, strict=False)
|
||||
if isinstance(config, PromptLearningConfig):
|
||||
model.prompt_encoder[adapter_name].embedding.load_state_dict(
|
||||
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
|
||||
)
|
35
src/train.py
35
src/train.py
@ -1,32 +1,46 @@
|
||||
import torch
|
||||
from datasets_library.factory import get_dataset
|
||||
from dataset_library.factory import get_dataset
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
# get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from peft import get_peft_model
|
||||
from peft_library import get_peft_model, get_peft_config
|
||||
|
||||
from utils.trainer import ContinualTrainer
|
||||
from utils.args import ContinualScriptArguments
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig))
|
||||
parser = TrlParser(
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||
)
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 0 == 1:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ModelConfig()
|
||||
model_args = ContinualModelConfig()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
# peft_config = get_peft_config(dict(**vars(model_args)))
|
||||
if model_args.peft_type == "MMOELora":
|
||||
from peft_library.tuners import MMOELoraConfig
|
||||
|
||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
elif model_args.peft_type == "LORA":
|
||||
from peft.tuners.lora import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
||||
else:
|
||||
peft_config = None
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
@ -62,9 +76,6 @@ if __name__ == "__main__":
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
peft_config = get_peft_config(model_args)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
@ -73,8 +84,10 @@ if __name__ == "__main__":
|
||||
|
||||
accelerator = create_accelerator_and_postprocess(training_args)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
model.print_trainable_parameters()
|
||||
if peft_config is not None:
|
||||
model = get_peft_model(model, peft_config)
|
||||
if accelerator.is_local_main_process:
|
||||
model.print_trainable_parameters()
|
||||
|
||||
for dataset_name in script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_name)
|
||||
|
@ -1,8 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \
|
||||
--use_peft \
|
||||
--peft_type LORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--lora_target_modules q_proj v_proj \
|
||||
--per_device_train_batch_size 1 \
|
||||
|
@ -1,17 +1,21 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from trl import ScriptArguments, ModelConfig
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContinualScriptArguments:
|
||||
class ContinualScriptArguments(ScriptArguments):
|
||||
"""Script arguments for continual learning."""
|
||||
|
||||
dataset_name: list[str] = field(
|
||||
default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"]
|
||||
)
|
||||
dataset_config: Optional[str] = None
|
||||
dataset_train_split: str = "train"
|
||||
dataset_test_split: str = "test"
|
||||
dataset_generation_split: str = "generation"
|
||||
gradient_checkpointing_use_reentrant: bool = False
|
||||
ignore_bias_buffers: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContinualModelConfig(ModelConfig):
|
||||
"""Model configuration for continual learning."""
|
||||
|
||||
peft_type: Optional[str] = None
|
Loading…
Reference in New Issue
Block a user