景区网站怎么做做网站work什
sageattention 据说比flash_atten_2还要快很多。 但是如何在deepseekvl2这训练这里把它用上呢?
1.本质上sageattention是sdpa,SDPA的全称为Scaled Dot-Product Attention, 属于乘性注意力机制, 简单一句话来说就是,根据Query (Q)与Key之间的匹配度来对Value进行加权,而事实上不管是Query, Key还是Value都来自于输入,因此所谓的SDPA本质上是对输入信息信息进行重组。
2. sageattention使用了Triton这个包,它可以把python的代码编译成目标机器码,大大加速这个运算的速度。
3. 从官方例子可以看到,它本质的工作原理就是简单的替换torch.nn.functional.scaled_dot_product_attention这个函数,示例如下:
from sageattention import sageattn
import torch.nn.functional as F
。。。F.scaled_dot_product_attention = sageattn
所以,要用上sageatten,其实只需要原来的模型支持sdpa的注意力机制即可。
但很不幸,从deepseekvl2官方开源github可以看到DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM都是不支持sdpa的,这个两个类都没有声明_supports_sdpa = True
ATTENTION_CLASSES = {"eager": DeepseekV2Attention,"flash_attention_2": DeepseekV2FlashAttention2,"mla_eager": DeepseekV2Attention,"mla_flash_attention_2": DeepseekV2FlashAttention2,"mha_eager": LlamaAttention,"mha_flash_attention_2": LlamaFlashAttention2
}
这个attention_class也没有sdpa的实现。
因此,deepseekvl2无法直接简单使用sageattion,我们需要改一下deepseek的开源代码,才有可能用上sageattion.
修改步骤如下:
1. 首先要让DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM先支持sdpa,添加_supports_sdpa,这样transformers/modeling_utils.py的_check_and_enable_sdpa才可以检查通过.
class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):_tied_weights_keys = ["lm_head.weight"]_supports_sdpa = True
class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):_supports_sdpa = True
2.然后在DeepSeek-VL2/deepseek_vl2/models/modeling_deepseek.py添加sdpa的attention的实现类,我们让它继承LlamaAttention,如下,其实这个直接抄的LlamaSdpaAttention,copy是为了修改方便,实现如下:
class DeepSeekSdpaAttention(LlamaAttention):"""Deepseek attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from`DeepseekV2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt toSDPA API."""# Adapted from LlamaAttention.forwarddef forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_value: Optional[Cache] = None,output_attentions: bool = False,use_cache: bool = False,cache_position: Optional[torch.LongTensor] = None,position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46**kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:if output_attentions:# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.logger.warning_once("LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.')return super().forward(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position,position_embeddings=position_embeddings,)bsz, q_len, _ = hidden_states.size()query_states = self.q_proj(hidden_states)key_states = self.k_proj(hidden_states)value_states = self.v_proj(hidden_states)# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is usedquery_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)if position_embeddings is None:logger.warning_once("The attention layers in this model are transitioning from computing the RoPE embeddings internally ""through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ""`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ""removed and `position_embeddings` will be mandatory.")cos, sin = self.rotary_emb(value_states, position_ids)else:if isinstance(position_embeddings, torch.Tensor):cos, sin = self.rotary_emb(value_states, position_ids)else:cos, sin = position_embeddingsquery_states, key_states = apply_rotary_pos_emb2(query_states, key_states, cos, sin)if past_key_value is not None:# sin and cos are specific to RoPE models; cache_position needed for the static cachecache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)key_states = repeat_kv(key_states, self.num_key_value_groups)value_states = repeat_kv(value_states, self.num_key_value_groups)causal_mask = attention_maskif attention_mask is not None:causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,# Reference: https://github.com/pytorch/pytorch/issues/112577.if query_states.device.type == "cuda" and causal_mask is not None:query_states = query_states.contiguous()key_states = key_states.contiguous()value_states = value_states.contiguous()# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.is_causal = True #if causal_mask is None and q_len > 1 else Falseattn_output = sageattn(query_states,key_states,value_states,attn_mask=causal_mask,dropout_p=self.attention_dropout if self.training else 0.0,is_causal=is_causal,)attn_output = attn_output.transpose(1, 2).contiguous()attn_output = attn_output.view(bsz, q_len, -1)attn_output = self.o_proj(attn_output)return attn_output, None, past_key_value
3. 修改modeling_deepseek.py的ATTENTION_CLASSES,加上sdpa支持,如下:
ATTENTION_CLASSES = {"eager": DeepseekV2Attention,"flash_attention_2": DeepseekV2FlashAttention2,"mla_eager": DeepseekV2Attention,"mla_flash_attention_2": DeepseekV2FlashAttention2,"mha_eager": LlamaAttention,"mha_flash_attention_2": LlamaFlashAttention2,"mha_sdpa": DeepSeekSdpaAttention
}
4.使用的sageattion和Triton的版本如下:
Name: sageattention
Version: 1.0.6Name: triton
Version: 3.2.0
5. 训练测试,
swift sft --model "deepseek-ai/deepseek-vl2-tiny" --dataset ../TEST.json --attn_impl sdp。。。
using attn_implementation: mha_sdpa
[INFO:swift] model.hf_device_map: {'': device(type='cuda', index=0)}。。。
[INFO:swift] End time of running main: 2025-03-28 10:43:29.304164
成功进行了训练。