当前位置: 首页 > news >正文

仅仅使用pytorch来手撕transformer架构(2):多头注意力MultiHeadAttention类的实现和向前传播

手撕MultiHeadAttention 类的代码,结合具体的例子来说明每一步的作用和计算过程。

往期文章:
仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播

最适合小白入门的Transformer介绍

1. 初始化方法 __init__

def __init__(self, embed_size, heads):
    super(MultiHeadAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size // heads

    assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"

    self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

1.1参数解释

  • embed_size:嵌入向量的维度,表示每个输入向量的大小。
  • heads:注意力头的数量。多头注意力机制将输入分割成多个“头”,每个头学习不同的特征。
  • head_dim:每个注意力头的维度大小,计算公式为 embed_size // heads。这意味着每个头处理的特征子集的大小。

1.2线性变换层

  • self.valuesself.keysself.queries

    • 这些是线性变换层,用于将输入的嵌入向量分别转换为值(Values)、键(Keys)和查询(Queries)。
    • 每个线性层的输入和输出维度都是 self.head_dim,因为每个头处理的特征子集大小为 self.head_dim
    • 使用 bias=False 是为了简化计算,避免引入额外的偏置项。
  • self.fc_out

    • 在多头注意力计算完成后,将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度 embed_size

2. 前向传播方法 forward

def forward(self, values, keys, query, mask):
    N = query.shape[0]  # Batch size
    value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

2.1输入参数

  • valueskeysquery
    • 这三个输入张量的形状通常为 (batch_size, seq_len, embed_size)
    • 它们分别对应于值(Values)、键(Keys)和查询(Queries)。
  • mask
    • 用于遮蔽某些位置的注意力权重,避免模型关注到不应该关注的部分(例如,解码器中的未来信息)。

2.2多头注意力计算过程

2.2.1 将输入嵌入分割为多个头:
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
  • 将输入的嵌入向量分割成 heads 个头,每个头的维度为 self.head_dim
  • 例如,如果 embed_size = 256heads = 8,则 self.head_dim = 32,每个头处理 32 维的特征。
  • 重塑后的形状为 (N, seq_len, heads, head_dim)
2.2.2 线性变换:
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
  • 对每个头的值、键和查询分别进行线性变换。
  • 这一步将输入特征投影到不同的子空间中,使得每个头可以学习不同的特征。
2.2.3计算注意力分数(Attention Scores):
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
  • 使用 torch.einsum 计算查询和键之间的点积,得到注意力分数矩阵。
  • 公式 nqhd,nkhd->nhqk 表示:
    • n:批量大小(Batch Size)。
    • q:查询序列的长度。
    • k:键序列的长度。
    • h:头的数量。
    • d:每个头的维度。
  • 输出的 energy 形状为 (N, heads, query_len, key_len)
2.2.4应用掩码(Masking):
if mask is not None:
    energy = energy.masked_fill(mask == 0, float("-1e20"))
  • 如果提供了掩码,将掩码为 0 的位置的注意力分数设置为一个非常小的值(如 -1e20),这样在后续的 softmax 计算中,这些位置的权重会趋近于 0。
2.2.5计算注意力权重:
attention = torch.softmax(energy / (self.embed_size ** (0.5)), dim=3)
  • 对注意力分数进行 softmax 归一化,得到注意力权重。
  • 除以 sqrt(embed_size) 是为了缩放点积结果,避免梯度消失或爆炸。
2.2.6应用注意力权重:
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
    N, query_len, self.heads * self.head_dim
)
  • 使用 torch.einsum 将注意力权重与值相乘,得到加权的值。
  • 公式 nhql,nlhd->nqhd 表示:
    • n:批量大小。
    • h:头的数量。
    • q:查询序列的长度。
    • l:值序列的长度。
    • d:每个头的维度。
  • 输出的 out 形状为 (N, query_len, heads * self.head_dim)
2.2.7线性变换输出:
out = self.fc_out(out)
  • 将所有头的输出拼接起来,并通过一个线性层将维度转换回原始的嵌入维度 embed_size

3. 示例矩阵计算

假设:

  • embed_size = 4
  • heads = 2
  • head_dim = embed_size // heads = 2
  • 输入序列长度为 3,批量大小为 1。

3.1输入张量

values = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]], dtype=torch.float32)
keys = torch.tensor([[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], dtype=torch.float32)
query = torch.tensor([[[25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36]]], dtype=torch.float32)
mask = None

3.2重塑为多头

values = values.reshape(1, 3, 2, 2)  # (N, value_len, heads, head_dim)
keys = keys.reshape(1, 3, 2, 2)
queries = query.reshape(1, 3, 2, 2)

3.3线性变换

假设线性变换层的权重为单位矩阵(简化计算),则:

values = self.values(values)  # 不改变值
keys = self.keys(keys)
queries = self.queries(queries)

3.4计算注意力分数

energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

假设:

  • queries = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
  • keys = [[[13, 14], [15, 16]], [[17, 18], [19, 20]], [[21, 22], [23, 24]]]

计算点积:

energy = [
    [
        [[1*13 + 2*14, 1*15 + 2*16], [1*17 + 2*18, 1*19 + 2*20]],

完整代码:

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "嵌入尺寸需要被头部整除"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

作者码字不易,觉得有用的话不妨点个赞吧,关注我,持续为您更新AI的优质内容。

相关文章:

  • 侯捷 C++ 课程学习笔记:C++内存管理机制
  • Qt 初识
  • Unity Android出包
  • Mysql高频面试题
  • Gemini 2.0 Flash
  • AQS及派生类
  • AI日报 - 2025年3月11日
  • Spring Cloud 负载均衡器架构选型
  • 什么是 MyBatis? 它的优点和缺点是什么?
  • [NewStarCTF 2023 公开赛道]ez_sql1 【sqlmap使用/大小写绕过】
  • 万字技术指南STM32F103C8T6 + ESP8266-01 连接 OneNet 平台 MQTT/HTTP
  • Hexo博客Icarus主题不蒜子 UV、PV 统计数据初始化配置
  • (done) MIT6.S081 Lec15 Crash recovery 学习笔记
  • tcp/ip协议配置参数有哪些?tcp/ip协议需要设置的参数有哪些
  • JAVA面试_进阶部分_深入理解socket网络异常
  • 每日一题----------String 和StringBuffer和StringBuiler重点
  • STM32步进电机驱动全解析(上) | 零基础入门STM32第五十七步
  • WLAN(无线局域网)安全
  • Java网络爬虫工程
  • Docker基础之运行原理
  • 魔都眼|84岁美琪大戏院焕新回归:《SIX》开启中国首演
  • 是否担心关税战等外部因素冲击中国经济?外交部:有能力、有条件、有底气
  • 第九届丝绸之路国际博览会在西安开幕
  • 引入AI Mode聊天机器人,Gemini 2.5 Pro加持,谷歌重塑搜索智能
  • 小满:一庭栀子香
  • 4月中国常青游戏榜:32款游戏吸金近34亿元,腾讯、网易占半壁江山,《原神》再跌出前十