当前位置: 首页 > news >正文

从代码学习深度学习 - 注意力汇聚:注意力评分函数(加性和点积注意力) PyTorch 版

文章目录

  • 前言
  • 一、掩蔽 Softmax 操作
    • 1.1 sequence_mask
    • 1.2 masked_softmax
    • 1.3 测试代码
  • 二、加性注意力 (Additive Attention)
    • 2.1 实现解析
    • 2.2 测试代码
  • 三、点积注意力 (Dot Product Attention)
    • 3.1 实现解析
    • 3.2 测试代码
  • 四、可视化注意力权重
    • 4.1 可视化点积注意力的权重
  • 总结


前言

在深度学习领域,注意力机制(Attention Mechanism)已经成为许多模型的核心组件,尤其是在自然语言处理(NLP)和计算机视觉任务中。注意力机制的核心思想是通过计算查询(Query)与键(Key)之间的相关性,动态地为值(Value)分配权重,从而聚焦于最重要的信息。本篇博客将通过 PyTorch 代码,深入探讨注意力汇聚(Attention Pooling)的两种常见评分函数:加性注意力(Additive Attention)和点积注意力(Dot Product Attention)。我们将从代码实现入手,逐步解析其原理,并通过可视化工具展示注意力权重的分布。
在这里插入图片描述

本文的目标读者是对深度学习有一定基础、希望通过代码理解注意力机制的实现细节的开发者。所有代码均基于 PyTorch,并在 Jupyter Notebook 中运行和测试。让我们开始吧!


一、掩蔽 Softmax 操作

在注意力机制中,掩蔽 Softmax(Masked Softmax)是一个关键步骤,用于确保模型只关注序列中的有效部分,避免对填充(padding)数据产生影响。我们先来看两个核心函数的实现:sequence_maskmasked_softmax

1.1 sequence_mask

sequence_mask 函数用于在序列中屏蔽不相关的项。它接收输入序列张量 X、有效长度张量 valid_len,并将无效位置替换为指定值(默认值为 0)。

import torch
import torch.nn as nn

def sequence_mask(X, valid_len, value=0):
    """
    在序列中屏蔽不相关的项
    
    参数:
        X: 输入序列张量,维度 [batch_size, maxlen]
        valid_len: 有效长度张量,维度 [batch_size]
        value: 填充值,标量,默认为0
    
    返回:
        X: 屏蔽后的序列张量,维度 [batch_size, maxlen]
    
    Defined in :numref:`sec_seq2seq_decoder`
    """
    # 获取序列的最大长度,维度为标量
    maxlen = X.size(1)
    
    # 创建掩码矩阵
    # torch.arange(maxlen): 生成 [0, 1, ..., maxlen-1] 的序列,维度 [maxlen]
    # [None, :] 将其扩展为 [1, maxlen]
    # valid_len[:, None] 将 [batch_size] 扩展为 [batch_size, 1]
    # 比较结果 mask 维度为 [batch_size, maxlen]
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    
    # 使用掩码将 X 中无效位置设为 value
    # ~mask 为反向掩码,选择需要填充的位置
    X[~mask] = value
    
    return X

这个函数的工作原理是:

  1. 通过 torch.arange(maxlen) 生成一个从 0 到 maxlen-1 的序列,并扩展为与批量大小匹配的形状。
  2. 使用广播机制,将 valid_len 与生成的序列比较,生成布尔掩码 mask
  3. 根据掩码,将无效位置(即超出有效长度的部分)替换为 value

1.2 masked_softmax

masked_softmax 函数在 Softmax 操作中加入掩蔽机制,确保无效位置的注意力权重为 0。

def masked_softmax(X, valid_lens):
    """
    通过在最后一个轴上掩蔽元素来执行softmax操作
    
    参数:
        X: 三维张量 (batch_size, seq_len, feature_dim)
        valid_lens: 一维张量 (batch_size,) 或二维张量 (batch_size, seq_len),表示有效长度
        
    返回:
        经过masked softmax处理的张量 (batch_size, seq_len, feature_dim)
    """
    if valid_lens is None:
        # 当没有指定有效长度时,直接执行标准softmax
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape  # shape: (batch_size, seq_len, feature_dim)
        if valid_lens.dim() == 1:
            # 将一维的valid_lens重复扩展到与X的第二维匹配
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            # 将二维的valid_lens展平为一维
            valid_lens = valid_lens.reshape(-1)
            
        # 在最后一轴上对被掩蔽的元素使用非常大的负值替换,使其softmax输出为0
        X = sequence_mask(X.reshape(-1, shape[-1]), 
                          valid_lens,
                          value=-1e6)
        # 执

相关文章:

  • SQL问题分析与诊断(8)——其他工具和技术
  • ECMAScript 7~10 新特性
  • RLAgent note
  • 数据结构与算法-动态规划-线性动态规划,0-1背包,多重背包,完全背包,有依赖的背包,分组背包,背包计数,背包路径
  • 取消echarts地图悬浮时默认黄色高亮
  • Sigma-Delta ADC调制器的拓扑结构分类
  • java中的JNI调用c库
  • 若依微服务集成Flowable仿钉钉工作流
  • 【JavaScript】十八、页面加载事件和页面滚动事件
  • 基于AI的Web应用防火墙(AppWall)实战:漏洞拦截与威胁情报集成
  • 深入理解Java反射
  • 导入 Excel 批量替换文件名称及扩展名
  • react中通过 EventEmitter 在组件间传递状态
  • QTreeWidget 手动设置选中项后不高亮的问题
  • rbd块设备的id修改
  • 纳米软件储能电源模块自动化测试深度解析
  • Git版本管理系列:(三)远程仓库
  • vxe-table4.6 + vue3.2 + ant-design-vue 3.x 实现对列的显示、隐藏、排序
  • MYSQL-创建和使用表
  • Higress: 阿里巴巴高性能云原生API网关详解
  • 平台类网站有哪些/企业网站建设的基本流程
  • 三年疫情最后成了闹剧/优化方案电子版
  • 嘉兴本地推广网站/体验式营销
  • 国内做网站比较好的公司有哪些/互联网营销案例
  • 怎样帮人做网站挣钱/给我免费播放片高清在线观看
  • wordpress 微商网站/什么是软文营销?