Merge branch 'release/0.1.1'

This commit is contained in:
YunyaoZhou 2025-01-07 15:12:53 +08:00
commit 4d65809c34
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
9 changed files with 897 additions and 58 deletions

181
src/dataset_library/chem.py Normal file
View File

@ -0,0 +1,181 @@
from PIL import Image
from torch.utils.data import Dataset
import json
import os
class ChemDataseet(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
return {
"image": image,
"chat": chat,
"image_id": sample["image_id"],
}
class OCRVQADatasetForGeneration(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
}
# {"role": "assistant", "content": answer},
]
return {
"image": image,
"chat": chat,
"answer": answer,
"image_id": sample["image_id"],
}
if __name__ == "__main__":
dataset = ChemDataseet(
"/home/zyy/research/accelerate/dataset/chem/images/",
"/home/zyy/research/accelerate/dataset/chem/qwen_data/conversations_loc_train.jsonl",
split="train",
)
print(len(dataset))
print(dataset[0])
dataset = OCRVQADatasetForGeneration(
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
split="train",
)
print(len(dataset))
print(dataset[0])
pass

View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
print(model) print(model)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from collatefn_library.qwen2 import ( from model_library.qwen2 import (
collate_fn_for_train, collate_fn_for_train,
collate_fn_for_evaluate, collate_fn_for_evaluate,
) )

View File

@ -0,0 +1,4 @@
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
from .model import Qwen2VLForConditionalGeneration_modified
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]

View File

@ -10,7 +10,6 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
] ]
# print(texts) # print(texts)
images = [example["image"] for example in examples] images = [example["image"] for example in examples]
# Tokenize the texts and process the images # Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True) batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@ -52,7 +51,7 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
now_index += 1 now_index += 1
now_index += 1 now_index += 1
batch["labels"] = labels batch["labels"] = labels
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long) batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
return batch return batch
@ -80,48 +79,5 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
# pixel_values torch.Size([3888, 1176]) # pixel_values torch.Size([3888, 1176])
# image_grid_thw torch.Size([3, 3]) # image_grid_thw torch.Size([3, 3])
# answers_ids torch.Size([3, 10]) # answers_ids torch.Size([3, 10])
# answers_mask torch.Size([3, 10]) # answers_mask torch.Size([3, 10])
return batch return batch
if __name__ == "__main__":
from transformers import Qwen2VLProcessor
from dataset_library.OCRVQADataset import OCRVQADatasetForGeneration
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
dataset = OCRVQADatasetForGeneration(
vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
split="train",
)
examples = [dataset[i] for i in range(3)]
# print(collate_fn_for_evaluate(examples, processor))
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
from accelerate import Accelerator
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset,
batch_size=3,
collate_fn=lambda x: collate_fn_for_evaluate(x, processor),
)
accelerator = Accelerator()
model = accelerator.prepare(model)
val_dataloader = accelerator.prepare(val_dataloader)
for batch in val_dataloader:
completion = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
pixel_values=batch["pixel_values"],
image_grid_thw=batch["image_grid_thw"],
max_length=100,
)
target = batch["answers_ids"]
generated_text = [out_ids[len(in_ids) :] for out_ids, in_ids in zip(completion, batch["input_ids"])]
generated_text = processor.tokenizer.batch_decode(generated_text)
target_text = processor.tokenizer.batch_decode(target)
print(generated_text, target_text)

View File

@ -0,0 +1,696 @@
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast
import torch
from typing import Optional, List, Union, Tuple
from torch.nn import CrossEntropyLoss
from torch import nn
from torch.nn import functional as F
from torch import Tensor
class LinearLayer(nn.Linear):
def forward(self, input: Tensor, **kwargs) -> Tensor:
return F.linear(input, self.weight, self.bias)
class Qwen2VLAttention_modified(Qwen2VLAttention):
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.q_proj = LinearLayer(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj = LinearLayer(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj = LinearLayer(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.o_proj = LinearLayer(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(
torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified):
"""
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Qwen2Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, **kwargs)
key_states = self.k_proj(hidden_states, **kwargs)
value_states = self.v_proj(hidden_states, **kwargs)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
"flash_attention_2": Qwen2VLFlashAttention2,
"sdpa": Qwen2VLSdpaAttention_modified,
}
class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer):
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](
config, layer_idx
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class Qwen2VLModel_modified(Qwen2VLModel):
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList(
[
Qwen2VLDecoderLayer_modified(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# the hard coded `3` is for temporal, height and width.
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(
3, inputs_embeds.shape[0], -1
)
elif position_ids.dim() == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
**kwargs,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
config.vision_config
)
self.model = Qwen2VLModel_modified(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if (
position_ids is None
and input_ids is not None
and (attention_mask is None or attention_mask.ndim == 2)
):
# calculate RoPE index once per generation in the pre-fill stage only
if (
cache_position is not None and cache_position[0] == 0
) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
cache_position[0] + self.rope_deltas
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)

View File

@ -151,9 +151,10 @@ class MMOELoraLinear(nn.Module, MMOELoraLayer):
def forward(self, x: torch.Tensor, *args, **kwargs): def forward(self, x: torch.Tensor, *args, **kwargs):
self._check_forward_args(x, *args, **kwargs) self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None) adapter_names = kwargs.pop("adapter_names", None)
task_id = kwargs.pop( # task_id = kwargs.pop(
"task_id", torch.tensor([0] * len(x), dtype=torch.long).to(x.device) # "task_id", torch.tensor([0] * len(x), dtype=torch.long).to(x.device)
) # )
task_id = kwargs.pop("task_id", torch.tensor([0] * len(x), dtype=torch.long))
previous_dtype = x.dtype previous_dtype = x.dtype
if self.disable_adapters: # No adapter if self.disable_adapters: # No adapter

View File

@ -53,14 +53,16 @@ if __name__ == "__main__":
padding_side="left", padding_side="left",
) )
model = AutoModelForVision2Seq.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from collatefn_library.qwen2 import ( # from transformers import Qwen2VLForConditionalGeneration
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
from model_library.qwen2vl import (
collate_fn_for_train, collate_fn_for_train,
collate_fn_for_evaluate, collate_fn_for_evaluate,
) )

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from trl import ScriptArguments, ModelConfig from trl import ScriptArguments, ModelConfig
from transformers import TrainingArguments
@dataclass @dataclass