import warnings from typing import Optional import torch from torch import nn from peft.tuners.lora import LoraLayer from peft.tuners.tuners_utils import BaseTunerLayer from .config import MMOELoraConfig class MMOELoraLayer(LoraLayer): def __init__( self, expert_num: int, base_layer: nn.Linear = None, ): super().__init__(base_layer=base_layer) self.expert_num = expert_num def update_layer( self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, **kwargs ): self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: lora_dropout_layer = nn.Dropout(p=lora_dropout) else: lora_dropout_layer = nn.Identity() self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) # Actual trainable parameters if r > 0: self.lora_A.update( nn.ModuleDict( {adapter_name: MMOELinearA(self.in_features, r, self.expert_num)} ) ) self.lora_B.update( nn.ModuleDict( {adapter_name: MMOELinearB(r, self.out_features, self.expert_num)} ) ) self.scaling[adapter_name] = lora_alpha / r if init_lora_weights: self.reset_lora_parameters(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name) self.set_adapter(self.active_adapters) def reset_lora_parameters(self, adapter_name): if adapter_name in self.lora_A.keys(): # initialize A the same way as the default for nn.Linear and B to zero for i in range(self.expert_num): nn.init.normal_( self.lora_A[adapter_name].loraA[i].mlp.weight, mean=0.0, std=0.01 ) nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight) class MMOELoraLinear(nn.Module, MMOELoraLayer): # Lora implemented in a dense layer # nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Lora def __init__( self, base_layer: nn.Linear, adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) **kwargs, ): super().__init__() MMOELoraLayer.__init__( self, expert_num=kwargs.pop("expert_num", 2), base_layer=base_layer ) init_lora_weights = kwargs.pop("init_lora_weights", True) self.task_num = kwargs.pop("task_num", True) self.te_dim = kwargs.pop("task_embedding_dim", True) # nn.Linear.__init__(self, in_features, out_features, **kwargs) # init the Gate network self.lora_task_embedding = nn.ModuleDict({}) self.lora_gate = nn.ModuleDict({}) self.lora_task_embedding.update( nn.ModuleDict({adapter_name: nn.Embedding(self.task_num + 1, self.te_dim)}) ) self.lora_gate.update( nn.ModuleDict({adapter_name: Gate(self.te_dim, self.expert_num)}) ) # Freezing the pre-trained weight matrix self.fan_in_fan_out = fan_in_fan_out self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self._active_adapter = adapter_name # def merge(self, task_id): # if self._active_adapter not in self.lora_A.keys(): # return # if self.merged: # warnings.warn("Already merged. Nothing to do.") # return # if self.r[self._active_adapter] > 0: # expert_weight = self.lora_gate[self._active_adapter]( # self.lora_task_embedding[self._active_adapter](task_id) # ) # for i in range(self.expert_num): # lora_A_weights = self.lora_A[self._active_adapter].loraA[i].mlp.weight # lora_B_weights = self.lora_B[self._active_adapter].loraB[i].mlp.weight # self.base_layer.weight.data += ( # transpose( # lora_B_weights @ lora_A_weights, # self.fan_in_fan_out, # ) # * self.scaling[self._active_adapter] # * expert_weight[..., i] # ) # self.merged = True # def unmerge(self, task_id): # if self._active_adapter not in self.lora_A.keys(): # return # if not self.merged: # warnings.warn("Already unmerged. Nothing to do.") # return # if self.r[self._active_adapter] > 0: # expert_weight = self.lora_gate[self._active_adapter]( # self.lora_task_embedding[self._active_adapter](task_id) # ) # for i in range(self.expert_num): # lora_A_weights = self.lora_A[self._active_adapter].loraA[i].mlp.weight # lora_B_weights = self.lora_B[self._active_adapter].loraB[i].mlp.weight # self.base_layer.weight.data -= ( # transpose( # lora_B_weights @ lora_A_weights, # self.fan_in_fan_out, # ) # * self.scaling[self._active_adapter] # * expert_weight[..., i] # ) # self.merged = False def forward(self, x: torch.Tensor, *args, **kwargs): self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) # task_id = kwargs.pop( # "task_id", torch.tensor([0] * len(x), dtype=torch.long).to(x.device) # ) task_id = kwargs.pop("task_id", torch.tensor([0] * len(x), dtype=torch.long)) previous_dtype = x.dtype if self.disable_adapters: # No adapter # if self.merged: # self.unmerge(task_id) # result = self.base_layer(x, *args, **kwargs) # TODO: check this result = self.base_layer(x, *args, **kwargs) elif self.merged: # general lora process result = self.base_layer(x, *args, **kwargs) else: result = self.base_layer(x, *args, **kwargs) torch_result_dtype = result.dtype for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue x = x.to(self.lora_A[active_adapter].loraA[0].mlp.weight.dtype) expert_weight = self.lora_gate[active_adapter]( self.lora_task_embedding[active_adapter](task_id) ) for i in range(self.expert_num): result += ( self.lora_B[active_adapter].loraB[i]( self.lora_A[active_adapter].loraA[i]( self.lora_dropout[active_adapter](x) ) ) * self.scaling[active_adapter] * expert_weight[..., i].view([x.size(0)] + [1] * (x.dim() - 1)) ) result = result.to(torch_result_dtype) result = result.to(previous_dtype) return result def __repr__(self): return "MMOElora." + super().__repr__() class MMOELinearA(nn.Module): """MMOE based LoRA block""" def __init__(self, in_features: int, out_features: int, expert_num: int) -> None: super().__init__() self.expert_num = expert_num self.in_features = in_features self.out_features = out_features self.loraA = nn.ModuleList([]) assert ( self.out_features % self.expert_num == 0 ) # lora rank should be divided by expert number self.r = self.out_features // self.expert_num for _ in range(self.expert_num): self.loraA.append(Expert(self.in_features, self.r)) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """input x is a vector, return output is a list""" outputs = [] for i in range(self.expert_num): outputs.append(self.loraA[i](x)) return outputs class MMOELinearB(nn.Module): """MMOE based LoRA block""" def __init__(self, in_features: int, out_features: int, expert_num: int) -> None: super().__init__() self.expert_num = expert_num self.in_features = in_features self.out_features = out_features self.loraB = nn.ModuleList([]) assert self.in_features % self.expert_num == 0 self.r = self.in_features // self.expert_num for _ in range(self.expert_num): self.loraB.append(Expert(self.r, self.out_features)) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """input x is a list, return output is also a list""" outputs = [] for i in range(self.expert_num): outputs.append(self.loraB[i](x[i])) return outputs class Expert(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.in_features = in_features self.out_features = out_features self.mlp = nn.Linear(self.in_features, self.out_features, bias=False) def forward(self, x): # LoRA A or B block y = self.mlp(x) return y class Gate(nn.Module): def __init__(self, input_size: int, expert_num: int): super().__init__() # 使用embedding来代替线性层 self.GateL = nn.Linear(input_size, expert_num, bias=False) self.act = nn.Softmax(dim=1) # 第0维为batch size def forward(self, x): y = self.GateL(x) y = self.act(y) return y def dispatch_default( target: torch.nn.Module, adapter_name: str, lora_config: MMOELoraConfig, **kwargs, ) -> Optional[torch.nn.Module]: new_module = None if isinstance(target, BaseTunerLayer): target_base_layer = target.get_base_layer() else: target_base_layer = target # if isinstance(target_base_layer, torch.nn.Embedding): # embedding_kwargs = kwargs.copy() # embedding_kwargs.pop("fan_in_fan_out", None) # embedding_kwargs.update(lora_config.loftq_config) # new_module = Embedding(target, adapter_name, **embedding_kwargs) # elif isinstance(target_base_layer, torch.nn.Conv2d): # kwargs.update(lora_config.loftq_config) # new_module = Conv2d(target, adapter_name, **kwargs) # elif isinstance(target_base_layer, torch.nn.Conv3d): # kwargs.update(lora_config.loftq_config) # new_module = Conv3d(target, adapter_name, **kwargs) if isinstance(target_base_layer, torch.nn.Linear): if kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False kwargs.update(lora_config.loftq_config) new_module = MMOELoraLinear(target, adapter_name, **kwargs) # elif isinstance(target_base_layer, Conv1D): # if not kwargs["fan_in_fan_out"]: # warnings.warn( # "fan_in_fan_out is set to False but the target module is `Conv1D`. " # "Setting fan_in_fan_out to True." # ) # kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True # kwargs.update(lora_config.loftq_config) # new_module = Linear( # target, adapter_name, is_target_conv_1d_layer=True, **kwargs # ) return new_module