当前位置: 首页 > news >正文

手撕 Decoder

Happy-LLM:从零开始的大语言模型原理与实践教程.pdf P24

Decoder Layer

一个Decoder Layer内的数据流动顺序为:

Input (x)↓
LayerNorm 1↓
Masked Self-Attention↓
x + Attention(x)↓
LayerNorm 2↓
Cross-Attention (with enc_out)↓
h = x + Attention(x, enc_out)↓
LayerNorm 3↓
MLP (Feed-Forward Network)↓
out = h + MLP(h)

该实现依然与标准transformer不一样,以下代码里,先归一化再残差连接(Pre-LayerNorm),而标准transformer则相反

代码

class DecoderLayer(nn.Module):def __init__(self, args):super().__init__()self.attention_norm_1 = LayerNorm(args.n_embd)self.mask_attention = MultiHeadAttention(args, is_causal=True)self.attention_norm_2 = LayerNorm(args.n_embd)self.attention = MultiHeadAttention(args, is_causal=False)self.ffn_norm = LayerNorm(args.n_embd)self.feed_forward = MLP(args)def forward(self, x, enc_out):# Layer Normnorm_x = self.attention_norm_1(x)# 掩码自注意力x = x + self.mask_attention.forward(norm_x, norm_x, norm_x)# 多头注意力norm_x = self.attention_norm_2(x)h = x + self.attention.forward(norm_x, enc_out, enc_out)# 经过前馈神经网络out = h + self.feed_forward.forward(self.ffn_norm(h))return out

初始化了三个子层(Masked Multi-Head Attention, Cross Attention, Feed Forward),每个子层都含有一个Layer Norm,三个层归一化的函数相同,均为LayerNorm(args.n_embd),仅名称不同

代码逻辑实质上是三个重复的x_{out}=x+SubLayer(LN(x))

搭建Decoder

 class Decoder(nn.Module):'''解码器'''def __init__(self, args):super(Decoder, self).__init__() # ⼀个 Decoder 由 N 个 Decoder Layer 组成self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layer)])self.norm = LayerNorm(args.n_embd)def forward(self, x, enc_out):"Pass the input (and mask) through each layer in turn."for layer in self.layers:x = layer(x, enc_out)return self.norm(x)

[DecoderLayer(args) for _ in range(args.n_layer)] 通过循环生成 args.n_layer 个 DecoderLayer 实例;nn.ModuleList 将生成的 DecoderLayer 实例列表包装为 nn.ModuleList,动态创建并注册多个解码器层,构建符合 Transformer 架构的解码器

代码末尾的 self.norm(x) 是对所有层处理后的最终输出进行归一化

参考文章

Happy-LLM:从零开始的大语言模型原理与实践教程.pdf P24

http://www.dtcms.com/a/258061.html

相关文章:

  • Day 2:Shell变量解密——从“Hello World“到会“记忆“的脚本
  • C语言数组介绍 -- 一维数组和二维数组的创建、初始化、下标、遍历、存储,C99 变长数组
  • Zynq + FreeRTOS + YAFFS2 + SQLite3 集成指南
  • python计算长方形的周长 2025年3月青少年电子学会等级考试 中小学生python编程等级考试一级真题答案解析
  • Vibe Coding - 使用cursor从PRD到TASK精准分解执行
  • 《内心强大不怯场》读书笔记3
  • 智能营销系统对企业的应用价值
  • 【Java面试】你是怎么控制缓存的更新?
  • Linux内核网络栈的智慧:skb->cb控制缓冲区的设计哲学
  • sudo安装pip包的影响
  • 有哪些词编码模型
  • 相机标定与3D重建技术通俗讲解
  • Python基础(​​FAISS​和​​Chroma​)
  • 每日算法刷题Day36 6.23:leetcode枚举技巧枚举中间4道题,用时1h30min
  • VLN论文复现——VLFM(ICRA最佳论文)
  • 【图像】ubuntu中图像处理
  • 可编辑精品PPT | 企业数字化商业平台客户中台解决方案客户CRM数据中台方案
  • 支持java8的kafka版本
  • 73页精品PPT | 大数据平台规划与数据价值挖掘应用咨询项目解决方案
  • 【Docker基础】Docker容器管理:docker pause详解
  • 龙虎榜——20250623
  • AI-Sphere-Butler之如何将豆包桌面版对接到AI全能管家~新玩法(一)
  • 如何实现财务自由
  • EEG 分类攻略1- theta, alpha, beta和gamma频谱
  • 学习Linux进程冻结技术
  • OpenCV——霍夫变换
  • 一些想法。。。
  • Mermaid学习第二部
  • Unreal Engine附着组件调用区别
  • 【C语言】解决VScode中文乱码问题