feat✨: 更新Python版本和依赖管理,优化训练脚本以支持Flash Attention 2
This commit is contained in:
parent
b766c21c9b
commit
90b3181f3f
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.11.7
|
@ -1,3 +1,3 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
uv venv --python 3.11.7
|
uv sync
|
||||||
pip install -U torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
uv sync --extra compile
|
||||||
|
46
pyproject.toml
Normal file
46
pyproject.toml
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
[project]
|
||||||
|
dependencies = [
|
||||||
|
"accelerate==1.2.1",
|
||||||
|
"datasets==3.2.0",
|
||||||
|
"deepspeed==0.16.2",
|
||||||
|
"evaluate==0.4.3",
|
||||||
|
"markupsafe==2.1.5",
|
||||||
|
"peft==0.14.0",
|
||||||
|
"pip==24.3.1",
|
||||||
|
"requests==2.32.3",
|
||||||
|
"setuptools>=70.0.0",
|
||||||
|
"torch==2.5.1+cu124",
|
||||||
|
"torchaudio==2.5.1+cu124",
|
||||||
|
"torchvision==0.20.1+cu124",
|
||||||
|
"transformers==4.46.1",
|
||||||
|
"trl==0.13.0",
|
||||||
|
"wheel>=0.45.1",
|
||||||
|
]
|
||||||
|
description = "Add your description here"
|
||||||
|
name = "cl-lmm"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
compile = [
|
||||||
|
"flash-attn>=2.7.2.post1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
markupsafe = {index = "pytorch"}
|
||||||
|
requests = {index = "pypi"}
|
||||||
|
torch = {index = "pytorch"}
|
||||||
|
torchaudio = {index = "pytorch"}
|
||||||
|
torchvision = {index = "pytorch"}
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pypi"
|
||||||
|
url = "https://pypi.org/simple"
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
no-build-isolation-package = ["flash-attn"]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch"
|
||||||
|
url = "https://download.pytorch.org/whl/cu124"
|
@ -1,11 +0,0 @@
|
|||||||
accelerate==1.2.1
|
|
||||||
deepspeed==0.16.2
|
|
||||||
evaluate==0.4.3
|
|
||||||
peft==0.14.0
|
|
||||||
pillow==10.2.0
|
|
||||||
torch==2.5.1+cu124
|
|
||||||
torchaudio==2.5.1+cu124
|
|
||||||
torchvision==0.20.1+cu124
|
|
||||||
transformers==4.46.1
|
|
||||||
trl==0.13.0
|
|
||||||
pillow==9.5.0
|
|
@ -1,5 +1,12 @@
|
|||||||
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
|
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
else:
|
||||||
|
flash_attn_varlen_func = None
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.cache_utils import DynamicCache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
import torch
|
import torch
|
||||||
@ -246,9 +253,145 @@ class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified):
|
|||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLFlashAttention2_modified(Qwen2VLAttention_modified):
|
||||||
|
"""
|
||||||
|
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
|
||||||
|
as the weights of the module stays untouched. The only required change would be on the forward pass
|
||||||
|
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
||||||
|
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
||||||
|
config.max_window_layers layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states, **kwargs)
|
||||||
|
key_states = self.k_proj(hidden_states, **kwargs)
|
||||||
|
value_states = self.v_proj(hidden_states, **kwargs)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {
|
||||||
|
"sin": sin,
|
||||||
|
"cos": cos,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
# Reashape to the expected shape for Flash Attention
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config.use_sliding_window
|
||||||
|
and getattr(self.config, "sliding_window", None) is not None
|
||||||
|
and self.layer_idx >= self.config.max_window_layers
|
||||||
|
):
|
||||||
|
sliding_window = self.config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
QWEN2_VL_ATTENTION_CLASSES = {
|
QWEN2_VL_ATTENTION_CLASSES = {
|
||||||
"eager": Qwen2VLAttention,
|
"eager": Qwen2VLAttention,
|
||||||
"flash_attention_2": Qwen2VLFlashAttention2,
|
"flash_attention_2": Qwen2VLFlashAttention2_modified,
|
||||||
"sdpa": Qwen2VLSdpaAttention_modified,
|
"sdpa": Qwen2VLSdpaAttention_modified,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ if __name__ == "__main__":
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
print(model)
|
||||||
from model_library.qwen2vl import (
|
from model_library.qwen2vl import (
|
||||||
collate_fn_for_train,
|
collate_fn_for_train,
|
||||||
collate_fn_for_evaluate,
|
collate_fn_for_evaluate,
|
||||||
@ -85,13 +86,16 @@ if __name__ == "__main__":
|
|||||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, peft_config)
|
||||||
# model = inject_adapter_in_model(peft_config, model)
|
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
model.print_trainable_parameters()
|
||||||
elif model_args.peft_type == "LORA":
|
elif model_args.peft_type == "LORA":
|
||||||
from peft.tuners.lora import LoraConfig
|
from peft.tuners.lora import LoraConfig
|
||||||
|
|
||||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml
|
|||||||
--lora_target_modules q_proj v_proj \
|
--lora_target_modules q_proj v_proj \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
--per_device_eval_batch_size 2 \
|
--per_device_eval_batch_size 2 \
|
||||||
--gradient_accumulation_steps 8 \
|
--attn_implementation flash_attention_2 \
|
||||||
|
--gradient_accumulation_steps 16 \
|
||||||
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
--torch_dtype bfloat16
|
--torch_dtype bfloat16
|
||||||
|
Loading…
Reference in New Issue
Block a user