当前位置: 首页 > news >正文

如何在ms-swift 微调训练deepseekvl2时使用sageattention

        sageattention 据说比flash_atten_2还要快很多。 但是如何在deepseekvl2这训练这里把它用上呢?

        1.本质上sageattention是sdpa,SDPA的全称为Scaled Dot-Product Attention, 属于乘性注意力机制, 简单一句话来说就是,根据Query (Q)与Key之间的匹配度来对Value进行加权,而事实上不管是Query, Key还是Value都来自于输入,因此所谓的SDPA本质上是对输入信息信息进行重组。

       2. sageattention使用了Triton这个包,它可以把python的代码编译成目标机器码,大大加速这个运算的速度。

        3. 从官方例子可以看到,它本质的工作原理就是简单的替换torch.nn.functional.scaled_dot_product_attention这个函数,示例如下:

from sageattention import sageattn
import torch.nn.functional as F
。。。
    F.scaled_dot_product_attention = sageattn

        所以,要用上sageatten,其实只需要原来的模型支持sdpa的注意力机制即可。

        但很不幸,从deepseekvl2官方开源github可以看到DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM都是不支持sdpa的,这个两个类都没有声明_supports_sdpa = True

ATTENTION_CLASSES = {
    "eager": DeepseekV2Attention,
    "flash_attention_2": DeepseekV2FlashAttention2,

    "mla_eager": DeepseekV2Attention,
    "mla_flash_attention_2": DeepseekV2FlashAttention2,

    "mha_eager": LlamaAttention,
    "mha_flash_attention_2": LlamaFlashAttention2
}

        这个attention_class也没有sdpa的实现。

        因此,deepseekvl2无法直接简单使用sageattion,我们需要改一下deepseek的开源代码,才有可能用上sageattion.

        修改步骤如下:

1. 首先要让DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM先支持sdpa,添加_supports_sdpa,这样transformers/modeling_utils.py的_check_and_enable_sdpa才可以检查通过.

class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    _supports_sdpa = True
class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
    _supports_sdpa = True

2.然后在DeepSeek-VL2/deepseek_vl2/models/modeling_deepseek.py添加sdpa的attention的实现类,我们让它继承LlamaAttention,如下,其实这个直接抄的LlamaSdpaAttention,copy是为了修改方便,实现如下:


class DeepSeekSdpaAttention(LlamaAttention):
    """
    Deepseek attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `DeepseekV2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            if isinstance(position_embeddings, torch.Tensor):
                cos, sin = self.rotary_emb(value_states, position_ids)
            else:
                cos, sin = position_embeddings


        query_states, key_states = apply_rotary_pos_emb2(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True  #if causal_mask is None and q_len > 1 else False

        attn_output = sageattn(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value

3. 修改modeling_deepseek.py的ATTENTION_CLASSES,加上sdpa支持,如下:


ATTENTION_CLASSES = {
    "eager": DeepseekV2Attention,
    "flash_attention_2": DeepseekV2FlashAttention2,

    "mla_eager": DeepseekV2Attention,
    "mla_flash_attention_2": DeepseekV2FlashAttention2,

    "mha_eager": LlamaAttention,
    "mha_flash_attention_2": LlamaFlashAttention2,
    "mha_sdpa": DeepSeekSdpaAttention
}

4.使用的sageattion和Triton的版本如下:

Name: sageattention
Version: 1.0.6

Name: triton
Version: 3.2.0

5. 训练测试,

swift sft  --model "deepseek-ai/deepseek-vl2-tiny"  --dataset  ../TEST.json  --attn_impl sdp

。。。
using attn_implementation: mha_sdpa
[INFO:swift] model.hf_device_map: {'': device(type='cuda', index=0)}

。。。
[INFO:swift] End time of running main: 2025-03-28 10:43:29.304164

成功进行了训练。

http://www.dtcms.com/a/99614.html

相关文章:

  • flutter优秀项目推荐
  • 【Spring Boot 与 Spring Cloud 深度 Mape 之五】微服务守门神:Spring Cloud Gateway 核心详解与实战
  • Linux下xl9535 gpio扩展芯片bug调试
  • Java面试黄金宝典16
  • C语言_数据结构_排序
  • LeetCode 每日一题 2025/3/24-2025/3/30
  • Typora使用Gitee作为图床
  • Windows模仿Mac大小写切换, 中英文切换
  • Python自动化面试通关秘籍
  • 相似度计算 ccf-csp 2024-2-2
  • 网络华为HCIA+HCIP ip-prefix,route-policy
  • DBeaver Error : Public Key Retrieval is not allowed
  • 可视化图解算法: 二叉树的前序遍历
  • 算法-前缀和与差分
  • 【hadoop】远程调试环境
  • 用Python打造智能宠物:强化学习的奇妙之旅
  • 计算机三级信息安全部分英文缩写
  • 【MyBatis】MyBatis 操作数据库
  • Windows学习笔记(4)关于MITRE
  • 解决 FFmpeg 使用 C/C++ 接口时,解码没有 shell 快的问题(使用多线程)
  • 用Python实现资本资产定价模型(CAPM)
  • ubuntu 安装mysql
  • Python 中列表(List)、元组(Tuple)、集合(Set)和字典(Dict)四大数据结构的完整对比
  • macOS Jdk1.8安装(目前主流版本的jdk)
  • 【漫话机器学习系列】163.方差膨胀因子(Variance Inflation Factor, VIF)
  • Spring 通过多种方式实现使用线程
  • 在用redis当中可能遇到的问题解决方案以及redis中的一些名词解释
  • HTML 标签类型全面介绍
  • docker-compese 启动mysql8.0.36与phpmyadmin,并使用web连接数据库
  • Reactive编程:数据流和观察者