当前位置: 首页 > news >正文

从零用 NumPy 实现单层 Transformer 解码器(Decoder-Only)

200行代码理解LLM Attention+自解码推理

  • 从零用 NumPy 实现单层 Transformer 解码器(Decoder-Only)
    • 一、整体流程
    • 二、核心组件实现
      • 1. 正弦位置编码
      • 2. Layer Normalization
      • 3. 数值稳定的 Softmax
      • 4. 多头自注意力(含因果掩码)
      • 5. 前馈网络(MLP + GELU)
      • 6. 单层 Transformer 的前向函数
    • 三、运行与验证
    • 四、总结

从零用 NumPy 实现单层 Transformer 解码器(Decoder-Only)

近年来,大语言模型(LLM)大多基于 Transformer Decoder 架构,例如 GPT、LLaMA 等。在这篇文章中,我们将用 纯 NumPy 实现一个单层的 Decoder-Only Transformer,并支持因果掩码多头注意力GELU 激活等核心特性。

一、整体流程

单层 Pre-LN Decoder 的标准计算步骤为:

  1. 添加位置编码(Positional Encoding)
  2. LayerNorm(预归一化)
  3. 多头自注意力(Multi-Head Attention)
  4. 输出投影 W O W_O WO
  5. 残差连接①
  6. LayerNorm(第二次)
  7. 前馈网络(MLP + GELU)
  8. 残差连接②
  9. 词表投影预测下一个 Token

流程示意图如下(省略 Batch):

X_in → +PE → LN → MHA(+mask) → W_O → +残差①→ LN → MLP(GELU) → +残差② → W_vocab^T → softmax

二、核心组件实现

1. 正弦位置编码

我们使用原论文《Attention Is All You Need》中的正弦位置编码,为每个位置生成固定的向量并加到嵌入上:

def add_cos_embedding(X_l):seq_len, embed_dim = X_l.shapefor i in range(seq_len):for j in range(embed_dim):if j % 2 == 0:X_l[i, j] += np.sin(float(i) / (10000 ** (j / embed_dim)))else:X_l[i, j] += np.cos(float(i-1) / (10000 ** ((j-1) / embed_dim)))return X_l

这里用到了不同频率的正弦/余弦,偶数维用 sin,奇数维用 cos,从而编码位置信息。


2. Layer Normalization

LayerNorm 是 Transformer 的标配归一化方法,这里我们用每个 token 自身的均值和方差来归一化:

def layer_norm(X_l, beta, gamma, eps=1e-6):mean = np.mean(X_l, axis=-1, keepdims=True)std = np.std(X_l, axis=-1, keepdims=True)X_l = (X_l - mean) / (std + eps)return X_l * gamma + beta

注意,这里 gammabeta 是可学习参数,对所有 token 共享。


3. 数值稳定的 Softmax

为了避免指数溢出,我们在计算 exp 前先减去行最大值:

def softmax(X_l, eps=1e-6):X_l = X_l - np.max(X_l, axis=-1, keepdims=True)exp_x = np.exp(X_l)return exp_x / (np.sum(exp_x, axis=-1, keepdims=True) + eps)

4. 多头自注意力(含因果掩码)

在 Decoder 中,我们必须确保当前 token 不能看到未来的信息,因此需要因果掩码(Causal Mask):

def compute_attention(X_q, X_k, X_v, heads, attn_dim, mask):head_outputs = []for i in range(heads):scores = np.dot(X_q[i], X_k[i].T) / np.sqrt(attn_dim)scores = scores + mask   # 将不可见位置加上 -1e9attn = softmax(scores)out = np.dot(attn, X_v[i])head_outputs.append(out)return np.concatenate(head_outputs, axis=-1)  # 拼接所有头

其中 mask 是一个 [seq_len, seq_len] 的矩阵,上三角(未来位置)为 -1e9,下三角为 0。


5. 前馈网络(MLP + GELU)

MLP 部分使用了常见的 GELU 激活函数:

def gelu(x):return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))def mlp
http://www.dtcms.com/a/329803.html

相关文章:

  • 《红黑树驱动的Map/Set实现:C++高效关联容器全解析》
  • 基于微信小程序的生态农产销售管理的设计与实现/基于C#的生态农产销售系统的设计与实现、基于asp.net的农产销售系统的设计与实现
  • Ubuntu24.04桌面版安装wps
  • 深入分析Linux文件系统核心原理架构与实现机制
  • RS485转profinet网关接M8-11 系列 RFID 读卡模块实现读取卡号输出
  • 元数据与反射:揭开程序的“自我认知”能力
  • 【递归、搜索与回溯算法】穷举、暴搜、深搜、回溯、剪枝
  • 第七章:OLED温湿度显示系统
  • 数据库连接池如何进行空闲管理
  • 光伏板横铺VS竖铺,布局决定发电量!
  • MySQL数据库知识体系总结 20250813
  • iOS混淆工具有哪些?数据安全与隐私合规下的防护实践
  • [ai]垂直agent|意图识别|槽位提取
  • 基于Tensorflow2.15的图像分类系统
  • MySQL三大存储引擎对比:InnoDB vs MyISAM vs MEMORY
  • 【Unity3D】Spine黑线(预乘问题)、贴图边缘裁剪问题
  • Effective C++ 条款39:明智而审慎地使用private继承
  • RabbitMQ:Windows版本安装部署
  • Java研学-RabbitMQ(六)
  • 基于js和html的点名应用
  • B站 韩顺平 笔记 (Day 17)
  • Spring Security 前后端分离场景下的会话并发管理
  • Spring Boot项目调用第三方接口的三种方式比较
  • 【Linux学习|黑马笔记|Day3】root用户、查看权限控制信息、chmod、chown、快捷键、软件安装、systemctl、软连接、日期与时区
  • Go 微服务限流与熔断最佳实践:滑动窗口、令牌桶与自适应阈值
  • NLP学习之Transformer(1)
  • 深度学习(4):数据加载器
  • Redis7学习——Redis的初认识
  • 51c自动驾驶~合集14
  • Docker:快速部署 Temporal 工作流引擎的技术指南