ViT算法流程——从 原始像素 → 网络输出 logits 的 每一步张量形状、公式、关键代码
Vision Transformer(ViT)算法流程,从 原始像素 → 网络输出 logits 的 每一步张量形状、公式、关键代码 全部展开。
🎯 目标
输入:一张 RGB 图像 x ∈ ℝH×W×C
输出:K 类 logits y ∈ ℝK
(以下以 ImageNet 为例:H=W=224,C=3,K=1000,patch=16,D=768,L=12,heads=12)
🔍 0. 符号表
| 符号 | 含义 | 数值例 |
|---|---|---|
| P | patch 边长 | 16 |
| N | patch 个数 | (224/16)² = 196 |
| D | 隐藏维度 | 768 |
| L | Transformer 层数 | 12 |
| h | 注意力头数 | 12 |
| D_h | 单头维度 | D/h = 64 |
🧮 1. 图像分块 & 线性投影(Patch Embedding)
1.1 图像 → 块
# 伪代码
x = x.view(3, 224, 224)
patches = x.unfold(1, 16, 16).unfold(2, 16, 16) # (3, 14, 14, 16, 16)
patches = patches.permute(1,2,0,3,4).reshape(196, 3*16*16) # (196, 768)
1.2 线性投影
W_proj = nn.Linear(768, 768) # 公式: z_0^(i) = W_proj * patches[i]
z_patch = W_proj(patches) # (196, 768)
🎫 2. 添加类别令牌 & 位置编码
2.1 可学习 cls token
cls_token = nn.Parameter(torch.zeros(1, 768)) # (1, D)
z_cls = cls_token.expand(batch_size, 1, 768) # (B, 1, D)
2.2 拼接
z_0 = torch.cat([z_cls, z_patch], dim=1) # (B, 197, 768)
2.3 可学习位置编码
pos_embed = nn.Parameter(torch.zeros(1, 197, 768)) # (1, 197, D)
z_0 = z_0 + pos_embed # 广播相加
形状:(B, 197, 768),此后永远不变。
🔁 3. Transformer 编码器(共 L=12 层)
每层包含 MSA → Add&Norm → FFN → Add&Norm,循环 L 次。
3.1 层归一化(LN)
z = LayerNorm(z) # (B, 197, 768)
3.2 多头自注意力(MSA)
3.2.1 线性映射 QKV
W_qkv = nn.Linear(768, 3*768) # 一次性输出
qkv = W_qkv(z) # (B, 197, 3*768)
q, k, v = qkv.chunk(3, dim=-1) # 各 (B, 197, 768)
3.2.2 拆多头
q = q.view(B, 197, 12, 64).transpose(1,2) # (B, 12, 197, 64)
k = k.view(B, 197, 12, 64).transpose(1,2)
v = v.view(B, 197, 12, 64).transpose(1,2)
3.2.3 缩放点积注意力
attn = (q @ k.transpose(-2,-1)) / sqrt(64) # (B,12,197,197)
attn = softmax(attn, dim=-1) # 行和=1
attn = dropout(attn, p=0.0) # 可设 0.1
out = attn @ v # (B,12,197,64)
3.2.4 合并头
out = out.transpose(1,2).reshape(B, 197, 768) # (B,197,768)
z = z + dropout(out) # 残差
3.3 FFN(两线性 + GELU)
ffn = Linear(768, 3072)
ffn2 = Linear(3072, 768)
z = z + dropout(ffn2(gelu(ffn(LayerNorm(z)))))
输出形状仍为 (B, 197, 768)
🏁 4. 提取 cls token & 分类头
4.1 取 cls token
z_cls = z[:, 0] # (B, 768)
4.2 线性分类器
head = nn.Linear(768, 1000)
logits = head(z_cls) # (B, 1000)
📈 5. 损失 & 训练
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
🧵 6. 完整形状链路(一行一形状)
| 阶段 | 形状 |
|---|---|
| 输入图像 | (B, 3, 224, 224) |
| patch unfold | (B, 196, 768) |
| +cls token | (B, 197, 768) |
| +pos embed | (B, 197, 768) |
| LN → QKV | (B, 197, 3×768) |
| 多头拆分 | (B, 12, 197, 64) |
| 注意力矩阵 | (B, 12, 197, 197) |
| 加权 V | (B, 12, 197, 64) |
| 合并头 | (B, 197, 768) |
| FFN 隐藏 | (B, 197, 3072) |
| FFN 输出 | (B, 197, 768) |
| 取 cls | (B, 768) |
| logits | (B, 1000) |
🧪 7. 关键代码片段(PyTorch 风格,可直接粘贴)
import torch, torch.nn as nn, torch.nn.functional as F
from einops import rearrangeclass PatchEmbed(nn.Module):def __init__(self, img_size=224, patch_size=16, embed_dim=768):super().__init__()self.n = img_size // patch_size # 14self.proj = nn.Conv2d(3, embed_dim, patch_size, patch_size) # 一次性完成 unfold + lineardef forward(self, x):x = self.proj(x) # (B, 768, 14, 14)x = rearrange(x, 'b d h w -> b (h w) d') # (B, 196, 768)return xclass MSA(nn.Module):def __init__(self, dim=768, heads=12, dropout=0.):super().__init__()self.heads = headsself.scale = (dim // heads) ** -0.5self.to_qkv = nn.Linear(dim, dim*3)self.attn_drop = nn.Dropout(dropout)self.proj = nn.Linear(dim, dim)def forward(self, x):B, N, D = x.shapeqkv = self.to_qkv(x).view(B, N, 3, self.heads, D//self.heads).permute(2,0,3,1,4) # (3,B,h,N,dh)q, k, v = qkv[0], qkv[1], qkv[2] # 各 (B,h,N,dh)attn = (q @ k.transpose(-2,-1)) * self.scaleattn = F.softmax(attn, dim=-1)attn = self.attn_drop(attn)out = attn @ v # (B,h,N,dh)out = out.transpose(1,2).reshape(B, N, D)return self.proj(out)class Block(nn.Module):def __init__(self, dim=768, heads=12, mlp_ratio=4, drop=0.):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attn = MSA(dim, heads, drop)self.norm2 = nn.LayerNorm(dim)self.mlp = nn.Sequential(nn.Linear(dim, int(dim*mlp_ratio)),nn.GELU(),nn.Dropout(drop),nn.Linear(int(dim*mlp_ratio), dim),nn.Dropout(drop))def forward(self, x):x = x + self.attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return xclass ViT(nn.Module):def __init__(self, img_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12):super().__init__()self.patch_embed = PatchEmbed(img_size, patch_size, dim)n_patches = (img_size//patch_size)**2self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))self.pos_embed = nn.Parameter(torch.zeros(1, n_patches+1, dim))nn.init.trunc_normal_(self.pos_embed, std=0.02)self.blocks = nn.ModuleList([Block(dim, heads) for _ in range(depth)])self.norm = nn.LayerNorm(dim)self.head = nn.Linear(dim, num_classes)def forward(self, x):B = x.shape[0]x = self.patch_embed(x) # (B,196,768)cls_tokens = self.cls_token.expand(B, -1, -1) # (B,1,768)x = torch.cat((cls_tokens, x), dim=1) + self.pos_embed # (B,197,768)for blk in self.blocks:x = blk(x)x = self.norm(x[:, 0]) # (B,768)return self.head(x) # (B,1000)
🧩 8. 小结(背口诀)
“卷积变切片,切片当单词;加 token 学位置,堆叠自注意力;取头做分类,softmax 出概率。”
