feat: 更新Python版本和依赖管理,优化训练脚本以支持Flash Attention 2

This commit is contained in:
YunyaoZhou 2025-01-11 03:53:23 +08:00
parent b766c21c9b
commit 90b3181f3f
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
8 changed files with 1807 additions and 16 deletions

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.11.7

View File

@ -1,3 +1,3 @@
#!/bin/bash
uv venv --python 3.11.7
pip install -U torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
uv sync
uv sync --extra compile

46
pyproject.toml Normal file
View File

@ -0,0 +1,46 @@
[project]
dependencies = [
"accelerate==1.2.1",
"datasets==3.2.0",
"deepspeed==0.16.2",
"evaluate==0.4.3",
"markupsafe==2.1.5",
"peft==0.14.0",
"pip==24.3.1",
"requests==2.32.3",
"setuptools>=70.0.0",
"torch==2.5.1+cu124",
"torchaudio==2.5.1+cu124",
"torchvision==0.20.1+cu124",
"transformers==4.46.1",
"trl==0.13.0",
"wheel>=0.45.1",
]
description = "Add your description here"
name = "cl-lmm"
readme = "README.md"
requires-python = ">=3.11"
version = "0.1.0"
[project.optional-dependencies]
compile = [
"flash-attn>=2.7.2.post1",
]
[tool.uv.sources]
markupsafe = {index = "pytorch"}
requests = {index = "pypi"}
torch = {index = "pytorch"}
torchaudio = {index = "pytorch"}
torchvision = {index = "pytorch"}
[[tool.uv.index]]
name = "pypi"
url = "https://pypi.org/simple"
[tool.uv]
no-build-isolation-package = ["flash-attn"]
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu124"

View File

@ -1,11 +0,0 @@
accelerate==1.2.1
deepspeed==0.16.2
evaluate==0.4.3
peft==0.14.0
pillow==10.2.0
torch==2.5.1+cu124
torchaudio==2.5.1+cu124
torchvision==0.20.1+cu124
transformers==4.46.1
trl==0.13.0
pillow==9.5.0

View File

@ -1,5 +1,12 @@
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from transformers.modeling_flash_attention_utils import _flash_attention_forward
else:
flash_attn_varlen_func = None
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast
import torch
@ -246,9 +253,145 @@ class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified):
return attn_output, None, past_key_value
class Qwen2VLFlashAttention2_modified(Qwen2VLAttention_modified):
"""
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, **kwargs)
key_states = self.k_proj(hidden_states, **kwargs)
value_states = self.v_proj(hidden_states, **kwargs)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_value is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
"flash_attention_2": Qwen2VLFlashAttention2,
"flash_attention_2": Qwen2VLFlashAttention2_modified,
"sdpa": Qwen2VLSdpaAttention_modified,
}

View File

@ -62,6 +62,7 @@ if __name__ == "__main__":
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
print(model)
from model_library.qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
@ -85,13 +86,16 @@ if __name__ == "__main__":
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
model = get_peft_model(model, peft_config)
# model = inject_adapter_in_model(peft_config, model)
if accelerator.is_local_main_process:
model.print_trainable_parameters()
elif model_args.peft_type == "LORA":
from peft.tuners.lora import LoraConfig
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
model = get_peft_model(model, peft_config)
if accelerator.is_local_main_process:
model.print_trainable_parameters()

View File

@ -8,7 +8,8 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml
--lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--attn_implementation flash_attention_2 \
--gradient_accumulation_steps 16 \
--output_dir checkpoint/sft-llava-1.5-7b-hf \
--bf16 \
--torch_dtype bfloat16

1607
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff