Merge branch 'release/0.1.2'
This commit is contained in:
commit
d0e13cfbac
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.11.7
|
@ -1,3 +1,3 @@
|
||||
#!/bin/bash
|
||||
uv venv --python 3.11.7
|
||||
pip install -U torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
||||
uv sync
|
||||
uv sync --extra compile
|
||||
|
50
pyproject.toml
Normal file
50
pyproject.toml
Normal file
@ -0,0 +1,50 @@
|
||||
[project]
|
||||
dependencies = [
|
||||
"accelerate==1.2.1",
|
||||
"datasets==3.2.0",
|
||||
"deepspeed==0.16.2",
|
||||
"evaluate==0.4.3",
|
||||
"librosa>=0.10.2.post1",
|
||||
"markupsafe==2.1.5",
|
||||
"numba>=0.60.0",
|
||||
"peft==0.14.0",
|
||||
"pip==24.3.1",
|
||||
"requests==2.32.3",
|
||||
"setuptools>=70.0.0",
|
||||
"soundfile>=0.13.0",
|
||||
"torch==2.5.1+cu124",
|
||||
"torchaudio==2.5.1+cu124",
|
||||
"torchvision==0.20.1+cu124",
|
||||
"transformers==4.48.0",
|
||||
"trl==0.13.0",
|
||||
"wheel>=0.45.1",
|
||||
]
|
||||
description = "Add your description here"
|
||||
name = "cl-lmm"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
version = "0.1.0"
|
||||
|
||||
[project.optional-dependencies]
|
||||
compile = [
|
||||
"flash-attn>=2.7.2.post1",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
markupsafe = {index = "pytorch"}
|
||||
requests = {index = "pypi"}
|
||||
torch = {index = "pytorch"}
|
||||
torchaudio = {index = "pytorch"}
|
||||
torchvision = {index = "pytorch"}
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pypi"
|
||||
url = "https://pypi.org/simple"
|
||||
|
||||
[tool.uv]
|
||||
no-build-isolation-package = ["flash-attn"]
|
||||
concurrent-builds = 4
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
@ -1,11 +0,0 @@
|
||||
accelerate==1.2.1
|
||||
deepspeed==0.16.2
|
||||
evaluate==0.4.3
|
||||
peft==0.14.0
|
||||
pillow==10.2.0
|
||||
torch==2.5.1+cu124
|
||||
torchaudio==2.5.1+cu124
|
||||
torchvision==0.20.1+cu124
|
||||
transformers==4.46.1
|
||||
trl==0.13.0
|
||||
pillow==9.5.0
|
97
src/dataset_library/GigaspeechDataset.py
Normal file
97
src/dataset_library/GigaspeechDataset.py
Normal file
@ -0,0 +1,97 @@
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class GigaspeechDataset(Dataset):
|
||||
def __init__(self, audio_processor=None, text_processor=None, split="train"):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.audio_processor = audio_processor
|
||||
self.text_processor = text_processor
|
||||
gs = load_dataset("speechcolab/gigaspeech", "xs")
|
||||
self.data = gs[split]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
audio = sample["audio"]["array"]
|
||||
sampling_rate = sample["audio"]["sampling_rate"]
|
||||
text = sample["text"]
|
||||
|
||||
if self.audio_processor is not None:
|
||||
audio = self.audio_processor(audio)
|
||||
if self.text_processor is not None:
|
||||
text = self.text_processor(text)
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio_url": ""},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please convert the audio to text",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": text}]},
|
||||
]
|
||||
return {
|
||||
"audio": (audio, sampling_rate),
|
||||
"chat": chat,
|
||||
}
|
||||
|
||||
|
||||
class GigaspeechDatasetForGeneration(GigaspeechDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
audio = sample["audio"]["array"]
|
||||
sampling_rate = sample["audio"]["sampling_rate"]
|
||||
text = sample["text"]
|
||||
|
||||
if self.audio_processor is not None:
|
||||
audio = self.audio_processor(audio)
|
||||
if self.text_processor is not None:
|
||||
text = self.text_processor(text)
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "audio_url": ""},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please convert the audio to text",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
return {
|
||||
"audio": (audio, sampling_rate),
|
||||
"chat": chat,
|
||||
"answer": text,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = GigaspeechDataset(
|
||||
split="train",
|
||||
)
|
||||
print(len(dataset))
|
||||
print(dataset[0])
|
||||
dataset = GigaspeechDatasetForGeneration(
|
||||
split="train",
|
||||
)
|
||||
print(len(dataset))
|
||||
print(dataset[0])
|
||||
pass
|
@ -107,7 +107,7 @@ class OCRVQADataset(Dataset):
|
||||
}
|
||||
|
||||
|
||||
class OCRVQADatasetForGeneration(Dataset):
|
||||
class OCRVQADatasetForGeneration(OCRVQADataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
|
@ -6,7 +6,7 @@ def get_dataset(
|
||||
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
|
||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||
if dataset_name == "OCR_VQA_200K":
|
||||
if dataset_name == "OCR-VQA-200K":
|
||||
import os.path as osp
|
||||
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
|
||||
@ -48,4 +48,13 @@ def get_dataset(
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
|
||||
if dataset_name == "gigaspeech":
|
||||
from .GigaspeechDataset import GigaspeechDataset, GigaspeechDatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": GigaspeechDataset(split="train"),
|
||||
"test": GigaspeechDataset(split="test"),
|
||||
"generation": GigaspeechDatasetForGeneration(split="test"),
|
||||
}
|
||||
return dataset
|
||||
|
105
src/model_library/factory.py
Normal file
105
src/model_library/factory.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
from trl import (
|
||||
get_kbit_device_map,
|
||||
# get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
def get_model(model_args):
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=(
|
||||
get_kbit_device_map() if quantization_config is not None else None
|
||||
),
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
from transformers import Qwen2VLProcessor
|
||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
processor = Qwen2VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
print(model)
|
||||
from model_library.qwen2vl import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=(
|
||||
get_kbit_device_map() if quantization_config is not None else None
|
||||
),
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
||||
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
processor = Qwen2AudioProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
print(model)
|
||||
from model_library.qwen2audio import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "VITA-MLLM/VITA-1.5":
|
||||
# from transformers import
|
||||
# from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
print(model)
|
||||
from model_library.qwen2vl import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
4
src/model_library/qwen2audio/__init__.py
Normal file
4
src/model_library/qwen2audio/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||
from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]
|
87
src/model_library/qwen2audio/collate_fn.py
Normal file
87
src/model_library/qwen2audio/collate_fn.py
Normal file
@ -0,0 +1,87 @@
|
||||
from transformers import Qwen2AudioProcessor
|
||||
|
||||
|
||||
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
for example in examples
|
||||
]
|
||||
audios = [example["audio"][0] for example in examples]
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
sampling_rate=examples[0]["audio"][1],
|
||||
)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
# 对<|im_start|>system *** <|im_end|>\n加掩码
|
||||
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
|
||||
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
|
||||
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
|
||||
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
|
||||
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||
# 151644 151645 8948 872 77091 None 151643
|
||||
|
||||
for i, label in enumerate(labels):
|
||||
now_index = 0
|
||||
while now_index < len(label):
|
||||
if label[now_index] == im_start_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
if (
|
||||
label[now_index] == system_token_id
|
||||
or label[now_index] == user_token_id
|
||||
):
|
||||
while label[now_index] != im_end_token_id:
|
||||
label[now_index] = -100
|
||||
now_index += 1
|
||||
label[now_index] = -100
|
||||
elif label[now_index] == assistant_token_id:
|
||||
label[now_index] = -100
|
||||
label[now_index + 1] = -100
|
||||
now_index += 2
|
||||
while (
|
||||
now_index < len(label) and label[now_index] != im_end_token_id
|
||||
):
|
||||
now_index += 1
|
||||
now_index += 1
|
||||
batch["labels"] = labels
|
||||
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["chat"], tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
audios = [example["audio"] for example in examples]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, audios=audios, return_tensors="pt", padding=True)
|
||||
|
||||
answers = [example["answer"] for example in examples]
|
||||
answers = processor(text=answers, return_tensors="pt", padding=True)
|
||||
batch["answers_ids"] = answers["input_ids"]
|
||||
batch["answers_mask"] = answers["attention_mask"]
|
||||
# input_ids torch.Size([3, 370])
|
||||
# attention_mask torch.Size([3, 370])
|
||||
# pixel_values torch.Size([3888, 1176])
|
||||
# image_grid_thw torch.Size([3, 3])
|
||||
# answers_ids torch.Size([3, 10])
|
||||
# answers_mask torch.Size([3, 10])
|
||||
return batch
|
856
src/model_library/qwen2audio/model.py
Normal file
856
src/model_library/qwen2audio/model.py
Normal file
@ -0,0 +1,856 @@
|
||||
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
is_flash_attn_2_available,
|
||||
Qwen2VLAttention,
|
||||
Qwen2VLFlashAttention2,
|
||||
Qwen2VLDecoderLayer,
|
||||
Qwen2VLModel,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLConfig,
|
||||
logger,
|
||||
apply_rotary_pos_emb_vision,
|
||||
Cache,
|
||||
apply_multimodal_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
math,
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
Qwen2VisionTransformerPretrainedModel,
|
||||
)
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
import torch
|
||||
from typing import Optional, List, Union, Tuple
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class LinearLayer(nn.Linear):
|
||||
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class Qwen2VLAttention_modified(Qwen2VLAttention):
|
||||
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__(config, layer_idx)
|
||||
self.q_proj = LinearLayer(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.k_proj = LinearLayer(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.v_proj = LinearLayer(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.o_proj = LinearLayer(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Fix precision issues in Qwen2-VL float16 inference
|
||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||
if query_states.dtype == torch.float16:
|
||||
attn_weights = torch.where(
|
||||
torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified):
|
||||
"""
|
||||
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Qwen2Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states, **kwargs)
|
||||
key_states = self.k_proj(hidden_states, **kwargs)
|
||||
value_states = self.v_proj(hidden_states, **kwargs)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class Qwen2VLFlashAttention2_modified(Qwen2VLAttention_modified):
|
||||
"""
|
||||
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
|
||||
as the weights of the module stays untouched. The only required change would be on the forward pass
|
||||
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
||||
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
||||
config.max_window_layers layers.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states, **kwargs)
|
||||
key_states = self.k_proj(hidden_states, **kwargs)
|
||||
value_states = self.v_proj(hidden_states, **kwargs)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
QWEN2_VL_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2VLAttention,
|
||||
"flash_attention_2": Qwen2VLFlashAttention2_modified,
|
||||
"sdpa": Qwen2VLSdpaAttention_modified,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer):
|
||||
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, layer_idx
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Qwen2VLModel_modified(Qwen2VLModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2VLDecoderLayer_modified(config, layer_idx)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds"
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# torch.jit.trace() doesn't support cache objects in the output
|
||||
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
)
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# the hard coded `3` is for temporal, height and width.
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.view(1, 1, -1).expand(
|
||||
3, inputs_embeds.shape[0], -1
|
||||
)
|
||||
elif position_ids.dim() == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
|
||||
config.vision_config
|
||||
)
|
||||
self.model = Qwen2VLModel_modified(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.rope_deltas = None # cache rope_deltas here
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if (
|
||||
position_ids is None
|
||||
and input_ids is not None
|
||||
and (attention_mask is None or attention_mask.ndim == 2)
|
||||
):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
cache_position is not None and cache_position[0] == 0
|
||||
) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
cache_position[0] + self.rope_deltas
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
@ -1,5 +1,29 @@
|
||||
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
is_flash_attn_2_available,
|
||||
Qwen2VLAttention,
|
||||
Qwen2VLFlashAttention2,
|
||||
Qwen2VLDecoderLayer,
|
||||
Qwen2VLModel,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLConfig,
|
||||
logger,
|
||||
apply_rotary_pos_emb_vision,
|
||||
Cache,
|
||||
apply_multimodal_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
math,
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
Qwen2VisionTransformerPretrainedModel,
|
||||
)
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
import torch
|
||||
@ -246,9 +270,145 @@ class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified):
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class Qwen2VLFlashAttention2_modified(Qwen2VLAttention_modified):
|
||||
"""
|
||||
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
|
||||
as the weights of the module stays untouched. The only required change would be on the forward pass
|
||||
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
||||
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
||||
config.max_window_layers layers.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states, **kwargs)
|
||||
key_states = self.k_proj(hidden_states, **kwargs)
|
||||
value_states = self.v_proj(hidden_states, **kwargs)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
QWEN2_VL_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2VLAttention,
|
||||
"flash_attention_2": Qwen2VLFlashAttention2,
|
||||
"flash_attention_2": Qwen2VLFlashAttention2_modified,
|
||||
"sdpa": Qwen2VLSdpaAttention_modified,
|
||||
}
|
||||
|
||||
|
52
src/train.py
52
src/train.py
@ -1,16 +1,10 @@
|
||||
import torch
|
||||
from dataset_library.factory import get_dataset
|
||||
from transformers import (
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from trl import (
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
# get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from peft_library import get_peft_model
|
||||
|
||||
@ -32,44 +26,11 @@ if __name__ == "__main__":
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
# peft_config = get_peft_config(dict(**vars(model_args)))
|
||||
from model_library.factory import get_model
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||
model_args
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
# from transformers import Qwen2VLForConditionalGeneration
|
||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
from model_library.qwen2vl import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
@ -85,13 +46,16 @@ if __name__ == "__main__":
|
||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
# model = inject_adapter_in_model(peft_config, model)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
model.print_trainable_parameters()
|
||||
elif model_args.peft_type == "LORA":
|
||||
from peft.tuners.lora import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules, r=2)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
model.print_trainable_parameters()
|
||||
|
||||
|
@ -1,14 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name CHEM \
|
||||
--dataset_name gigaspeech \
|
||||
--use_peft \
|
||||
--peft_type MMOELORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--peft_type LORA \
|
||||
--model_name_or_path Qwen/Qwen2-Audio-7B-Instruct \
|
||||
--lora_target_modules q_proj v_proj \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16
|
||||
|
@ -62,15 +62,12 @@ def create_accelerator_and_postprocess(args):
|
||||
"when using FSDP."
|
||||
)
|
||||
|
||||
def propagate_args_to_deepspeed(auto_find_batch_size=False):
|
||||
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
|
||||
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
||||
|
||||
ds_plugin = accelerator.state.deepspeed_plugin
|
||||
|
||||
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
|
||||
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
|
||||
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size)
|
||||
|
||||
if is_deepspeed_enabled and getattr(args, "hf_deepspeed_config", None) is None:
|
||||
propagate_args_to_deepspeed()
|
||||
ds_plugin.hf_ds_config.trainer_config_process(args, auto_find_batch_size=False)
|
||||
return accelerator
|
||||
|
@ -1,6 +1,11 @@
|
||||
# _________________________________________________________
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers.trainer import (
|
||||
Trainer,
|
||||
_is_peft_model,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
|
||||
class ContinualTrainer(Trainer):
|
||||
@ -13,8 +18,7 @@ class ContinualTrainer(Trainer):
|
||||
def create_accelerator_and_postprocess(self):
|
||||
if self.accelerator is not None:
|
||||
self.is_deepspeed_enabled = (
|
||||
getattr(self.accelerator.state, "deepspeed_plugin", None)
|
||||
is not None
|
||||
getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
)
|
||||
self.is_fsdp_enabled = (
|
||||
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
@ -23,3 +27,54 @@ class ContinualTrainer(Trainer):
|
||||
return
|
||||
else:
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if (
|
||||
self.label_smoother is not None or self.compute_loss_func is not None
|
||||
) and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if labels is not None:
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
if _is_peft_model(unwrapped_model):
|
||||
model_name = unwrapped_model.base_model.model._get_name()
|
||||
else:
|
||||
model_name = unwrapped_model._get_name()
|
||||
# User-defined compute_loss function
|
||||
if self.compute_loss_func is not None:
|
||||
loss = self.compute_loss_func(
|
||||
outputs, labels, num_items_in_batch=num_items_in_batch
|
||||
)
|
||||
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||
else:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
else:
|
||||
if isinstance(outputs, dict) and "loss" not in outputs:
|
||||
raise ValueError(
|
||||
"The model did not return a loss from the inputs, only the following keys: "
|
||||
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
||||
)
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
Loading…
Reference in New Issue
Block a user