371 lines
15 KiB
Python
371 lines
15 KiB
Python
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
import torch
|
|
from typing import Optional, List, Union, Tuple
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
class Qwen2VLAttention_modified(Qwen2VLAttention):
|
|
|
|
QWEN2_VL_ATTENTION_CLASSES = {
|
|
"eager": Qwen2VLAttention,
|
|
"flash_attention_2": Qwen2VLFlashAttention2,
|
|
"sdpa": Qwen2VLSdpaAttention,
|
|
}
|
|
|
|
|
|
class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer):
|
|
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
|
|
super().__init__(config, layer_idx)
|
|
|
|
|
|
class Qwen2VLModel_modified(Qwen2VLModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.layers = nn.ModuleList(
|
|
[Qwen2VLDecoderLayer_modified(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError(
|
|
"You must specify exactly one of input_ids or inputs_embeds"
|
|
)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
# torch.jit.trace() doesn't support cache objects in the output
|
|
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
|
past_key_values = DynamicCache()
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = (
|
|
past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
)
|
|
cache_position = torch.arange(
|
|
past_seen_tokens,
|
|
past_seen_tokens + inputs_embeds.shape[1],
|
|
device=inputs_embeds.device,
|
|
)
|
|
|
|
# the hard coded `3` is for temporal, height and width.
|
|
if position_ids is None:
|
|
position_ids = cache_position.view(1, 1, -1).expand(
|
|
3, inputs_embeds.shape[0], -1
|
|
)
|
|
elif position_ids.dim() == 2:
|
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
|
|
|
causal_mask = self._update_causal_mask(
|
|
attention_mask,
|
|
inputs_embeds,
|
|
cache_position,
|
|
past_key_values,
|
|
output_attentions,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = None
|
|
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
decoder_layer.__call__,
|
|
hidden_states,
|
|
causal_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
output_attentions,
|
|
use_cache,
|
|
cache_position,
|
|
position_embeddings,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = Qwen2VLModel_modified(config)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
r"""
|
|
Args:
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|
|
|
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
|
|
>>> messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image"},
|
|
{"type": "text", "text": "What is shown in this image?"},
|
|
],
|
|
},
|
|
]
|
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
```"""
|
|
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
if pixel_values is not None:
|
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
n_image_features = image_embeds.shape[0]
|
|
if n_image_tokens != n_image_features:
|
|
raise ValueError(
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
)
|
|
image_mask = (
|
|
(input_ids == self.config.image_token_id)
|
|
.unsqueeze(-1)
|
|
.expand_as(inputs_embeds)
|
|
.to(inputs_embeds.device)
|
|
)
|
|
image_embeds = image_embeds.to(
|
|
inputs_embeds.device, inputs_embeds.dtype
|
|
)
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
|
|
if pixel_values_videos is not None:
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
n_video_features = video_embeds.shape[0]
|
|
if n_video_tokens != n_video_features:
|
|
raise ValueError(
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
)
|
|
video_mask = (
|
|
(input_ids == self.config.video_token_id)
|
|
.unsqueeze(-1)
|
|
.expand_as(inputs_embeds)
|
|
.to(inputs_embeds.device)
|
|
)
|
|
video_embeds = video_embeds.to(
|
|
inputs_embeds.device, inputs_embeds.dtype
|
|
)
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
|
|
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
if (
|
|
position_ids is None
|
|
and input_ids is not None
|
|
and (attention_mask is None or attention_mask.ndim == 2)
|
|
):
|
|
# calculate RoPE index once per generation in the pre-fill stage only
|
|
if (
|
|
cache_position is not None and cache_position[0] == 0
|
|
) or self.rope_deltas is None:
|
|
position_ids, rope_deltas = self.get_rope_index(
|
|
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
)
|
|
self.rope_deltas = rope_deltas
|
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
else:
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
delta = (
|
|
cache_position[0] + self.rope_deltas
|
|
if cache_position is not None
|
|
else 0
|
|
)
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
position_ids = position_ids.add(delta)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
|
|
outputs = self.model(
|
|
input_ids=None,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
logits = logits.float()
|
|
# Shift so that tokens < n predict n
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss()
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
shift_labels = shift_labels.view(-1)
|
|
# Enable model parallelism
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
return Qwen2VLCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
rope_deltas=self.rope_deltas,
|
|
)
|