当前位置: 首页 > news >正文

从零构建TransformerP1-了解设计

欢迎来到啾啾的博客🐱。
记录学习点滴。分享工作思考和实用技巧,偶尔也分享一些杂谈💬。
有很多很多不足的地方,欢迎评论交流,感谢您的阅读和评论😄。

目录

  • 引言
  • 1 概念回顾
    • 1.1 序列任务
      • 1.1.1 将序列变成模型能处理的形式
      • 1.1.2 其他类型的序列
      • 1.1.3 判断一个任务是不是序列任务的 checklist
    • 1.2 长距离依赖
    • 1.3 元素间关系
  • 2 Transformer设计流程
  • 3 第一步:分析你的“问题”——输入输出结构
  • 4 第二步:选择 Transformer 的哪种“模式”?
    • 4.1 如何选择?
  • 5 第三步:设计核心组件(你需要定义什么?)
    • 5.1 Embedding层
    • 5.2 注意力机制(Attention)
      • 5.2.1 位置编码(Positional Encoding)
        • 5.2.1.1 固定位置编码
        • 5.2.1.2 可学习位置编码(更常用)
    • 5.3 模型头(Head)——输出层设计
    • 5.4 损失函数(Loss Function)
  • 6 第四步:训练流程(必须有的循环)
  • 7 第五步:验证与推理策略
  • 8 设计一个“新闻分类”模型
    • 8.1 设计过程
  • 9 总结:设计 Transformer 的 checklist
    • 9.1 最后一句话:

引言

AI使用声明:在内容整理、结构优化和语言表达的过程中,我使用了人工智能(AI)工具作为辅助。

在之前的《Transformer:从入门到放弃》一篇中,我们已经对Transformer架构有了基本的了解。

本篇,让我们基于PyTorch,一步一步实现一个简化但完整、可运行的 Transformer 模型,手写一些核心组件,包括:

  • 注意力提示(Attention Cue)
  • 注意力评分函数(Attention Scoring Functions)
  • 自注意力(Self-Attention)和多头注意力(Multi-Head Attention)
  • 位置编码(Positional Encoding)
  • 编码器(Encoder)和解码器(Decoder)
    最终将它们组装成一个完整的Encoder-Decoder架构的Transformer。

开始阅读前,让我们思考一个问题:“设计一个Transformer需要什么?”

并且思考"给一段新闻文本,判断属于哪个类别(体育、科技、娱乐…)",即一个文本分类任务用的Transformer模型要怎么设计?

补充资料:《happy-llm》第二章(强烈推荐)

1 概念回顾

1.1 序列任务

输入或输出是“有序元素组成的序列”的任务。 这里的“序列”就像一条有顺序的链条,每个位置上是一个元素(比如一个词、一个音素、一个动作)。

序列任务 = 输入或输出是一个“有顺序的元素链”,且顺序影响语义的任务。

任务类型输入序列输出序列说明
文本分类"这部电影真好看!"[电影, 是, 好看, ...]"正面"输入是序列,输出是单个标签
机器翻译"Hello world"[Hello, world]"你好 世界"输入和输出都是序列
文本生成"从前有一只""从前有一只小猫..."输出是逐步生成的序列
语音识别音频波形 → 帧序列"今天天气很好"声音是时间序列
时间序列预测[1, 2, 3, 4, 5][6]数值型序列预测未来值
命名实体识别(NER)"小明在北京上班"[人名, 地点, 组织]每个词都有一个标签
对话系统用户说:“你好” → 模型回复:“你好呀!”多轮对话历史输入输出都是对话序列
👉 这些任务的共同点:数据是有“时间”或“顺序”维度的

Transformer 特别适合序列任务,其核心优势是能并行处理序列 + 能用 自注意力机制 直接建模序列中任意两个元素的关系(无论多远)。

1.1.1 将序列变成模型能处理的形式

原始文本不能直接输入神经网络,要转换成“向量序列”。

  • 步骤 1:分词(Tokenization)
句子:"I love Transformers"
分词 → ["I", "love", "Transformers"]
  • 步骤 2:转为词 ID
词表:{"I": 1, "love": 2, "Transformers": 3}
→ [1, 2, 3]
  • 步骤 3:嵌入 + 位置编码
# 每个 ID 映射为向量
embeddings = nn.Embedding(vocab_size, d_model)  # (3,) → (3, d_model)# 加上位置信息
x = embeddings + positional_encoding
# 最终输入:(seq_len, d_model) 的向量序列

我们知道,模型处理向量时,需要保证向量维度的一致,也就是序列长度的一致。
不同句子长度不同,怎么批量训练呢?

解决方案:Padding + Mask

句子1: "Hello"           → [Hello, <pad>, <pad>]
句子2: "I love you"      → [I, love, you]
句子3: "OK"              → [OK, <pad>, <pad>]
  • 统一补到最长长度
  • 使用 attention mask 告诉模型:“忽略 <pad> 位置”

👉 这样就能批量处理变长序列了。

  • seq_len与d_model
维度是否可变如何处理作用
序列长度(seq_len)batch 间可变,batch 内必须一致padding + mask控制上下文长度
向量维度(d_model)全程必须一致模型设计时固定控制表示能力

🔁 它们的关系是:

  • seq_len 决定了“时间/顺序维度”的大小
  • d_model 决定了“特征维度”的大小
  • 一起构成输入张量:(batch_size, seq_len, d_model)
  • 向量就是1D的张量

1.1.2 其他类型的序列

序列类型示例是否适合 Transformer
文本序列句子、文档✅ 最经典应用
时间序列股价、气温✅ 可以(如 TimeSformer)
音频序列语音波形、MFCC 特征✅ Whisper 就是 Transformer
视频序列一连串图像帧✅ VideoBERT、TimeSformer
DNA 序列ATCG…✅ 生物信息学中使用
用户行为序列点击、浏览、购买✅ 推荐系统中常用
程序代码函数、语句✅ CodeBERT、Codex

🌟 所以 Transformer 不只是“NLP 模型”,而是“序列建模通用架构

1.1.3 判断一个任务是不是序列任务的 checklist

✅ 如果你回答“是”,那它很可能是序列任务:

问题是?
输入是一段文字、一句话、一段语音吗?
输出要生成一段文本或翻译结果吗?
数据有时间顺序(如股价、日志)吗?
元素之间的顺序会影响含义吗?
不同样本的长度不一样吗?

1.2 长距离依赖

一段新闻文本,如"苹果公司发布了新款iPhone,搭载A17芯片..."
开头提到“苹果”,结尾提到“发布会”,中间隔了很多词,但这两个信息共同决定这是“科技”新闻。

关键信息可能相隔很远,但需要两个相隔很远的信息共同决定语义,这就是长距离依赖。

  • 模型对比
模型处理长距离依赖能力原因
RNN/LSTM⚠️ 有限(随距离衰减)信息需逐步传递,梯度消失
CNN⚠️ 有限(需多层堆叠)感受野有限,n层CNN只能看到2ⁿ距离
Transformer优秀自注意力直接建模任意距离关系

1.3 元素间关系

一段新闻文本,如"苹果公司发布了新款iPhone,搭载A17芯片..."
“苹果”是公司还是水果?需要结合上下文(如“发布”、“芯片”)判断语义 → 必须建模词与词之间的关系。

这样词与词的关系就是元素间关系。

是否需要建模“元素间关系”是 NLP 的本质问题。

例子问题需要关系建模吗?
"苹果很好吃"“苹果”是水果还是公司?✅ 需要看上下文(“好吃” → 水果)
"苹果发布了新手机"同上✅ “发布” → 公司
👉 这些都依赖词与词之间的语义关系,而 Transformer 的自注意力机制正是为此设计的:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
它允许每个词(Query)去“查询”其他所有词(Key),根据相关性加权聚合信息(Value)。直接连接任意距离元素

伪代码:Transformer的自注意力机制如何解决歧义,建立元素间关系

  def resolve_ambiguity(word, context):# word作为Query,context作为Key-Valueattention_weights = softmax(word @ context.T / sqrt(d_k))# 权重反映相关性:"苹果"对"发布"的权重高,对"好吃"的权重低resolved_meaning = attention_weights @ contextreturn resolved_meaning

2 Transformer设计流程

首先,从“问题”出发 → 决定“是否用 Transformer”和“怎么设计”?

你的问题↓
输入是什么?输出是什么?(数据结构)↓
是序列任务吗?有长距离依赖吗?↓
是否需要建模“元素间关系”?↓
→ 是:考虑 Transformer
→ 否:可能 CNN/RNN/MLP 更合适↓
选择架构:Encoder-only?Decoder-only?Encoder-Decoder?↓
定义组件:Embedding、Attention、FFN、Head...↓
训练循环:Loss、Optimization、Evaluation

![[从零构建Transformer.png]]

文本分类任务分析如下:

你的问题:新闻分类↓
输入:文本序列 → 是序列任务 ✅↓
是否有长距离依赖? → 是(如首尾关键词呼应)✅↓
是否需要建模词间关系? → 是(如歧义消解)✅↓
→ 推荐使用 Transformer(而非 MLP/RNN/CNN)↓
输出是单个类别 → 不需要生成 → 用 Encoder-only 架构↓
选择 [CLS] 或 平均池化 获取句向量↓
接分类头(Linear + Softmax)

3 第一步:分析你的“问题”——输入输出结构

Transformer 最擅长处理“序列”或“结构化关系”问题

任务类型输入输出是否适合Transformer
文本分类句子[I love you]类别positive✅ 是(Encoder-only)
机器翻译源语言句子[Hello world]目标语言句子[你好世界]✅ 是(Encoder-Decoder)
文本生成提示[Once upon a time]续写故事✅ 是(Decoder-only)
图像分类图像像素类别标签✅ 可以(ViT:把图切成 patch)
时间序列预测历史数据[1,2,3,4]未来值[5]✅ 是(类似 seq2seq)
表格数据分类特征列[age, income, ...]标签❌ 通常不用(MLP 更好)

4 第二步:选择 Transformer 的哪种“模式”?

Transformer 不只有一种结构!根据任务不同,有三种主流变体:

架构结构典型任务列子
Encoder-only只保留编码器分类、NER、句向量BERT、RoBERTa
Decoder-only只保留解码器(带掩码)文本生成GPT 系列
Encoder-Decoder编码器 + 解码器机器翻译、摘要T5、BART

4.1 如何选择?

有一些简单的例子。

问题推荐模式
“这段话是正面还是负面?”Encoder-only
“把英文翻译成中文”Encoder-Decoder
“续写这篇文章”Decoder-only
“回答一个问题”Decoder-only(如 ChatGPT)
“提取实体:人名、地点”Encoder-only

5 第三步:设计核心组件(你需要定义什么?)

"设计-实现"映射表

设计决策PyTorch实现代码示例
输入表示
文本任务nn.Embedding + 位置编码self.token_emb = nn.Embedding(vocab_size, d_model)
图像任务Patch Embeddingx = x.unfold(2, patch_size, patch_size).reshape(...)
位置编码
固定位置编码Sinusoidal PE见下方完整实现
可学习位置编码nn.Embedding(max_len, d_model)self.pos_emb = nn.Embedding(512, d_model)
注意力类型
自注意力nn.MultiheadAttentionself.attn = nn.MultiheadAttention(d_model, n_heads)
因果注意力+ maskattn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

一旦确定了架构,就要设计以下模块:

5.1 Embedding层

输入表示(Input Representation),需要把原始输入变成向量序列,在深度学习中,承担这个任务的组件就是 Embedding 层。

Embedding 层其实是一个存储固定大小的词典的嵌入向量查找表。也就是说,在输入神经网络之前,我们往往会先让自然语言输入通过分词器 tokenizer,分词器的作用是把自然语言输入切分成 token 并转化成一个固定的 index。例如,如果我们将词表大小设为 4,输入“我喜欢你”,那么,分词器可以将输入转化成:

input: 我
output: 0input: 喜欢
output: 1input:你
output: 2
  • 文本:词嵌入 + 位置编码
  • 图像:将图像切分为 patch,每个 patch 线性投影为向量(ViT)
  • 音频:频谱图切块 → 向量
  • 多模态:文本向量 + 图像向量拼接或对齐

✅ 关键:所有输入都要变成 (batch_size, seq_len, d_model) 的张量


5.2 注意力机制(Attention)

根据任务决定注意力类型:

类型用途是否允许查看未来
自注意力(Self-Attention)建模序列内部关系是(Encoder) / 否(Decoder)
交叉注意力(Cross-Attention)解码器关注编码器输出
因果注意力(Causal Attention)生成时防止泄露未来信息❌ 不能看后面

⚠️ 解码器中的自注意力必须使用 掩码(mask) 来屏蔽未来 token


5.2.1 位置编码(Positional Encoding)

因为 Transformer 没有顺序感知,必须采用位置编码机制来保留序列的位置信息:

  • 固定位置编码:正弦函数(原论文)。适合序列长度固定的任务。
  • 可学习位置编码nn.Embedding(seq_len, d_model)。适合大多数情况,模型能自动适应数据分布
  • 相对位置编码:更高级,建模相对距离
5.2.1.1 固定位置编码

Sinusoidal PE

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)  # (1, max_len, d_model)self.register_buffer('pe', pe)def forward(self, x):seq_len = x.size(1)return x + self.pe[:, :seq_len, :]
5.2.1.2 可学习位置编码(更常用)
self.pos_embedding = nn.Embedding(max_position_embeddings, d_model)
positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
x = x + self.pos_embedding(positions)

5.3 模型头(Head)——输出层设计

根据任务设计最后的输出层:

任务输出头
分类Linear(d_model, num_classes)+ CrossEntropy
回归Linear(d_model, 1)+ MSE
序列生成Linear(d_model, vocab_size)+ Softmax + CTC Loss
命名实体识别(NER)每个 token 输出类别(seq_len × classes
问答输出起始和结束位置(两个线性层)

5.4 损失函数(Loss Function)

任务损失函数
分类交叉熵CrossEntropyLoss
生成语言模型损失(预测下一个词)
回归均方误差MSELoss
多标签分类BCEWithLogitsLoss
对比学习InfoNCE Loss(如 CLIP)

6 第四步:训练流程(必须有的循环)

无论什么任务,训练都遵循这个循环(神经网络的训练循环):

for epoch in epochs:for batch in dataloader:# 1. 前向传播output = model(input_ids, attention_mask=mask)# 2. 计算损失loss = loss_fn(output, labels)# 3. 反向传播loss.backward()# 4. 更新参数optimizer.step()optimizer.zero_grad()# 5. 记录日志print(f"Loss: {loss.item()}")

✅ 这个循环是通用的,但 modelloss_fn 要根据任务定制

7 第五步:验证与推理策略

任务推理方式
分类[CLS] 或平均池化后分类
生成自回归生成(一次一个 token)
翻译Beam Search 提高质量
问答找 start 和 end 位置

8 设计一个“新闻分类”模型

给一段新闻文本,判断属于哪个类别(体育、科技、娱乐…)

8.1 设计过程

步骤决策
1. 输入输出输入:句子;输出:类别 → 分类任务
2. 架构选择Encoder-only(不需要生成)
3. 输入表示Token Embedding + Positional Encoding
4. 模型结构6 层 Encoder,每层 Multi-Head Attention + FFN
5. 输出头[CLS] 位置或平均池化 → Linear → 分类
6. 损失函数CrossEntropyLoss
7. 优化器AdamW
8. 推理方式前向传播 → argmax

👉 这就是 BERT 做分类的方式!


9 总结:设计 Transformer 的 checklist

📌 使用建议:逐行打勾 ✅ 或 ❌,根据答案组合决定是否使用 Transformer 及具体架构

类别问题是?决策指引
任务性质1. 输入或输出是序列吗?(文本、时间序列、音频等)✅/❌❌ 否 → 考虑 MLP/XGBoost/CNN
✅ 是 → 进入下一步
2. 序列长度是否可变或较长?(>50)✅/❌✅ 是 → Transformer 优势明显
⚠️ 否 → CNN/RNN 也可考虑
3. 元素之间的顺序是否影响语义?✅/❌✅ 是 → 排除词袋模型(Bag-of-Words)
语义复杂性4. 是否存在长距离依赖?(首尾信息关联)✅/❌✅ 是 → Transformer 显著优于 RNN/CNN
⚠️ 否 → 简单模型可能足够
5. 是否需要上下文才能理解局部语义?(如歧义消解)✅/❌✅ 是 → 自注意力机制的核心优势
例:“苹果”是水果还是公司?
6. 是否需要建模全局结构关系?(如句法、逻辑)✅/❌✅ 是 → Transformer 更适合
生成需求7. 是否需要生成序列?(如翻译、摘要、对话)✅/❌✅ 是 → 必须用 Decoder-only 或 Encoder-Decoder
❌ 否 → Encoder-only 足够
8. 是否允许模型看到未来 token?✅/❌❌ 否(如自回归生成)→ 必须使用 因果掩码(causal mask)
✅ 是 → 可用双向注意力(如 BERT)
输入输出设计9. 如何表示输入?文本:Token Embedding + PE
图像:Patch Embedding
音频:Spectrogram + Conv
10. 如何表示位置?固定 sinusoidal PE / 可学习 Position Embedding / 相对位置编码
11. 输出头(Head)如何设计?分类:[CLS] 或 平均池化 + Linear
序列标注:每个 token 输出
生成:LM Head(vocab_size 输出)
训练与评估12. 使用什么损失函数?分类:CrossEntropy
生成:Language Modeling Loss
回归:MSE
13. 如何评估?分类:Accuracy/F1
生成:BLEU/ROUGE/METEOR
语义相似:Cosine Similarity
  • 🧠 使用示例:新闻分类任务
问题回答决策
1. 是序列任务?→ 考虑 Transformer
2. 序列较长?→ Transformer 优势
4. 有长距离依赖?强烈推荐 Transformer
5. 需要上下文理解?→ 自注意力必要
7. 需要生成?→ 用 Encoder-only
8. 能看未来?→ 可用双向注意力(BERT-style)
11. 输出头?[CLS] token + 分类头
12. 损失函数?CrossEntropyLoss

👉 结论:使用 BERT-style Encoder-only 模型,最佳选择。

9.1 最后一句话:

Transformer 不是一个“万能黑箱”,而是一个“模块化工具箱”
你要做的不是“套公式”,而是:
🔍 理解问题 → 拆解结构 → 选择组件 → 组装模型

http://www.dtcms.com/a/321277.html

相关文章:

  • FreeRTOS入门知识(初识RTOS)(一)
  • Nginx 部署前端项目、负载均衡与反向代理
  • Seaborn 学习笔记
  • DigitalProductId解密算法php版
  • 「安全发」ISV对接支付宝+小猎系统
  • Prometheus 通过读取文件中的配置来监控目标
  • [ MySQL 数据库 ] 环境安装配置和使用
  • Rocky Linux 安装 Google Chrome 浏览器
  • (附源码)基于SpringBoot的高校爱心捐助平台的设计与实现
  • USB (Universal Serial Bus,通用串行总线)
  • K次取反后最大化的数组和
  • [案例十] NX二次开发批量替换组件功能(装配环境)
  • 【Open3D】基础操作之三维数据结构的高效组织和管理
  • 【FreeRTOS】任务间通讯3:互斥量- Mutex
  • ctrl+alt+方向键导致屏幕旋转的解决方法
  • 基于双块轻量级神经网络的无人机拍摄的风力涡轮机图像去雾方法
  • No time to train! Training-Free Reference-Based Instance Segmentation之论文阅读
  • 机场风云:AI 云厂商的暗战,广告大战一触即发
  • 【实战】Dify从0到100进阶--中药科普助手(2)
  • 用browse实现菜单功能的方法
  • 快速上手 Ollama:强大的开源语言模型框架
  • Docker的安装使用以及常见的网络问题
  • 数据库恢复技术:保障数据安全的关键
  • DeepSeek辅助编写的带缓存检查的数据库查询缓存系统
  • Odoo 18 → Odoo 19 功能改动对比表
  • 基于Web的交互式坐标系变换矩阵计算工具
  • 时间复杂度计算(以for循环为例)
  • BBH详解:面向大模型的高阶推理评估基准与数据集分析
  • 轻松实现浏览器自动化——AI浏览器自动化框架Stagehand
  • 力扣 hot100 Day69