feat: 使得MOELORA支持task_id以及其他参数的传递

This commit is contained in:
YunyaoZhou 2025-01-07 15:07:08 +08:00
parent b40e0290a7
commit a1bb0f7c8c
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
2 changed files with 332 additions and 5 deletions

View File

@ -1,29 +1,346 @@
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention, # from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
from transformers.models.qwen2_vl.modeling_qwen2_vl import * from transformers.models.qwen2_vl.modeling_qwen2_vl import *
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
import torch import torch
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch import nn
from torch.nn import functional as F
from torch import Tensor
class LinearLayer(nn.Linear):
def forward(self, input: Tensor, **kwargs) -> Tensor:
return F.linear(input, self.weight, self.bias)
class Qwen2VLAttention_modified(Qwen2VLAttention): class Qwen2VLAttention_modified(Qwen2VLAttention):
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.q_proj = LinearLayer(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj = LinearLayer(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj = LinearLayer(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.o_proj = LinearLayer(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(
torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified):
"""
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Qwen2Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, **kwargs)
key_states = self.k_proj(hidden_states, **kwargs)
value_states = self.v_proj(hidden_states, **kwargs)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
QWEN2_VL_ATTENTION_CLASSES = { QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention, "eager": Qwen2VLAttention,
"flash_attention_2": Qwen2VLFlashAttention2, "flash_attention_2": Qwen2VLFlashAttention2,
"sdpa": Qwen2VLSdpaAttention, "sdpa": Qwen2VLSdpaAttention_modified,
} }
class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer): class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer):
def __init__(self, config: Qwen2VLConfig, layer_idx: int): def __init__(self, config: Qwen2VLConfig, layer_idx: int):
super().__init__(config, layer_idx) super().__init__(config, layer_idx)
self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](
config, layer_idx
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class Qwen2VLModel_modified(Qwen2VLModel): class Qwen2VLModel_modified(Qwen2VLModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[Qwen2VLDecoderLayer_modified(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [
Qwen2VLDecoderLayer_modified(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
) )
def forward( def forward(
@ -174,7 +491,16 @@ class Qwen2VLModel_modified(Qwen2VLModel):
class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration): class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
config.vision_config
)
self.model = Qwen2VLModel_modified(config) self.model = Qwen2VLModel_modified(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing
self.post_init()
def forward( def forward(
self, self,

View File

@ -151,9 +151,10 @@ class MMOELoraLinear(nn.Module, MMOELoraLayer):
def forward(self, x: torch.Tensor, *args, **kwargs): def forward(self, x: torch.Tensor, *args, **kwargs):
self._check_forward_args(x, *args, **kwargs) self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None) adapter_names = kwargs.pop("adapter_names", None)
task_id = kwargs.pop( # task_id = kwargs.pop(
"task_id", torch.tensor([0] * len(x), dtype=torch.long).to(x.device) # "task_id", torch.tensor([0] * len(x), dtype=torch.long).to(x.device)
) # )
task_id = kwargs.pop("task_id", torch.tensor([0] * len(x), dtype=torch.long))
previous_dtype = x.dtype previous_dtype = x.dtype
if self.disable_adapters: # No adapter if self.disable_adapters: # No adapter