Transformer架构发展历史
Transformer架构详解:起源、原理与应用
目录
- Transformer的起源与历史背景
- Transformer核心架构详解
- 自注意力机制深入解析
- Transformer在结构化数据中的应用
- Transformer在图像数据中的应用
- 性能优化与变体
- 实际应用案例
1. Transformer的起源与历史背景
1.1 深度学习序列建模的演进
在Transformer出现之前,自然语言处理领域主要依赖以下架构:
循环神经网络时代(2010-2017)
- RNN (Recurrent Neural Networks):基础循环架构,存在梯度消失问题
- LSTM (Long Short-Term Memory):通过门控机制解决长期依赖问题
- GRU (Gated Recurrent Units):LSTM的简化版本
- Seq2Seq模型:编码器-解码器架构,用于机器翻译
RNN架构的局限性:
- 顺序计算:无法并行化,训练效率低
- 长期依赖问题:即使使用LSTM,处理超长序列仍然困难
- 梯度传播路径长:容易出现梯度消失或爆炸
- 信息瓶颈:编码器需要将所有信息压缩到固定维度的向量中
1.2 注意力机制的诞生
Bahdanau Attention (2014)
- 由Dzmitry Bahdanau等人提出
- 允许解码器在生成每个词时关注输入序列的不同部分
- 解决了Seq2Seq模型的信息瓶颈问题
核心思想:
对于输出序列的每个位置,计算其与输入序列所有位置的相关性
相关性越高的位置获得更大的权重
1.3 “Attention is All You Need” 论文
发表信息:
- 时间:2017年6月(arXiv)
- 作者团队:Google Brain和Google Research
- 核心作者:Ashish Vaswani, Noam Shazeer, Niki Parmar等
- 发表会议:NIPS 2017(现NeurIPS)
革命性贡献:
- 完全抛弃循环结构:首次提出纯注意力架构
- 并行化计算:大幅提升训练和推理速度
- 多头注意力机制:从多个子空间捕获信息
- 位置编码:通过三角函数编码位置信息
- 残差连接与层归一化:稳定深层网络训练
影响力:
- 截至2024年,论文引用次数超过10万次
- 催生了BERT、GPT、T5等众多里程碑模型
- 从NLP扩展到CV、语音、生物信息学等多个领域
2. Transformer核心架构详解
2.1 整体架构
Transformer采用编码器-解码器(Encoder-Decoder)架构:
输入序列 → [嵌入层 + 位置编码] → 编码器栈 → 解码器栈 → 输出概率分布↓上下文向量
核心组件:
- 编码器(Encoder):6层堆叠,每层包含多头自注意力和前馈网络
- 解码器(Decoder):6层堆叠,包含掩码自注意力、交叉注意力和前馈网络
- 注意力机制:核心计算单元
- 位置编码:注入序列位置信息
2.2 编码器详细结构
单个编码器层包含:
# 伪代码表示
class EncoderLayer:def forward(x):# 1. 多头自注意力attn_output = MultiHeadAttention(Q=x, K=x, V=x)x = LayerNorm(x + attn_output) # 残差连接 + 层归一化# 2. 位置前馈网络ffn_output = FeedForward(x)x = LayerNorm(x + ffn_output) # 残差连接 + 层归一化return x
参数配置(原论文):
- 模型维度 (d_model):512
- 头数 (num_heads):8
- 前馈网络维度 (d_ff):2048
- Dropout率:0.1
2.3 解码器详细结构
单个解码器层包含:
# 伪代码表示
class DecoderLayer:def forward(x, encoder_output):# 1. 掩码多头自注意力(防止看到未来信息)masked_attn = MaskedMultiHeadAttention(Q=x, K=x, V=x)x = LayerNorm(x + masked_attn)# 2. 编码器-解码器注意力(交叉注意力)cross_attn = MultiHeadAttention(Q=x, K=encoder_output, V=encoder_output)x = LayerNorm(x + cross_attn)# 3. 位置前馈网络ffn_output = FeedForward(x)x = LayerNorm(x + ffn_output)return x
关键差异:
- 第一个注意力层使用掩码,确保位置i只能关注i之前的位置
- 增加交叉注意力层,Query来自解码器,Key和Value来自编码器输出
2.4 位置编码(Positional Encoding)
由于Transformer没有循环或卷积结构,需要显式注入位置信息。
公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中:
pos:序列中的位置(0到max_len-1)i:维度索引(0到d_model/2-1)d_model:模型维度
特点:
- 确定性:不需要学习,节省参数
- 外推性:理论上可以处理训练时未见过的序列长度
- 相对位置关系:PE(pos+k)可以表示为PE(pos)的线性函数
可视化示例:
位置0: [sin(0/10000^0), cos(0/10000^0), sin(0/10000^(2/512)), ...]
位置1: [sin(1/10000^0), cos(1/10000^0), sin(1/10000^(2/512)), ...]
位置2: [sin(2/10000^0), cos(2/10000^0), sin(2/10000^(2/512)), ...]
...
3. 自注意力机制深入解析
3.1 缩放点积注意力(Scaled Dot-Product Attention)
核心公式:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
计算步骤:
-
计算相似度得分:
Scores = QK^T 维度:(seq_len_q, d_k) × (d_k, seq_len_k) = (seq_len_q, seq_len_k) -
缩放:
Scaled_Scores = Scores / √d_k- 目的:防止点积结果过大导致softmax梯度过小
- 当d_k较大时,点积方差为d_k,缩放使方差归一化
-
应用Softmax:
Attention_Weights = softmax(Scaled_Scores)- 每一行的权重和为1
- 表示Query对所有Key的注意力分布
-
加权求和:
Output = Attention_Weights × V 维度:(seq_len_q, seq_len_k) × (seq_len_k, d_v) = (seq_len_q, d_v)
示例计算(简化版):
import numpy as np# 假设我们有3个词,每个词的表示维度为4
Q = np.array([[1, 0, 1, 0],[0, 1, 0, 1],[1, 1, 0, 0]])K = np.array([[1, 0, 1, 0],[0, 1, 0, 1],[1, 1, 0, 0]])V = np.array([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]])# 1. 计算QK^T
scores = np.dot(Q, K.T) # 结果:[[2, 0, 2], [0, 2, 1], [2, 1, 2]]# 2. 缩放
d_k = Q.shape[-1]
scaled_scores = scores / np.sqrt(d_k) # 除以2# 3. Softmax
attention_weights = np.exp(scaled_scores) / np.exp(scaled_scores).sum(axis=-1, keepdims=True)# 4. 加权求和
output = np.dot(attention_weights, V)
3.2 多头注意力(Multi-Head Attention)
核心思想:
- 不是只计算一次注意力,而是并行计算h次
- 每个"头"学习不同的注意力模式
- 类似于CNN中的多通道
公式:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O其中 head_i = Attention(QW^Q_i, KW^K_i, VW^V_i)
参数矩阵:
- W^Q_i ∈ ℝ^(d_model × d_k):Query投影矩阵
- W^K_i ∈ ℝ^(d_model × d_k):Key投影矩阵
- W^V_i ∈ ℝ^(d_model × d_v):Value投影矩阵
- W^O ∈ ℝ^(h·d_v × d_model):输出投影矩阵
典型配置:
d_model = 512
h = 8
d_k = d_v = d_model / h = 64
计算流程图:
输入 (batch, seq_len, d_model=512)↓
分割成8个头,每个头维度64↓
并行计算8个注意力↓Head 1: Attention(Q1, K1, V1) → (batch, seq_len, 64)Head 2: Attention(Q2, K2, V2) → (batch, seq_len, 64)...Head 8: Attention(Q8, K8, V8) → (batch, seq_len, 64)↓
拼接 (batch, seq_len, 512)↓
线性投影 W^O↓
输出 (batch, seq_len, d_model=512)
为什么使用多头?
-
不同表示子空间:每个头可以关注不同方面的信息
- 某个头可能关注语法关系
- 某个头可能关注语义相似性
- 某个头可能关注位置关系
-
增加模型容量:在不增加参数量的情况下增强表达能力
-
并行计算:多个头可以同时计算
3.3 掩码机制(Masking)
填充掩码(Padding Mask):
# 用于处理变长序列
# 假设序列长度为5,实际有效长度为3
sequence = [token1, token2, token3, <PAD>, <PAD>]
padding_mask = [0, 0, 0, 1, 1] # 1表示需要掩码的位置# 在计算注意力前应用
scores = scores.masked_fill(padding_mask == 1, -1e9)
# Softmax后,被掩码的位置权重接近0
前瞻掩码(Look-Ahead Mask):
# 用于解码器自注意力,防止看到未来信息
# 序列长度为4的掩码矩阵
look_ahead_mask = [[0, 1, 1, 1], # 位置0只能看到位置0[0, 0, 1, 1], # 位置1只能看到位置0-1[0, 0, 0, 1], # 位置2只能看到位置0-2[0, 0, 0, 0], # 位置3可以看到所有位置
]
# 0表示允许注意,1表示掩码
3.4 前馈网络(Feed-Forward Network)
结构:
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
特点:
- 两层全连接网络
- 中间层使用ReLU激活
- 每个位置独立应用(position-wise)
- 参数在所有位置共享
维度变化:
(batch, seq_len, d_model=512) ↓ 第一层
(batch, seq_len, d_ff=2048)↓ ReLU + 第二层
(batch, seq_len, d_model=512)
作用:
- 增加模型的非线性变换能力
- 每个位置可以进行独立的特征变换
- 类似于1×1卷积的作用
4. Transformer在结构化数据中的应用
4.1 结构化数据的挑战
表格数据特点:
- 异构特征:数值型、类别型混合
- 特征交互:特征之间存在复杂的非线性关系
- 数据量相对较小:通常几千到几十万样本
- 可解释性需求:金融、医疗等领域需要理解模型决策
传统方法:
- 树模型:XGBoost, LightGBM, CatBoost(表格数据的黄金标准)
- 深度学习:MLP、Wide & Deep、DeepFM
- 挑战:深度学习模型在小规模表格数据上通常不如树模型
4.2 TabTransformer (2020)
论文:“TabTransformer: Tabular Data Modeling Using Contextual Embeddings”
核心创新:
- 将类别特征转换为嵌入向量
- 使用Transformer学习类别特征之间的上下文关系
- 数值特征保持原样或简单归一化
架构设计:
类别特征 → 嵌入层 → Transformer编码器 → 拼接数值特征 → MLP → 输出具体流程:
1. 类别特征处理:- 特征1: [A] → Embedding → [e1_1, e1_2, ..., e1_d]- 特征2: [B] → Embedding → [e2_1, e2_2, ..., e2_d]- ...2. Transformer编码:- 输入: (num_categorical_features, embedding_dim)- 多层自注意力捕获特征间关系- 输出: (num_categorical_features, embedding_dim)3. 特征融合:- Flatten编码后的类别特征- 拼接原始数值特征- 通过MLP进行最终预测
代码示例:
import torch
import torch.nn as nnclass TabTransformer(nn.Module):def __init__(self, num_categories, # 每个类别特征的类别数num_numerical, # 数值特征数量embedding_dim=32, # 嵌入维度num_layers=6, # Transformer层数num_heads=8, # 注意力头数ffn_dim=128, # 前馈网络维度output_dim=1): # 输出维度super().__init__()# 类别特征嵌入self.embeddings = nn.ModuleList([nn.Embedding(num_cat, embedding_dim) for num_cat in num_categories])# Transformer编码器encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim,nhead=num_heads,dim_feedforward=ffn_dim,batch_first=True)self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)# 输出MLPtotal_dim = len(num_categories) * embedding_dim + num_numericalself.mlp = nn.Sequential(nn.Linear(total_dim, 256),nn.ReLU(),nn.Dropout(0.1),nn.Linear(256, 128),nn.ReLU(),nn.Dropout(0.1),nn.Linear(128, output_dim))def forward(self, categorical_features, numerical_features):# 嵌入类别特征embeddings = [emb(categorical_features[:, i]) for i, emb in enumerate(self.embeddings)]# (batch, num_categorical, embedding_dim)cat_embedded = torch.stack(embeddings, dim=1)# Transformer编码# (batch, num_categorical, embedding_dim)encoded = self.transformer(cat_embedded)# Flatten类别特征encoded_flat = encoded.reshape(encoded.size(0), -1)# 拼接数值特征combined = torch.cat([encoded_flat, numerical_features], dim=1)# MLP预测output = self.mlp(combined)return output# 使用示例
model = TabTransformer(num_categories=[10, 5, 20, 8], # 4个类别特征num_numerical=15, # 15个数值特征embedding_dim=32,num_layers=4,num_heads=4
)# 假设batch_size=64
cat_features = torch.randint(0, 10, (64, 4)) # 类别特征
num_features = torch.randn(64, 15) # 数值特征
output = model(cat_features, num_features)
优势:
- 特征交互:自动学习类别特征间的复杂关系
- 鲁棒性:对缺失值和噪声有一定容忍度
- 迁移学习:预训练的嵌入可以迁移到相关任务
实验结果(原论文):
- 在多个UCI数据集上超越传统MLP
- 在某些数据集上接近甚至超越树模型
- 特别适合类别特征较多的场景
4.3 FT-Transformer (2021)
论文:“Revisiting Deep Learning Models for Tabular Data”
核心改进:
- 所有特征都嵌入化:数值特征也转换为嵌入
- 特征级注意力:每个特征作为一个token
- 更简洁的架构:去除额外的MLP层
架构:
所有特征 → 特征嵌入 → [CLS] Token → Transformer → [CLS]输出 → 预测特征嵌入方式:
- 类别特征: Embedding(category_value)
- 数值特征: Linear(numerical_value) + Feature_embedding
代码示例:
class FTTransformer(nn.Module):def __init__(self, num_features,num_categories,embedding_dim=64,num_layers=3,num_heads=8):super().__init__()# 类别特征嵌入self.cat_embeddings = nn.ModuleList([nn.Embedding(num_cat, embedding_dim) for num_cat in num_categories])# 数值特征线性投影self.num_projections = nn.ModuleList([nn.Linear(1, embedding_dim) for _ in range(num_features)])# 特征位置嵌入total_features = len(num_categories) + num_featuresself.feature_embeddings = nn.Parameter(torch.randn(1, total_features, embedding_dim))# CLS tokenself.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))# Transformerencoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim,nhead=num_heads,dim_feedforward=embedding_dim * 4,batch_first=True)self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)# 输出层self.output = nn.Linear(embedding_dim, 1)def forward(self, categorical_features, numerical_features):batch_size = categorical_features.size(0)# 处理类别特征cat_embeds = [emb(categorical_features[:, i]) for i, emb in enumerate(self.cat_embeddings)]# 处理数值特征num_embeds = [proj(numerical_features[:, i:i+1]) for i, proj in enumerate(self.num_projections)]# 拼接所有特征嵌入all_embeds = cat_embeds + num_embedsfeatures = torch.stack(all_embeds, dim=1) # (batch, num_features, dim)# 添加特征位置嵌入features = features + self.feature_embeddings# 添加CLS tokencls_tokens = self.cls_token.expand(batch_size, -1, -1)features = torch.cat([cls_tokens, features], dim=1)# Transformer编码encoded = self.transformer(features)# 使用CLS token的输出进行预测cls_output = encoded[:, 0, :]output = self.output(cls_output)return output
优势:
- 统一处理所有类型的特征
- 更好的特征交互学习
- 在多个基准测试中达到SOTA
4.4 时间序列预测中的应用
Temporal Fusion Transformer (TFT, 2019)
场景:多变量时间序列预测
架构特点:
- 变量选择网络:学习哪些特征重要
- 时序融合解码器:融合不同时间尺度的信息
- 多头注意力:捕获长期依赖关系
应用:
- 电力负荷预测
- 股票价格预测
- 零售需求预测
代码框架:
class TemporalFusionTransformer(nn.Module):def __init__(self, config):super().__init__()# 变量选择网络self.variable_selection = VariableSelectionNetwork(input_dim=config.num_features,hidden_dim=config.hidden_dim)# LSTM编码器(处理历史序列)self.encoder_lstm = nn.LSTM(input_size=config.hidden_dim,hidden_size=config.hidden_dim,num_layers=1,batch_first=True)# LSTM解码器(生成未来序列)self.decoder_lstm = nn.LSTM(input_size=config.hidden_dim,hidden_size=config.hidden_dim,num_layers=1,batch_first=True)# 时序注意力层self.temporal_attention = nn.MultiheadAttention(embed_dim=config.hidden_dim,num_heads=config.num_heads,batch_first=True)# 门控残差网络self.grn = GatedResidualNetwork(config.hidden_dim)# 输出层self.output_layer = nn.Linear(config.hidden_dim, config.forecast_horizon)
4.5 推荐系统中的应用
BERT4Rec (2019)
思想:将用户行为序列看作"句子",物品看作"词"
架构:
用户行为序列: [item1, item2, [MASK], item4, item5]↓
Transformer编码器↓
预测被mask的item
训练方式:
- 随机mask序列中的物品
- 模型预测被mask的物品
- 类似BERT的预训练方式
优势:
- 双向建模:考虑前后文信息
- 捕获长期兴趣
- 自监督学习:无需额外标注
5. Transformer在图像数据中的应用
5.1 从CNN到Transformer的转变
传统CNN的局限:
- 局部感受野:卷积核只能看到局部区域
- 归纳偏置强:平移不变性和局部性是硬编码的
- 长距离依赖:需要堆叠很多层才能建立全局关系
Transformer的优势:
- 全局感受野:每个位置都能关注整个图像
- 灵活的归纳偏置:通过数据学习而非硬编码
- 可扩展性:性能随数据量和模型大小持续提升
5.2 Vision Transformer (ViT, 2020)
论文:“An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”(Google Research)
核心思想:将图像分割成固定大小的patch,每个patch作为一个token
架构流程:
1. 图像分割输入图像 (224×224×3)↓分割成patches (14×14个patches,每个16×16×3)2. Patch嵌入每个patch → flatten → 线性投影(16×16×3 = 768维) → 嵌入空间 (D维,如768)3. 添加位置嵌入+ 可学习的位置编码4. 添加[CLS] token[CLS] + patch_1 + patch_2 + ... + patch_1965. Transformer编码器多层自注意力 + FFN6. 分类头[CLS] token的输出 → MLP → 类别概率
详细代码实现:
import torch
import torch.nn as nnclass PatchEmbedding(nn.Module):"""将图像分割成patches并嵌入"""def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2# 使用卷积实现patch提取和线性投影self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):# x: (batch, 3, 224, 224)x = self.projection(x) # (batch, embed_dim, 14, 14)x = x.flatten(2) # (batch, embed_dim, 196)x = x.transpose(1, 2) # (batch, 196, embed_dim)return xclass VisionTransformer(nn.Module):def __init__(self, img_size=224,patch_size=16,in_channels=3,num_classes=1000,embed_dim=768,depth=12, # Transformer层数num_heads=12,mlp_ratio=4.0, # FFN隐藏层维度 = embed_dim * mlp_ratiodropout=0.1):super().__init__()# Patch嵌入self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patches# CLS tokenself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# 位置嵌入(可学习)self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.pos_drop = nn.Dropout(p=dropout)# Transformer编码器encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim,nhead=num_heads,dim_feedforward=int(embed_dim * mlp_ratio),dropout=dropout,batch_first=True)self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)# 分类头self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes)# 初始化权重nn.init.trunc_normal_(self.pos_embed, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)def forward(self, x):batch_size = x.shape[0]# Patch嵌入x = self.patch_embed(x) # (batch, 196, 768)# 添加CLS tokencls_tokens = self.cls_token.expand(batch_size, -1, -1)x = torch.cat([cls_tokens, x], dim=1) # (batch, 197, 768)# 添加位置嵌入x = x + self.pos_embedx = self.pos_drop(x)# Transformer编码x = self.transformer(x)# 分类(使用CLS token)x = self.norm(x[:, 0])x = self.head(x)return x# 模型变体
def vit_base_patch16_224():"""ViT-Base: 86M参数"""return VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12)def vit_large_patch16_224():"""ViT-Large: 307M参数"""return VisionTransformer(img_size=224,patch_size=16,embed_dim=1024,depth=24,num_heads=16)def vit_huge_patch14_224():"""ViT-Huge: 632M参数"""return VisionTransformer(img_size=224,patch_size=14,embed_dim=1280,depth=32,num_heads=16)
关键发现(原论文):
-
数据量是关键:
- 在ImageNet-1K(130万图像)上:ViT < ResNet
- 在ImageNet-21K(1400万图像)上:ViT ≈ ResNet
- 在JFT-300M(3亿图像)上:ViT > ResNet
-
归纳偏置:
- ViT几乎没有图像特定的归纳偏置
- 完全依赖数据学习,因此需要大规模数据
-
计算效率:
- 相同精度下,ViT训练成本更低
- 更容易扩展到更大模型
5.3 Swin Transformer (2021)
论文:“Swin Transformer: Hierarchical Vision Transformer using Shifted Windows”(Microsoft Research Asia)
动机:
- ViT计算复杂度是O(N²),其中N是patch数量
- 对于高分辨率图像,计算量巨大
- 缺少层次化结构,难以用于密集预测任务
核心创新:
- 局部窗口注意力:限制注意力在固定大小的窗口内
- 滑动窗口机制:通过移位窗口建立跨窗口连接
- 层次化结构:类似CNN的金字塔结构
架构设计:
Stage 1: 56×56 patches↓ Window Attention↓ Shifted Window AttentionStage 2: 28×28 patches (Patch Merging)↓ Window Attention↓ Shifted Window AttentionStage 3: 14×14 patches (Patch Merging)↓ Window Attention↓ Shifted Window AttentionStage 4: 7×7 patches (Patch Merging)↓ Window Attention↓ Shifted Window Attention
窗口注意力示例:
class WindowAttention(nn.Module):"""在固定大小窗口内计算注意力"""def __init__(self, dim, window_size, num_heads):super().__init__()self.dim = dimself.window_size = window_size # (window_height, window_width)self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# 相对位置偏置self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))# QKV投影self.qkv = nn.Linear(dim, dim * 3)self.proj = nn.Linear(dim, dim)def forward(self, x, mask=None):"""x: (num_windows*batch, window_size*window_size, C)mask: (num_windows, window_size*window_size, window_size*window_size)"""B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)qkv = qkv.permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]# 计算注意力q = q * self.scaleattn = (q @ k.transpose(-2, -1))# 添加相对位置偏置# ... (省略细节)# 应用掩码(用于shifted window)if mask is not None:attn = attn.masked_fill(mask == 0, float('-inf'))attn = attn.softmax(dim=-1)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)return xclass SwinTransformerBlock(nn.Module):"""Swin Transformer基本块"""def __init__(self, dim, num_heads, window_size=7, shift_size=0):super().__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.norm1 = nn.LayerNorm(dim)self.attn = WindowAttention(dim, (window_size, window_size), num_heads)self.norm2 = nn.LayerNorm(dim)self.mlp = nn.Sequential(nn.Linear(dim, 4 * dim),nn.GELU(),nn.Linear(4 * dim, dim))def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeshortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# 滑动窗口(如果shift_size > 0)if self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# 分割成窗口x_windows = window_partition(shifted_x, self.window_size)# 窗口注意力attn_windows = self.attn(x_windows)# 合并窗口shifted_x = window_reverse(attn_windows, self.window_size, H, W)# 反向滑动if self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)# FFNx = shortcut + xx = x + self.mlp(self.norm2(x))return x
优势:
- 计算效率:O(N)复杂度,可处理高分辨率图像
- 层次化表示:适合检测、分割等任务
- 性能优异:ImageNet分类、COCO检测都达到SOTA
实验结果:
- ImageNet-1K: Top-1 87.3% (Swin-L)
- COCO目标检测: 58.7 box AP
- ADE20K语义分割: 53.5 mIoU
5.4 Detection Transformer (DETR, 2020)
论文:“End-to-End Object Detection with Transformers”(Facebook AI)
革命性创新:
- 第一个端到端的目标检测器
- 不需要NMS(非极大值抑制)
- 不需要anchor boxes
架构:
输入图像↓
CNN Backbone (ResNet) → 特征图↓
Flatten + 位置编码↓
Transformer编码器↓
Object Queries(可学习)↓
Transformer解码器↓
FFN → N个预测 (类别 + 边界框)↓
匈牙利匹配 + Loss
关键组件:
- Object Queries:
# N个可学习的查询向量(N=100)
self.query_embed = nn.Embedding(num_queries, hidden_dim)# 在解码器中使用
query_pos = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)
- 二分匹配:
def hungarian_matching(pred_boxes, pred_logits, target_boxes, target_labels):"""使用匈牙利算法进行预测和GT的最优匹配"""# 计算分类成本cost_class = -pred_logits[:, target_labels]# 计算L1距离成本cost_bbox = torch.cdist(pred_boxes, target_boxes, p=1)# 计算GIoU成本cost_giou = -generalized_box_iou(pred_boxes, target_boxes)# 总成本C = cost_class + 5 * cost_bbox + 2 * cost_giou# 匈牙利算法indices = linear_sum_assignment(C.cpu())return indices
完整实现框架:
class DETR(nn.Module):def __init__(self, num_classes, num_queries=100, hidden_dim=256):super().__init__()# CNN骨干网络self.backbone = resnet50(pretrained=True)# 降维self.conv = nn.Conv2d(2048, hidden_dim, 1)# 位置编码self.pos_encoder = PositionEmbeddingSine(hidden_dim // 2)# Transformerself.transformer = nn.Transformer(d_model=hidden_dim,nhead=8,num_encoder_layers=6,num_decoder_layers=6,dim_feedforward=2048)# Object queriesself.query_embed = nn.Embedding(num_queries, hidden_dim)# 预测头self.class_embed = nn.Linear(hidden_dim, num_classes + 1) # +1 for no-objectself.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) # (x, y, w, h)def forward(self, images):# 特征提取features = self.backbone(images) # (batch, 2048, H, W)features = self.conv(features) # (batch, 256, H, W)# 位置编码pos = self.pos_encoder(features)# Flattenbatch_size = features.shape[0]features = features.flatten(2).permute(2, 0, 1) # (HW, batch, 256)pos = pos.flatten(2).permute(2, 0, 1)# Object queriesquery_embed = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)# Transformermemory = self.transformer.encoder(features, pos=pos)hs = self.transformer.decoder(query_embed, memory, memory_key_padding_mask=None,pos=pos,query_pos=query_embed)# 预测outputs_class = self.class_embed(hs) # (num_queries, batch, num_classes+1)outputs_coord = self.bbox_embed(hs).sigmoid() # (num_queries, batch, 4)return {'pred_logits': outputs_class[-1],'pred_boxes': outputs_coord[-1]}
优势:
- 真正的端到端训练
- 全局推理能力
- 代码简洁,易于扩展
局限性:
- 训练收敛较慢(需要500 epochs)
- 小物体检测性能较弱
- 计算量较大
5.5 图像分割中的应用
Segmenter (2021)
将ViT应用于语义分割:
图像 → Patch嵌入 → Transformer编码器 → 逐patch分类 → 上采样 → 分割图
Mask2Former (2022)
结合DETR思想进行实例分割和全景分割:
图像 → Backbone → Pixel Decoder → Transformer Decoder → Mask预测
6. 性能优化与变体
6.1 高效注意力机制
Linformer (2020)
- 将注意力复杂度从O(N²)降到O(N)
- 方法:低秩矩阵近似K和V
# 核心思想
K_projected = K @ E # E: (N, k) 其中 k << N
V_projected = V @ F # F: (N, k)
Attention = softmax(Q @ K_projected.T) @ V_projected
Performer (2020)
- 使用快速注意力算法(FAVOR+)
- 线性复杂度,无近似误差
Reformer (2020)
- 局部敏感哈希(LSH)注意力
- 只计算相似query和key的注意力
6.2 位置编码改进
相对位置编码:
# 不是编码绝对位置,而是编码相对距离
relative_position = position_i - position_j
bias = learned_bias[relative_position]
attention_score = (Q @ K.T) + bias
旋转位置编码(RoPE):
- 通过旋转矩阵注入位置信息
- 用于LLaMA、PaLM等大模型
ALiBi(Attention with Linear Biases):
- 在注意力分数上添加线性偏置
- 外推性能优异
6.3 稀疏注意力
局部注意力:
# 只关注周围k个位置
for i in range(seq_len):attend_to = range(max(0, i-k), min(seq_len, i+k+1))
Strided注意力:
# 每隔s个位置关注一次
attend_to = [0, s, 2s, 3s, ...]
Longformer:
- 结合局部注意力和全局注意力
- 可处理长达16k的序列
7. 实际应用案例
7.1 案例1:电商推荐系统
场景:用户商品点击序列预测
数据:
- 用户历史点击:[item_1, item_2, …, item_n]
- 用户特征:年龄、性别、地域
- 商品特征:类别、价格、品牌
方案:
class RecommendationTransformer(nn.Module):def __init__(self, num_items, embedding_dim=128):super().__init__()# 商品嵌入self.item_embedding = nn.Embedding(num_items, embedding_dim)# Transformer编码器self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embedding_dim,nhead=8,dim_feedforward=512,batch_first=True),num_layers=4)# 预测头self.output = nn.Linear(embedding_dim, num_items)def forward(self, item_sequence):# 嵌入x = self.item_embedding(item_sequence)# 位置编码x = x + self.positional_encoding(x)# Transformer编码encoded = self.transformer(x)# 预测下一个商品logits = self.output(encoded[:, -1, :])return logits
效果:
- 相比LSTM提升5-10%的点击率
- 能够捕获长期用户兴趣
7.2 案例2:医学图像分割
场景:CT扫描肿瘤分割
挑战:
- 医学图像分辨率高(512×512甚至更大)
- 标注数据有限
- 需要高精度
方案:使用Swin-UNETR
输入CT切片↓
Swin Transformer编码器(多尺度特征)↓
UNet风格解码器↓
分割结果
优势:
- 全局上下文建模
- 层次化特征提取
- 在BTCV数据集上达到SOTA
7.3 案例3:金融风控
场景:信用卡欺诈检测
数据特点:
- 高维特征:100+维
- 混合类型:交易金额(数值)、商户类别(类别)
- 类别不平衡:欺诈样本<1%
方案:使用TabTransformer
model = TabTransformer(num_categories=[1000, 50, 20, ...], # 商户ID、地区、卡类型num_numerical=45, # 交易金额、时间等embedding_dim=64,num_layers=6,num_heads=8
)# 训练时使用Focal Loss处理类别不平衡
criterion = FocalLoss(alpha=0.25, gamma=2.0)
效果:
- F1-score提升3-5%
- 特征交互自动学习,减少特征工程
7.4 案例4:自动驾驶
场景:3D目标检测
数据:激光雷达点云 + 相机图像
方案:TransFusion
点云 → PointNet提取特征↓
相机图像 → CNN提取特征↓
Transformer融合多模态特征↓
检测头 → 3D边界框
创新点:
- 跨模态注意力机制
- 查询向量同时关注点云和图像特征
总结
Transformer的核心优势
- 并行化:不同于RNN的顺序计算,可以充分利用GPU
- 长距离依赖:通过注意力机制直接建立任意位置间的连接
- 可解释性:注意力权重可视化,了解模型关注点
- 可扩展性:性能随模型大小和数据量持续提升
- 通用性:从NLP到CV,从表格数据到多模态
未来发展方向
-
效率优化:
- 更高效的注意力机制(Flash Attention)
- 模型压缩与量化
- 稀疏模型(MoE)
-
架构创新:
- 更好的位置编码
- 动态深度网络
- 神经架构搜索
-
应用拓展:
- 多模态融合(CLIP, DALL-E)
- 科学计算(AlphaFold)
- 强化学习(Decision Transformer)
-
理论理解:
- 为什么Transformer有效?
- 如何设计更好的归纳偏置?
- 泛化性理论
学习资源
论文必读:
- Attention is All You Need (2017)
- BERT: Pre-training of Deep Bidirectional Transformers (2018)
- An Image is Worth 16x16 Words (ViT, 2020)
- Swin Transformer (2021)
代码实践:
- Hugging Face Transformers库
- PyTorch官方教程
- Annotated Transformer
课程推荐:
- Stanford CS224N (NLP with Deep Learning)
- Stanford CS231N (Computer Vision)
- Deep Learning Specialization (Andrew Ng)
文档版本:v1.0
最后更新:2024年
作者:Claude (Anthropic)
