当前位置: 首页 > news >正文

手写mask|代码详解,TriangularCausalMask/ProbMask/LocalMask

文章目录

    • 一、TriangularCausalMask(三角因果掩码)
      • 功能与原理
        • 核心作用:
        • 实现细节:
        • 示例:
        • 应用场景
    • 二、ProbMask(概率掩码)
      • 功能与原理
        • 核心作用:
        • 实现细节:
        • 示例
        • 应用场景
    • 三、LocalMask(局部掩码)
      • 功能与原理
        • 核心作用:
        • 实现细节:
        • 示例
        • 应用场景
    • 三种方法对比

一、TriangularCausalMask(三角因果掩码)

class TriangularCausalMask():def __init__(self, B, L, S=None, device="cpu"):# B: 批次大小, L: 查询序列长度, S: 键/值序列长度(默认与L相同)if S is not None:mask_shape = [B, 1, L, S]  # 交叉注意力场景(Query与Key长度不同)else:mask_shape = [B, 1, L, L]  # 自注意力场景(Query与Key长度相同)with torch.no_grad():  # 不计算梯度(掩码在推理时固定)# 生成上三角矩阵(对角线及以下为0,对角线以上为1)# diagonal=1表示对角线向上偏移1位,即对角线本身为0self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)@propertydef mask(self):return self._mask  # 返回掩码张量(True表示需屏蔽的位置)

功能与原理

核心作用:

实现因果屏蔽(Causal Masking),确保序列中每个位置只能关注其过去或当前的位置,不能看到未来的信息。这是自回归模型(如语言模型、时序预测)的基础,避免预测时 “偷看” 未来数据。

实现细节:

通过torch.triu(…, diagonal=1)生成上三角矩阵,对角线及以下为False(允许关注),对角线以上为True(屏蔽)。
支持两种形状:

  • 当S=None时,掩码为[B, 1, L, L],适用于自注意力(Query 和 Key 长度相同)。
  • 当S≠None时,掩码为[B, 1, L, S],适用于交叉注意力(Query 长度为L,Key/Value 长度为S)。
示例:

在这里插入图片描述

应用场景
  • 自回归任务:如文本生成(GPT 系列)、时序预测(未来值仅依赖历史值)。
  • 交叉注意力场景:如 Encoder-Decoder 架构中,Decoder 的 Query 屏蔽未来 Token,而 Encoder 的 Key/Value 无需屏蔽(因 Encoder 处理全序列)。

二、ProbMask(概率掩码)

class ProbMask():def __init__(self, B, H, L, index, scores, device="cpu"):# B: 批次大小, H: 注意力头数, L: 查询序列长度# index: 选中的key位置索引(通常是top-k个最重要的位置)# scores: 注意力分数张量 [B, H, L, S]# 1. 创建基础三角掩码(屏蔽未来位置)_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)# 2. 扩展掩码至四维 [B, H, L, S],适配批次和头数_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])# 3. 根据index从扩展掩码中提取对应位置的掩码值# torch.arange(B)[:, None, None]: [B, 1, 1],批次索引# torch.arange(H)[None, :, None]: [1, H, 1],头索引# index: [B, H, L],每个位置选中的key索引indicator = _mask_ex[torch.arange(B)[:, None, None],torch.arange(H)[None, :, None],index, :].to(device)# 4. 调整形状与scores一致,得到最终掩码self._mask = indicator.view(scores.shape).to(device)@propertydef mask(self):return self._mask  # 返回掩码张量

功能与原理

核心作用:

在稀疏注意力机制(如 ProbSparse Attention)中,根据注意力分数动态选择关键位置,屏蔽冗余连接,降低计算复杂度。

实现细节:
  • 基础三角掩码:首先创建L×S的上三角掩码_mask(屏蔽未来位置)。
  • 扩展与索引:
    – 将掩码扩展为[B, H, L, S],适配批次和头数。
    – 通过index(通常是 top-k 个高注意力分数的位置索引)从扩展掩码中提取对应位置的屏蔽状态,生成最终掩码。
  • 关键变量:
    index:形状为[B, H, L],表示每个 Query 位置(L)在头(H)和批次(B)下选择的 Key 位置索引。
    scores:注意力分数,形状为[B, H, L, S],用于确定哪些 Key 位置重要。
  • 效果:仅屏蔽非关键位置(低注意力分数且为未来的位置),保留关键历史位置和当前位置,实现 “按需屏蔽”。
示例

在这里插入图片描述

应用场景

高效注意力机制(如长序列优化)

三、LocalMask(局部掩码)

class LocalMask():def __init__(self, B, L, S, device="cpu"):# B: 批次大小, L: 查询序列长度, S: 键/值序列长度mask_shape = [B, 1, L, S]with torch.no_grad():# 计算局部窗口大小(基于序列长度的对数)# 例如: L=8 → len=3, L=16 → len=4self.len = math.ceil(np.log2(L))# 掩码1: 三角掩码(屏蔽未来位置,同TriangularCausalMask)self._mask1 = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)# 掩码2: 反向三角掩码(屏蔽超过len步的历史位置)# diagonal=-self.len表示保留从当前位置向前数len个位置self._mask2 = ~torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=-self.len).to(device)# 合并两个掩码(同时屏蔽未来位置和过远的历史位置)self._mask = self._mask1 + self._mask2@propertydef mask(self):return self._mask  # 返回合并后的掩码

功能与原理

核心作用:

结合因果屏蔽和局部窗口屏蔽,限制每个位置只能关注其局部历史窗口内的位置,同时屏蔽未来位置。适用于需要捕捉短期依赖的任务,或降低长序列的计算复杂度。

实现细节:
  • 参数len:通过math.ceil(np.log2(L))计算局部窗口长度,例如:
    L=8 → len=3(log2(8)=3),窗口大小为3
    L=5 → len=3(log2(5)≈2.32→ceil为3)
  • 双重掩码:
    _mask1:上三角掩码(屏蔽未来位置,同 TriangularCausalMask)。
    _mask2:下三角掩码,diagonal=-len表示屏蔽超过前len个位置的历史区域。例如:
    len=3时,每个位置只能看到前 3 个历史位置(包括自己),更早的位置被屏蔽。
  • 掩码合并:_mask = _mask1 + _mask2,即同时屏蔽未来位置和过远的历史位置,仅保留最近的len个历史位置 + 当前位置。
示例

在这里插入图片描述

应用场景
  • 局部依赖建模:如语音识别(关注邻近帧)、文本摘要(聚焦上下文)。
  • 长序列优化:通过限制历史窗口大小,将注意力计算复杂度从O(L²)降至O(L×len),适用于L较大的场景(如视频帧序列)。

三种方法对比

掩码类型核心目标屏蔽逻辑计算复杂度典型场景
TriangularCausalMask保证因果关系(无未来泄露)硬性屏蔽所有未来位置O(L²)自回归生成、时序预测
ProbMask稀疏化注意力(减少计算)动态屏蔽未来位置中的低重要性区域O (L×k)(k 为关键位置数)长序列高效建模(如 ProbSparse)
LocalMask局部历史依赖建模屏蔽未来位置 + 过远历史位置O (L×len)(len 为固定窗口)短窗口依赖任务、长序列加速

相关文章:

  • 电子电路:全面深入了解晶振的定义、作用及应用
  • 01 RK3568调试4G 模块 EG800AK-CN
  • SpringCloud 分布式锁Redisson锁的重入性与看门狗机制 高并发 可重入
  • Python语法基础篇(包含类型转换、拷贝、可变对象/不可变对象,函数,拆包,异常,模块,闭包,装饰器)
  • 深度学习入门——基于多层感知机的MNIST手写数字识别
  • Blinko智能笔记系统实现跨平台同步与隐私保护的完整技术方案解析
  • 【C/C++】template 入门到高阶简单大纲
  • 经典SQL查询问题的练习第四天
  • AutoCompose - 携程自动编排【开源】
  • 【亲测有效】Mybatis-Plus中更新字段为null
  • pytorch3d+pytorch1.10+MinkowskiEngine安装
  • PyTorch--池化层(4)
  • Attention Is All You Need (Transformer) 以及Transformer pytorch实现
  • pytorch基本运算-导数和f-string
  • 互联网大厂Java求职面试:AI大模型与云原生技术的深度融合
  • MySQL关系型数据库学习
  • 第三发 DSP 点击控制系统
  • 【MATLAB代码】制导方法介绍与例程——三点法|三维空间,动态目标导引(订阅专栏后可直接查看源代码)
  • leetcode hot100 链表(一)
  • matlab实现求解兰伯特问题