当前位置: 首页 > news >正文

GPT - 多头注意力机制(Multi-Head Attention)模块

本节代码实现了一个多头注意力机制(Multi-Head Attention)模块,它是Transformer架构中的核心组件之一。
 

⭐关于多头自注意力机制的数学原理请见文章:

Transformer - 多头自注意力机制复现-CSDN博客
本节要求理解原理后手敲实现多头注意力机制

1. 初始化部分

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.q_project = nn.Linear(d_model, d_model)
        self.k_project = nn.Linear(d_model, d_model)
        self.v_project = nn.Linear(d_model, d_model)
        self.o_project = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
  • d_model:模型的维度,表示输入的特征维度。

  • num_heads:注意力头的数量。多头注意力机制将输入分成多个不同的“头”,每个头学习不同的特征,最后再将这些特征合并起来。

  • d_k:每个头的维度,计算公式为d_model // num_heads。例如,如果d_model=512num_heads=8,则每个头的维度为512 // 8 = 64

  • q_projectk_projectv_project:这三个线性层分别用于将输入x投影到查询(Query)、键(Key)和值(Value)空间。投影后的维度仍然是d_model

  • o_project:输出投影层,将多头注意力的结果再次投影到d_model维度。

  • dropout:用于防止过拟合的Dropout层。

2. 前向传播部分

def forward(self, x, attn_mask=None):
    batch_size, seq_len, d_model = x.shape
    Q = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    K = self.k_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    V = self.v_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
  • 输入x:形状为(batch_size, seq_len, d_model),其中seq_len是序列长度。

  • 投影操作

    • 使用q_projectk_projectv_project将输入x分别投影到查询(Q)、键(K)和值(V)空间。

    • 投影后的张量形状为(batch_size, seq_len, d_model)

  • 多头拆分

    • 使用.view(batch_size, seq_len, self.num_heads, self.d_k)将投影后的张量拆分成多个头,形状变为(batch_size, seq_len, num_heads, d_k)

    • 使用.transpose(1, 2)将头的维度提到前面,形状变为(batch_size, num_heads, seq_len, d_k)

    atten_scores = Q @ K.transpose(2, 3) / math.sqrt(self.d_k)
  • 计算注意力分数

    • 使用矩阵乘法@计算QK的点积,K.transpose(2, 3)K的形状变为(batch_size, num_heads, d_k, seq_len)

    • 点积结果的形状为(batch_size, num_heads, seq_len, seq_len)表示每个位置之间的注意力分数。

    • 除以math.sqrt(self.d_k)是为了防止点积结果过大,导致梯度消失或爆炸。

    if attn_mask is not None:
        attn_mask = attn_mask.unsqueeze(1)
        atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)
  • 注意力掩码(关于掩码的具体实现将在下一篇文章进行讲解)

    • 如果提供了注意力掩码attn_mask,则使用unsqueeze(1)将掩码的形状扩展为(batch_size, 1, seq_len, seq_len)

    • 使用masked_fill将掩码为0的位置的注意力分数设置为一个非常小的值(如-1e9),这样在softmax计算时,这些位置的注意力权重会接近0。

    atten_scores = torch.softmax(atten_scores, dim=-1)
    out = atten_scores @ V
  • 归一化注意力分数

    • 使用torch.softmax对注意力分数进行归一化,形状仍为(batch_size, num_heads, seq_len, seq_len)

  • 计算加权和

    • 使用矩阵乘法@将归一化后的注意力分数与V相乘,得到每个头的加权和,形状为(batch_size, num_heads, seq_len, d_k)

    out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
    out = self.o_project(out)
    return self.dropout(out)
  • 合并多头结果

    • 使用.transpose(1, 2)将头的维度放回原来的位置,形状变为(batch_size, seq_len, num_heads, d_k)

    • 使用.contiguous().view(batch_size, seq_len, d_model)将多头结果合并成一个张量,形状为(batch_size, seq_len, d_model)

  • 输出投影

    • 使用o_project将合并后的结果再次投影到d_model维度。

  • Dropout

    • 使用dropout层对输出进行Dropout操作,防止过拟合。

需复现完整代码

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.q_project = nn.Linear(d_model, d_model)
        self.k_project = nn.Linear(d_model, d_model)
        self.v_project = nn.Linear(d_model, d_model)
        self.o_project = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        
        batch_size, seq_len, d_model = x.shape
        Q = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        atten_scores = Q @ K.transpose(2, 3) / math.sqrt(self.d_k)

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(1)
            atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)

        atten_scores = torch.softmax(atten_scores, dim=-1)
        out = atten_scores @ V
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.o_project(out)
        return self.dropout(out)

相关文章:

  • AI应用开发平台 和 通用自动化工作流工具 的详细对比,涵盖定义、核心功能、典型工具、适用场景及优缺点分析
  • CTF web入门之文件包含
  • SAP BDC:企业数据管理的新纪元
  • flink部署使用(flink-connector-jdbc)连接达梦数据库并写入读取数据
  • NO.85十六届蓝桥杯备战|动态规划-经典线性DP|最长上升子序列|合唱队形|最长公共子序列|编辑距离(C++)
  • FreeRTOS入门与工程实践-基于STM32F103(一)(单片机程序设计模式,FreeRTOS源码概述,内存管理,任务管理,同步互斥与通信,队列,信号量)
  • BGP分解实验·23——BGP选路原则之路由器标识
  • 最新版IDEA超详细图文安装教程(适用Mac系统)附安装包及补丁2025最新教程
  • 首批 | 云轴科技ZStack通过电子标准院云上部署DeepSeek验证测试
  • Tkinter高级布局与窗口管理
  • Node.js中util模块详解
  • 【golang/jsonrpc】go-ethereum中json rpc初步使用(websocket版本)
  • vue3使用keep-alive缓存组件与踩坑日记
  • [实战] 二分查找与哈希表查找:原理、对比与C语言实现(附完整C代码)
  • PostgreSQL 实例运行状态全面检查
  • 考研数据结构精讲:数组与特殊矩阵的压缩存储技巧(包含真题及解析)
  • 大数据面试问答-Hadoop/Hive/HDFS/Yarn
  • 基于SpringBoot汽车零件商城系统设计和实现(源码+文档+部署讲解)
  • vue3+nodeJs+webSocket实现聊天功能
  • stack overflow国内无法访问原因
  • 长江画派创始人之一、美术家鲁慕迅逝世,享年98岁
  • 乡村快递取件“跑腿费”屡禁不止?云南元江县公布举报电话
  • 澎湃研究所“营商环境研究伙伴计划”启动
  • 黄玮接替周继红出任国家体育总局游泳运动管理中心主任
  • 黄晨光任中科院空间应用工程与技术中心党委书记、副主任
  • 复旦设立新文科发展基金,校友曹国伟、王长田联合捐赠1亿元