251 lines
9.1 KiB
Python
251 lines
9.1 KiB
Python
# coding=utf-8
|
|
# Copyright 2023-present the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
|
|
import torch
|
|
|
|
|
|
# needed for prefix-tuning of bloom model
|
|
def bloom_model_postprocess_past_key_value(past_key_values):
|
|
past_key_values = torch.cat(past_key_values)
|
|
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
|
|
keys = past_key_values[: total_layers // 2]
|
|
keys = keys.transpose(2, 3).reshape(
|
|
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
|
|
)
|
|
values = past_key_values[total_layers // 2 :]
|
|
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
|
|
|
|
return tuple(zip(keys, values))
|
|
|
|
|
|
def prepare_model_for_int8_training(
|
|
model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
|
|
):
|
|
r"""
|
|
This method wraps the entire protocol for preparing a model before running a training. This includes:
|
|
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
|
|
head to fp32
|
|
|
|
Args:
|
|
model, (`transformers.PreTrainedModel`):
|
|
The loaded model from `transformers`
|
|
"""
|
|
loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
|
|
|
|
for name, param in model.named_parameters():
|
|
# freeze base model's layers
|
|
param.requires_grad = False
|
|
|
|
if loaded_in_8bit:
|
|
# cast layer norm in fp32 for stability for 8bit models
|
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
|
param.data = param.data.to(torch.float32)
|
|
|
|
if loaded_in_8bit and use_gradient_checkpointing:
|
|
# For backward compatibility
|
|
if hasattr(model, "enable_input_require_grads"):
|
|
model.enable_input_require_grads()
|
|
else:
|
|
|
|
def make_inputs_require_grad(module, input, output):
|
|
output.requires_grad_(True)
|
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
|
|
# enable gradient checkpointing for memory efficiency
|
|
model.gradient_checkpointing_enable()
|
|
|
|
if hasattr(model, output_embedding_layer_name):
|
|
output_embedding_layer = getattr(model, output_embedding_layer_name)
|
|
input_dtype = output_embedding_layer.weight.dtype
|
|
|
|
class CastOutputToFloat(torch.nn.Sequential):
|
|
r"""
|
|
Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
|
|
in fp32
|
|
|
|
"""
|
|
|
|
def forward(self, x):
|
|
return super().forward(x.to(input_dtype)).to(torch.float32)
|
|
|
|
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
|
|
|
return model
|
|
|
|
|
|
# copied from transformers.models.bart.modeling_bart
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
|
"""
|
|
Shift input ids one token to the right.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
|
|
pad_token_id (`int`): The id of the `padding` token.
|
|
decoder_start_token_id (`int`): The id of the `start` token.
|
|
"""
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
|
shifted_input_ids[:, 0] = decoder_start_token_id
|
|
|
|
if pad_token_id is None:
|
|
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
|
# replace possible -100 values in labels by `pad_token_id`
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
|
|
|
return shifted_input_ids
|
|
|
|
|
|
class ModulesToSaveWrapper(torch.nn.Module):
|
|
def __init__(self, module_to_save, adapter_name):
|
|
super().__init__()
|
|
self.original_module = module_to_save
|
|
self.modules_to_save = torch.nn.ModuleDict({})
|
|
self.update(adapter_name)
|
|
self.active_adapter = adapter_name
|
|
|
|
def update(self, adapter_name):
|
|
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
|
|
|
|
def forward(self, *args, **kwargs):
|
|
if self.active_adapter not in self.modules_to_save:
|
|
return self.original_module(*args, **kwargs)
|
|
return self.modules_to_save[self.active_adapter](*args, **kwargs)
|
|
|
|
|
|
def _get_submodules(model, key):
|
|
parent = model.get_submodule(".".join(key.split(".")[:-1]))
|
|
target_name = key.split(".")[-1]
|
|
target = model.get_submodule(key)
|
|
return parent, target, target_name
|
|
|
|
|
|
def _freeze_adapter(model, adapter_name):
|
|
for n, p in model.named_parameters():
|
|
if adapter_name in n:
|
|
p.requires_grad = False
|
|
|
|
|
|
def _set_trainable(model, adapter_name):
|
|
key_list = [key for key, _ in model.named_modules()]
|
|
for key in key_list:
|
|
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
|
|
if target_module_found:
|
|
parent, target, target_name = _get_submodules(model, key)
|
|
if isinstance(target, ModulesToSaveWrapper):
|
|
target.update(adapter_name)
|
|
else:
|
|
for param in target.parameters():
|
|
param.requires_grad = True
|
|
setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name))
|
|
|
|
|
|
def _set_adapter(model, adapter_name):
|
|
for module in model.modules():
|
|
if isinstance(module, ModulesToSaveWrapper):
|
|
module.active_adapter = adapter_name
|
|
|
|
|
|
def fsdp_auto_wrap_policy(model):
|
|
import functools
|
|
import os
|
|
|
|
from accelerate import FullyShardedDataParallelPlugin
|
|
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
|
|
|
|
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
|
|
|
|
def lambda_policy_fn(module):
|
|
if (
|
|
len(list(module.named_children())) == 0
|
|
and getattr(module, "weight", None) is not None
|
|
and module.weight.requires_grad
|
|
):
|
|
return True
|
|
return False
|
|
|
|
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
|
|
transformer_wrap_policy = functools.partial(
|
|
transformer_auto_wrap_policy,
|
|
transformer_layer_cls=(
|
|
PrefixEncoder,
|
|
PromptEncoder,
|
|
PromptEmbedding,
|
|
FullyShardedDataParallelPlugin.get_module_class_from_name(
|
|
model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
|
|
),
|
|
),
|
|
)
|
|
|
|
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
|
|
return auto_wrap_policy
|
|
|
|
|
|
def transpose(weight, fan_in_fan_out):
|
|
return weight.T if fan_in_fan_out else weight
|
|
|
|
|
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
|
|
"t5": ["q", "v"],
|
|
"mt5": ["q", "v"],
|
|
"bart": ["q_proj", "v_proj"],
|
|
"gpt2": ["c_attn"],
|
|
"bloom": ["query_key_value"],
|
|
"blip-2": ["q", "v", "q_proj", "v_proj"],
|
|
"opt": ["q_proj", "v_proj"],
|
|
"gptj": ["q_proj", "v_proj"],
|
|
"gpt_neox": ["query_key_value"],
|
|
"gpt_neo": ["q_proj", "v_proj"],
|
|
"bert": ["query", "value"],
|
|
"roberta": ["query", "value"],
|
|
"xlm-roberta": ["query", "value"],
|
|
"electra": ["query", "value"],
|
|
"deberta-v2": ["query_proj", "value_proj"],
|
|
"deberta": ["in_proj"],
|
|
"layoutlm": ["query", "value"],
|
|
"llama": ["q_proj", "v_proj"],
|
|
"chatglm": ["query_key_value"],
|
|
}
|
|
|
|
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = {
|
|
"t5": ["q", "k", "v", "o", "wi", "wo"],
|
|
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
|
|
"bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
|
# "gpt2": ["c_attn"],
|
|
# "bloom": ["query_key_value"],
|
|
"opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
|
# "gptj": ["q_proj", "v_proj"],
|
|
# "gpt_neox": ["query_key_value"],
|
|
# "gpt_neo": ["q_proj", "v_proj"],
|
|
# "bert": ["query", "value"],
|
|
"roberta": ["query", "key", "value", "dense"],
|
|
# "xlm-roberta": ["query", "value"],
|
|
# "electra": ["query", "value"],
|
|
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
|
|
# "deberta": ["in_proj"],
|
|
# "layoutlm": ["query", "value"],
|
|
}
|
|
|
|
TRANSFORMERS_MODELS_TO_MMOELORAS_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
|
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
|
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
|
|
"bloom": bloom_model_postprocess_past_key_value,
|
|
}
|
|
|
|
WEIGHTS_NAME = "adapter_model.bin"
|
|
CONFIG_NAME = "adapter_config.json"
|