添加PEFT库的初始化文件,更新数据集导入路径,修改训练脚本以支持新的PEFT类型和配置,新增持续学习模型配置类,添加PEFT类型枚举,更新评估和训练逻辑以适应新结构

This commit is contained in:
YunyaoZhou 2025-01-02 02:44:58 +08:00
parent aef0f6834e
commit 2cd1bb4993
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
26 changed files with 1673 additions and 20 deletions

View File

@ -84,7 +84,7 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
if __name__ == "__main__": if __name__ == "__main__":
from transformers import Qwen2VLProcessor from transformers import Qwen2VLProcessor
from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration from dataset_library.OCRVQADataset import OCRVQADatasetForGeneration
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
dataset = OCRVQADatasetForGeneration( dataset = OCRVQADatasetForGeneration(

View File

@ -1,5 +1,5 @@
import torch import torch
from datasets_library.factory import get_dataset from dataset_library.factory import get_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments
from trl import ( from trl import (

View File

@ -0,0 +1 @@
from .mapping import get_peft_config, get_peft_model

288
src/peft_library/mapping.py Normal file
View File

@ -0,0 +1,288 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Optional
import torch
from peft.tuners.xlora.model import XLoraModel
from peft.config import PeftConfig
from peft.mixed_model import PeftMixedModel
from peft.peft_model import (
PeftModel,
PeftModelForCausalLM,
PeftModelForFeatureExtraction,
PeftModelForQuestionAnswering,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from peft.tuners import (
AdaLoraConfig,
AdaLoraModel,
AdaptionPromptConfig,
BOFTConfig,
BOFTModel,
BoneConfig,
BoneModel,
CPTConfig,
CPTEmbedding,
FourierFTConfig,
FourierFTModel,
HRAConfig,
HRAModel,
IA3Config,
IA3Model,
LNTuningConfig,
LNTuningModel,
LoHaConfig,
LoHaModel,
LoKrConfig,
LoKrModel,
LoraConfig,
LoraModel,
MultitaskPromptTuningConfig,
OFTConfig,
OFTModel,
PolyConfig,
PolyModel,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
VBLoRAConfig,
VBLoRAModel,
VeraConfig,
VeraModel,
XLoraConfig,
)
from .tuners import MMOELoraConfigS, MMOELoraModelS, MMOELoraModel, MMOELoraConfig
from peft.tuners.tuners_utils import BaseTuner
from peft.utils import _prepare_prompt_learning_config
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
if TYPE_CHECKING:
from transformers import PreTrainedModel
MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
"SEQ_CLS": PeftModelForSequenceClassification,
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
"CAUSAL_LM": PeftModelForCausalLM,
"TOKEN_CLS": PeftModelForTokenClassification,
"QUESTION_ANS": PeftModelForQuestionAnswering,
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
}
PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
"ADAPTION_PROMPT": AdaptionPromptConfig,
"PROMPT_TUNING": PromptTuningConfig,
"PREFIX_TUNING": PrefixTuningConfig,
"P_TUNING": PromptEncoderConfig,
"LORA": LoraConfig,
"LOHA": LoHaConfig,
"LORAPLUS": LoraConfig,
"LOKR": LoKrConfig,
"ADALORA": AdaLoraConfig,
"BOFT": BOFTConfig,
"IA3": IA3Config,
"MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
"OFT": OFTConfig,
"POLY": PolyConfig,
"LN_TUNING": LNTuningConfig,
"VERA": VeraConfig,
"FOURIERFT": FourierFTConfig,
"XLORA": XLoraConfig,
"HRA": HRAConfig,
"VBLORA": VBLoRAConfig,
"CPT": CPTConfig,
"BONE": BoneConfig,
"MMOELORA": MMOELoraConfig,
"MMOELORAS": MMOELoraConfigS,
}
PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = {
"LORA": LoraModel,
"LOHA": LoHaModel,
"LOKR": LoKrModel,
"ADALORA": AdaLoraModel,
"BOFT": BOFTModel,
"IA3": IA3Model,
"OFT": OFTModel,
"POLY": PolyModel,
"LN_TUNING": LNTuningModel,
"VERA": VeraModel,
"FOURIERFT": FourierFTModel,
"XLORA": XLoraModel,
"HRA": HRAModel,
"VBLORA": VBLoRAModel,
"CPT": CPTEmbedding,
"BONE": BoneModel,
"MMOELORA": MMOELoraModel,
"MMOELORAS": MMOELoraModelS,
}
def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
"""
Returns a Peft config object from a dictionary.
Args:
config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
"""
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
def get_peft_model(
model: PreTrainedModel,
peft_config: PeftConfig,
adapter_name: str = "default",
mixed: bool = False,
autocast_adapter_dtype: bool = True,
revision: Optional[str] = None,
low_cpu_mem_usage: bool = False,
) -> PeftModel | PeftMixedModel:
"""
Returns a Peft model object from a model and a config.
Args:
model ([`transformers.PreTrainedModel`]):
Model to be wrapped.
peft_config ([`PeftConfig`]):
Configuration object containing the parameters of the Peft model.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
mixed (`bool`, `optional`, defaults to `False`):
Whether to allow mixing different (compatible) adapter types.
autocast_adapter_dtype (`bool`, *optional*):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
select PEFT tuners.
revision (`str`, `optional`, defaults to `main`):
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
the base model
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as
False if you intend on training the model, unless the adapter weights will be replaced by different weights
before training starts.
"""
model_config = BaseTuner.get_model_config(model)
old_name = peft_config.base_model_name_or_path
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name
if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
"Please ensure that the correct base model is loaded when loading this checkpoint."
)
if revision is not None:
if peft_config.revision is not None and peft_config.revision != revision:
warnings.warn(
f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}"
)
peft_config.revision = revision
if (
(isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"]))
and (peft_config.init_lora_weights == "eva")
and not low_cpu_mem_usage
):
warnings.warn(
"lora with eva initialization used with low_cpu_mem_usage=False. "
"Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization."
)
prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
if prefix and adapter_name in prefix:
warnings.warn(
f"Adapter name {adapter_name} should not be contained in the prefix {prefix}."
"This may lead to reinitialization of the adapter weights during loading."
)
if mixed:
# note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
if (
peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys()
and not peft_config.is_prompt_learning
):
return PeftModel(
model,
peft_config,
adapter_name=adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if peft_config.is_prompt_learning:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
model,
peft_config,
adapter_name=adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)
def inject_adapter_in_model(
peft_config: PeftConfig,
model: torch.nn.Module,
adapter_name: str = "default",
low_cpu_mem_usage: bool = False,
) -> torch.nn.Module:
r"""
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.
Args:
peft_config (`PeftConfig`):
Configuration object containing the parameters of the Peft model.
model (`torch.nn.Module`):
The input model where the adapter will be injected.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
"""
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
raise ValueError(
"`create_and_replace` does not support prompt learning and adaption prompt yet."
)
if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys():
raise ValueError(
f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`."
)
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
peft_model = tuner_cls(
model,
peft_config,
adapter_name=adapter_name,
low_cpu_mem_usage=low_cpu_mem_usage,
)
return peft_model.model

View File

@ -0,0 +1,2 @@
from .mmoelora.mmoelora import MMOELoraModel, MMOELoraConfig
from .mmoelora.mmoeloraS import MMOELoraModelS, MMOELoraConfigS

View File

@ -0,0 +1,467 @@
# -*- encoding: utf-8 -*-
# here put the import lib
import importlib
import re
import warnings
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D
from peft_library.utils.peft_types import PeftType
from peft_library.utils.constants import TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING
from peft.utils.other import _freeze_adapter, _get_submodules, transpose
from peft.tuners.lora import (
LoraConfig,
LoraLayer,
LoraModel,
)
def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None
@dataclass
class MMOELoraConfig(LoraConfig):
"""
This is the configuration class to store the configuration of a [`~peft.MMOELora`]
"""
task_num: int = field(default=2, metadata={"help": "The number of tasks."})
task_embedding_dim: int = field(default=64)
expert_num: int = field(default=4)
def __post_init__(self):
self.peft_type = PeftType.MMOELORA
class MMOELoraModel(LoraModel):
"""
Create MMOELoRA (MMOE based LoRA) model from a pretrained transformers model.
"""
def __init__(self, model, config, adapter_name):
nn.Module.__init__(self)
self.model = model
self.forward = self.model.forward
self.peft_config = config
self.add_adapter(adapter_name, self.peft_config[adapter_name])
def add_adapter(self, adapter_name, config=None):
if config is not None: # get the lora config
model_config = (
self.model.config.to_dict()
if hasattr(self.model.config, "to_dict")
else self.model.config
)
config = self._prepare_mmoelora_config(config, model_config) # load config
self.peft_config[adapter_name] = config # subsititue the original config
self._find_and_replace(adapter_name)
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
raise ValueError(
"MMOELoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
)
self._mark_only_adapters_as_trainable(self.model)
if self.peft_config[adapter_name].inference_mode:
_freeze_adapter(self.model, adapter_name)
def _find_and_replace(self, adapter_name):
"""Replace the target `Linear` module with LoRA layer (Linear+LoRA)"""
lora_config = self.peft_config[adapter_name]
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if loaded_in_8bit and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
"task_num": lora_config.task_num,
"task_embedding_dim": lora_config.task_embedding_dim,
"expert_num": lora_config.expert_num,
}
key_list = [
key for key, _ in self.model.named_modules()
] # all module in raw model
for key in key_list:
# find the corresponding modules. target module has been split into list.
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(
key.endswith(target_key)
for target_key in lora_config.target_modules
)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = target.bias is not None
if isinstance(target, MMOELoraLayer):
target.update_layer(
adapter_name,
lora_config.init_r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
raise NotImplementedError
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = (
target.in_features,
target.out_features,
)
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = (
lora_config.fan_in_fan_out
) = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape
if hasattr(target.weight, "ds_shape")
else target.weight.shape
)
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = (
lora_config.fan_in_fan_out
) = True
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
)
new_module = MMOELoraLinear(
adapter_name, in_features, out_features, bias=bias, **kwargs
)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
@staticmethod
def _prepare_mmoelora_config(peft_config, model_config):
if peft_config.target_modules is None:
if (
model_config["model_type"]
not in TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING
):
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = (
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING[
model_config["model_type"]
]
)
if peft_config.inference_mode:
peft_config.merge_weights = True
return peft_config
class MMOELoraLayer(LoraLayer):
def __init__(self, in_features: int, out_features: int, expert_num: int):
super().__init__(in_features, out_features)
self.expert_num = expert_num
def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights
):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
if r > 0:
self.lora_A.update(
nn.ModuleDict(
{adapter_name: MMOELinearA(self.in_features, r, self.expert_num)}
)
)
self.lora_B.update(
nn.ModuleDict(
{adapter_name: MMOELinearB(r, self.out_features, self.expert_num)}
)
)
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
# initialize A the same way as the default for nn.Linear and B to zero
for i in range(self.expert_num):
nn.init.normal_(
self.lora_A[adapter_name].loraA[i].mlp.weight, mean=0.0, std=0.01
)
nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight)
class MMOELoraLinear(nn.Linear, MMOELoraLayer):
# Lora implemented in a dense layer
# nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Lora
def __init__(
self,
adapter_name: str,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.expert_num = kwargs.pop("expert_num", True)
self.task_num = kwargs.pop("task_num", True)
self.te_dim = kwargs.pop("task_embedding_dim", True)
nn.Linear.__init__(self, in_features, out_features, **kwargs)
MMOELoraLayer.__init__(
self,
in_features=in_features,
out_features=out_features,
expert_num=self.expert_num,
)
# init the Gate network
self.lora_task_embedding = nn.ModuleDict({})
self.lora_gate = nn.ModuleDict({})
self.lora_task_embedding.update(
nn.ModuleDict({adapter_name: nn.Embedding(self.task_num + 1, self.te_dim)})
)
self.lora_gate.update(
nn.ModuleDict({adapter_name: Gate(self.te_dim, self.expert_num)})
)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T
nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def merge(self, task_id):
if self.active_adapter not in self.lora_A.keys():
return
if self.merged:
warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
expert_weight = self.lora_gate[self.active_adapter](
self.lora_task_embedding[self.active_adapter](task_id)
)
for i in range(self.expert_num):
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
self.weight.data += (
transpose(
lora_B_weights @ lora_A_weights,
self.fan_in_fan_out,
)
* self.scaling[self.active_adapter]
* expert_weight[..., i]
)
self.merged = True
def unmerge(self, task_id):
if self.active_adapter not in self.lora_A.keys():
return
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
expert_weight = self.lora_gate[self.active_adapter](
self.lora_task_embedding[self.active_adapter](task_id)
)
for i in range(self.expert_num):
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
self.weight.data -= (
transpose(
lora_B_weights @ lora_A_weights,
self.fan_in_fan_out,
)
* self.scaling[self.active_adapter]
* expert_weight[..., i]
)
self.merged = False
def forward(self, x: torch.Tensor, **kwargs):
task_id = kwargs["task_id"]
previous_dtype = x.dtype
if (
self.active_adapter not in self.lora_A.keys()
): # No adapter, directly use linear
return F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
if self.disable_adapters: # No adapter
if (
self.r[self.active_adapter] > 0 and self.merged
): # merge the adapter to linear
self.unmerge(task_id)
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
elif (
self.r[self.active_adapter] > 0 and not self.merged
): # general lora process
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
x = x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)
expert_weight = self.lora_gate[self.active_adapter](
self.lora_task_embedding[self.active_adapter](task_id)
)
for i in range(self.expert_num):
result += ( # lora process
self.lora_B[self.active_adapter].loraB[i](
self.lora_A[self.active_adapter].loraA[i](
self.lora_dropout[self.active_adapter](x)
),
)
* self.scaling[self.active_adapter]
* expert_weight[..., i].unsqueeze(-1).unsqueeze(0)
)
else:
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
result = result.to(previous_dtype)
return result
class MMOELinearA(nn.Module):
"""MMOE based LoRA block"""
def __init__(self, in_features, out_features, expert_num) -> None:
super().__init__()
self.expert_num = expert_num
self.in_features, self.out_features = in_features, out_features
self.loraA = nn.ModuleList([])
assert (
self.out_features % self.expert_num == 0
) # lora rank should be divided by expert number
self.r = self.out_features // self.expert_num
for _ in range(self.expert_num):
self.loraA.append(Expert(self.in_features, self.r))
def forward(self, x):
"""input x is a vector, return output is a list"""
outputs = []
for i in range(self.expert_num):
outputs.append(self.loraA[i](x))
return outputs
class MMOELinearB(nn.Module):
"""MMOE based LoRA block"""
def __init__(self, in_features, out_features, expert_num) -> None:
super().__init__()
self.expert_num = expert_num
self.in_features, self.out_features = in_features, out_features
self.loraB = nn.ModuleList([])
assert self.in_features % self.expert_num == 0
self.r = self.in_features // self.expert_num
for _ in range(self.expert_num):
self.loraB.append(Expert(self.r, self.out_features))
def forward(self, x):
"""input x is a list, return output is also a list"""
outputs = []
for i in range(self.expert_num):
outputs.append(self.loraB[i](x[i]))
return outputs
class Expert(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features, self.out_features = in_features, out_features
self.mlp = nn.Linear(self.in_features, self.out_features, bias=False)
self.weight = self.mlp.weight
def forward(self, x):
# LoRA A or B block
y = self.mlp(x)
return y
class Gate(nn.Module):
def __init__(self, input_size, expert_num):
super().__init__()
# 使用embedding来代替线性层
self.GateL = nn.Linear(input_size, expert_num, bias=False)
self.act = nn.Softmax(dim=1) # 第0维为batch size
def forward(self, x):
y = self.GateL(x)
y = self.act(y)
return y

View File

@ -0,0 +1,227 @@
# -*- encoding: utf-8 -*-
# here put the import lib
import re
import importlib
import warnings
from dataclasses import dataclass, field
from .mmoelora import MMOELoraModel, MMOELoraLinear, MMOELoraLayer
from peft.tuners.lora import LoraConfig
import torch
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D
# from ..utils import _get_submodules, transpose, PeftType
from peft.utils.other import _get_submodules, transpose
from peft.utils.peft_types import PeftType
def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None
# TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
@dataclass
class MMOELoraConfigS(LoraConfig):
"""
This is the configuration class to store the configuration of a [`~peft.MMOELora`]
"""
task_num: int = field(default=2, metadata={"help": "The number of tasks."})
task_embedding_dim: int = field(default=64)
expert_num: int = field(default=4)
def __post_init__(self):
self.peft_type = PeftType.MMOELORAS
class MMOELoraModelS(MMOELoraModel):
def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name)
def _find_and_replace(self, adapter_name):
"""Replace the target `Linear` module with LoRA layer (Linear+LoRA)"""
lora_config = self.peft_config[adapter_name]
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if loaded_in_8bit and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
"task_num": lora_config.task_num,
"task_embedding_dim": lora_config.task_embedding_dim,
"expert_num": lora_config.expert_num,
}
key_list = [
key for key, _ in self.model.named_modules()
] # all module in raw model
for key in key_list:
# find the corresponding modules. target module has been split into list.
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(
key.endswith(target_key)
for target_key in lora_config.target_modules
)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = target.bias is not None
if isinstance(target, MMOELoraLayer):
target.update_layer(
adapter_name,
lora_config.init_r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
raise NotImplementedError
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = (
target.in_features,
target.out_features,
)
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = (
lora_config.fan_in_fan_out
) = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape
if hasattr(target.weight, "ds_shape")
else target.weight.shape
)
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = (
lora_config.fan_in_fan_out
) = True
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
)
new_module = MMOELoraLinearS(
adapter_name, in_features, out_features, bias=bias, **kwargs
)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
class MMOELoraLinearS(MMOELoraLinear):
def __init__(
self,
adapter_name: str,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0,
fan_in_fan_out: bool = False,
**kwargs,
):
super().__init__(
adapter_name,
in_features,
out_features,
r,
lora_alpha,
lora_dropout,
fan_in_fan_out,
**kwargs,
)
def unmerge(self, expert_weight):
if self.active_adapter not in self.lora_A.keys():
return
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
for i in range(self.expert_num):
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
self.weight.data -= (
transpose(
lora_B_weights @ lora_A_weights,
self.fan_in_fan_out,
)
* self.scaling[self.active_adapter]
* expert_weight[..., i]
)
self.merged = False
def forward(self, x: torch.Tensor, **kwargs):
expert_weight = kwargs["task_id"]
previous_dtype = x.dtype
if (
self.active_adapter not in self.lora_A.keys()
): # No adapter, directly use linear
return F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
if self.disable_adapters: # No adapter
if (
self.r[self.active_adapter] > 0 and self.merged
): # merge the adapter to linear
self.unmerge(expert_weight)
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
elif (
self.r[self.active_adapter] > 0 and not self.merged
): # general lora process
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
x = x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)
for i in range(self.expert_num):
result += ( # lora process
self.lora_B[self.active_adapter].loraB[i](
self.lora_A[self.active_adapter].loraA[i](
self.lora_dropout[self.active_adapter](x)
),
)
* self.scaling[self.active_adapter]
* expert_weight[..., i].unsqueeze(-1).unsqueeze(0)
)
else:
result = F.linear(
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
)
result = result.to(previous_dtype)
return result

View File

@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
from .other import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING,
CONFIG_NAME,
WEIGHTS_NAME,
_set_trainable,
bloom_model_postprocess_past_key_value,
prepare_model_for_int8_training,
shift_tokens_right,
transpose,
_get_submodules,
_set_adapter,
_freeze_adapter,
ModulesToSaveWrapper,
)
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict

View File

@ -0,0 +1,176 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import json
import os
from dataclasses import asdict, dataclass, field
from typing import Optional, Union
from huggingface_hub import hf_hub_download
from transformers.utils import PushToHubMixin
from .other import CONFIG_NAME
class PeftType(str, enum.Enum):
PROMPT_TUNING = "PROMPT_TUNING"
P_TUNING = "P_TUNING"
PREFIX_TUNING = "PREFIX_TUNING"
LORA = "LORA"
ADALORA = "ADALORA"
ADAPTION_PROMPT = "ADAPTION_PROMPT"
MMOELORAS = "MMOELORAS"
class TaskType(str, enum.Enum):
SEQ_CLS = "SEQ_CLS"
SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
CAUSAL_LM = "CAUSAL_LM"
TOKEN_CLS = "TOKEN_CLS"
CAUSAL_LMS = "CAUSAL_LMS"
@dataclass
class PeftConfigMixin(PushToHubMixin):
r"""
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to
push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
Args:
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
"""
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
@property
def __dict__(self):
return asdict(self)
def to_dict(self):
return self.__dict__
def save_pretrained(self, save_directory, **kwargs):
r"""
This method saves the configuration of your adapter model in a directory.
Args:
save_directory (`str`):
The directory where the configuration will be saved.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`]
method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
output_dict = self.__dict__
output_path = os.path.join(save_directory, CONFIG_NAME)
# save it
with open(output_path, "w") as writer:
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
r"""
This method loads the configuration of your adapter model from a directory.
Args:
pretrained_model_name_or_path (`str`):
The directory or the Hub repository id where the configuration is saved.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the child class initialization.
"""
path = (
os.path.join(pretrained_model_name_or_path, subfolder)
if subfolder is not None
else pretrained_model_name_or_path
)
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
config_file = os.path.join(path, CONFIG_NAME)
else:
try:
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder)
except Exception:
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
loaded_attributes = cls.from_json_file(config_file)
config = cls(**kwargs)
for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)
return config
@classmethod
def from_json_file(cls, path_json_file, **kwargs):
r"""
Loads a configuration file from a json file.
Args:
path_json_file (`str`):
The path to the json file.
"""
with open(path_json_file, "r") as file:
json_object = json.load(file)
return json_object
@dataclass
class PeftConfig(PeftConfigMixin):
"""
This is the base configuration class to store the configuration of a [`PeftModel`].
Args:
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
"""
base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
@dataclass
class PromptLearningConfig(PeftConfig):
"""
This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or
[`PromptTuning`].
Args:
num_virtual_tokens (`int`): The number of virtual tokens to use.
token_dim (`int`): The hidden embedding dimension of the base transformer model.
num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
num_attention_heads (`int`): The number of attention heads in the base transformer model.
num_layers (`int`): The number of layers in the base transformer model.
"""
num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
token_dim: int = field(
default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
)
num_transformer_submodules: Optional[int] = field(
default=None, metadata={"help": "Number of transformer submodules"}
)
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})

View File

@ -0,0 +1,4 @@
from peft.utils.constants import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING

View File

@ -0,0 +1,250 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import torch
# needed for prefix-tuning of bloom model
def bloom_model_postprocess_past_key_value(past_key_values):
past_key_values = torch.cat(past_key_values)
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
keys = past_key_values[: total_layers // 2]
keys = keys.transpose(2, 3).reshape(
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
)
values = past_key_values[total_layers // 2 :]
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
return tuple(zip(keys, values))
def prepare_model_for_int8_training(
model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
):
r"""
This method wraps the entire protocol for preparing a model before running a training. This includes:
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
head to fp32
Args:
model, (`transformers.PreTrainedModel`):
The loaded model from `transformers`
"""
loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False
if loaded_in_8bit:
# cast layer norm in fp32 for stability for 8bit models
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
if loaded_in_8bit and use_gradient_checkpointing:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
if hasattr(model, output_embedding_layer_name):
output_embedding_layer = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
r"""
Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
in fp32
"""
def forward(self, x):
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
return model
# copied from transformers.models.bart.modeling_bart
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
pad_token_id (`int`): The id of the `padding` token.
decoder_start_token_id (`int`): The id of the `start` token.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class ModulesToSaveWrapper(torch.nn.Module):
def __init__(self, module_to_save, adapter_name):
super().__init__()
self.original_module = module_to_save
self.modules_to_save = torch.nn.ModuleDict({})
self.update(adapter_name)
self.active_adapter = adapter_name
def update(self, adapter_name):
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
def forward(self, *args, **kwargs):
if self.active_adapter not in self.modules_to_save:
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)
def _get_submodules(model, key):
parent = model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = model.get_submodule(key)
return parent, target, target_name
def _freeze_adapter(model, adapter_name):
for n, p in model.named_parameters():
if adapter_name in n:
p.requires_grad = False
def _set_trainable(model, adapter_name):
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
if target_module_found:
parent, target, target_name = _get_submodules(model, key)
if isinstance(target, ModulesToSaveWrapper):
target.update(adapter_name)
else:
for param in target.parameters():
param.requires_grad = True
setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name))
def _set_adapter(model, adapter_name):
for module in model.modules():
if isinstance(module, ModulesToSaveWrapper):
module.active_adapter = adapter_name
def fsdp_auto_wrap_policy(model):
import functools
import os
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
FullyShardedDataParallelPlugin.get_module_class_from_name(
model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
),
),
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy
def transpose(weight, fan_in_fan_out):
return weight.T if fan_in_fan_out else weight
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],
"mt5": ["q", "v"],
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"blip-2": ["q", "v", "q_proj", "v_proj"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "value"],
"xlm-roberta": ["query", "value"],
"electra": ["query", "value"],
"deberta-v2": ["query_proj", "value_proj"],
"deberta": ["in_proj"],
"layoutlm": ["query", "value"],
"llama": ["q_proj", "v_proj"],
"chatglm": ["query_key_value"],
}
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "k", "v", "o", "wi", "wo"],
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
"bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
# "gpt2": ["c_attn"],
# "bloom": ["query_key_value"],
"opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
# "gptj": ["q_proj", "v_proj"],
# "gpt_neox": ["query_key_value"],
# "gpt_neo": ["q_proj", "v_proj"],
# "bert": ["query", "value"],
"roberta": ["query", "key", "value", "dense"],
# "xlm-roberta": ["query", "value"],
# "electra": ["query", "value"],
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
# "deberta": ["in_proj"],
# "layoutlm": ["query", "value"],
}
TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
"bloom": bloom_model_postprocess_past_key_value,
}
WEIGHTS_NAME = "adapter_model.bin"
CONFIG_NAME = "adapter_config.json"

View File

@ -0,0 +1,51 @@
import enum
class PeftType(str, enum.Enum):
"""
Enum class for the different types of adapters in PEFT.
Supported PEFT types:
- PROMPT_TUNING
- MULTITASK_PROMPT_TUNING
- P_TUNING
- PREFIX_TUNING
- LORA
- ADALORA
- BOFT
- ADAPTION_PROMPT
- IA3
- LOHA
- LOKR
- OFT
- XLORA
- POLY
- LN_TUNING
- VERA
- FOURIERFT
- HRA
- BONE
"""
PROMPT_TUNING = "PROMPT_TUNING"
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
P_TUNING = "P_TUNING"
PREFIX_TUNING = "PREFIX_TUNING"
LORA = "LORA"
ADALORA = "ADALORA"
BOFT = "BOFT"
ADAPTION_PROMPT = "ADAPTION_PROMPT"
IA3 = "IA3"
LOHA = "LOHA"
LOKR = "LOKR"
OFT = "OFT"
POLY = "POLY"
LN_TUNING = "LN_TUNING"
VERA = "VERA"
FOURIERFT = "FOURIERFT"
XLORA = "XLORA"
HRA = "HRA"
VBLORA = "VBLORA"
CPT = "CPT"
BONE = "BONE"
MMOELORAS = "MMOELORAS"
MMOELORA = "MMOELORA"

View File

@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import PeftType, PromptLearningConfig
def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"):
"""
Get the state dict of the Peft model.
Args:
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
the model should be the underlying model/unwrapped model (i.e. model.module).
state_dict (`dict`, *optional*, defaults to `None`):
The state dict of the model. If not provided, the state dict of the model
will be used.
"""
config = model.peft_config[adapter_name]
if state_dict is None:
state_dict = model.state_dict()
if config.peft_type in (PeftType.LORA, PeftType.ADALORA,
PeftType.MMOELORAS):
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none": # filter out all lora parameters
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.peft_type == PeftType.ADALORA:
rank_pattern = config.rank_pattern
if rank_pattern is not None:
rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()}
config.rank_pattern = rank_pattern
to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name)
elif config.peft_type == PeftType.ADAPTION_PROMPT:
to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")}
elif isinstance(config, PromptLearningConfig):
to_return = {}
if config.inference_mode:
prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
else:
prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
to_return["prompt_embeddings"] = prompt_embeddings
else:
raise NotImplementedError
if model.modules_to_save is not None:
for key, value in state_dict.items():
if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
to_return[key.replace("modules_to_save.", "")] = value
to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
return to_return
def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"):
"""
Set the state dict of the Peft model.
Args:
model ([`PeftModel`]): The Peft model.
peft_model_state_dict (`dict`): The state dict of the Peft model.
"""
config = model.peft_config[adapter_name]
state_dict = {}
if model.modules_to_save is not None:
for key, value in peft_model_state_dict.items():
if any(module_name in key for module_name in model.modules_to_save):
for module_name in model.modules_to_save:
if module_name in key:
key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
break
state_dict[key] = value
else:
state_dict = peft_model_state_dict
if config.peft_type in (PeftType.LORA, PeftType.ADALORA,
PeftType.MMOELORAS):
peft_model_state_dict = {}
for k, v in state_dict.items():
if "lora_" in k:
suffix = k.split("lora_")[1]
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
else:
k = f"{k}.{adapter_name}"
peft_model_state_dict[k] = v
else:
peft_model_state_dict[k] = v
if config.peft_type == PeftType.ADALORA:
rank_pattern = config.rank_pattern
if rank_pattern is not None:
model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT:
peft_model_state_dict = state_dict
else:
raise NotImplementedError
model.load_state_dict(peft_model_state_dict, strict=False)
if isinstance(config, PromptLearningConfig):
model.prompt_encoder[adapter_name].embedding.load_state_dict(
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
)

View File

@ -1,32 +1,46 @@
import torch import torch
from datasets_library.factory import get_dataset from dataset_library.factory import get_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments
from trl import ( from trl import (
ModelConfig, ModelConfig,
TrlParser, TrlParser,
get_kbit_device_map, get_kbit_device_map,
get_peft_config, # get_peft_config,
get_quantization_config, get_quantization_config,
) )
from peft import get_peft_model from peft_library import get_peft_model, get_peft_config
from utils.trainer import ContinualTrainer from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments from utils.args import ContinualScriptArguments, ContinualModelConfig
if __name__ == "__main__": if __name__ == "__main__":
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig)) parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
)
script_args, training_args, model_args = parser.parse_args_and_config() script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint # for type hint
if 0 == 1: if 0 == 1:
script_args = ContinualScriptArguments() script_args = ContinualScriptArguments()
training_args = TrainingArguments() training_args = TrainingArguments()
model_args = ModelConfig() model_args = ContinualModelConfig()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True} training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# peft_config = get_peft_config(dict(**vars(model_args)))
if model_args.peft_type == "MMOELora":
from peft_library.tuners import MMOELoraConfig
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
elif model_args.peft_type == "LORA":
from peft.tuners.lora import LoraConfig
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
else:
peft_config = None
torch_dtype = ( torch_dtype = (
model_args.torch_dtype model_args.torch_dtype
if model_args.torch_dtype in ["auto", None] if model_args.torch_dtype in ["auto", None]
@ -62,9 +76,6 @@ if __name__ == "__main__":
collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
peft_config = get_peft_config(model_args)
model = get_peft_model(model, peft_config)
################ ################
# Dataset # Dataset
################ ################
@ -73,8 +84,10 @@ if __name__ == "__main__":
accelerator = create_accelerator_and_postprocess(training_args) accelerator = create_accelerator_and_postprocess(training_args)
if accelerator.is_local_main_process: if peft_config is not None:
model.print_trainable_parameters() model = get_peft_model(model, peft_config)
if accelerator.is_local_main_process:
model.print_trainable_parameters()
for dataset_name in script_args.dataset_name: for dataset_name in script_args.dataset_name:
dataset = get_dataset(dataset_name) dataset = get_dataset(dataset_name)

View File

@ -1,8 +1,9 @@
#!/bin/bash #!/bin/bash
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py \ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \ --dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \
--use_peft \ --use_peft \
--peft_type LORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \ --lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \

View File

@ -1,17 +1,21 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from trl import ScriptArguments, ModelConfig
from transformers import TrainingArguments
@dataclass @dataclass
class ContinualScriptArguments: class ContinualScriptArguments(ScriptArguments):
"""Script arguments for continual learning.""" """Script arguments for continual learning."""
dataset_name: list[str] = field( dataset_name: list[str] = field(
default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"] default_factory=lambda: ["cifar10", "cifar100", "imagenet2012"]
) )
dataset_config: Optional[str] = None
dataset_train_split: str = "train"
dataset_test_split: str = "test"
dataset_generation_split: str = "generation" dataset_generation_split: str = "generation"
gradient_checkpointing_use_reentrant: bool = False
ignore_bias_buffers: bool = False
@dataclass
class ContinualModelConfig(ModelConfig):
"""Model configuration for continual learning."""
peft_type: Optional[str] = None