当前位置: 首页 > news >正文

ViT算法流程——从 原始像素 → 网络输出 logits 的 每一步张量形状、公式、关键代码

Vision Transformer(ViT)算法流程,从 原始像素 → 网络输出 logits每一步张量形状、公式、关键代码 全部展开。


🎯 目标

输入:一张 RGB 图像 x ∈ ℝH×W×C
输出:K 类 logits y ∈ ℝK
(以下以 ImageNet 为例:H=W=224,C=3,K=1000,patch=16,D=768,L=12,heads=12)


🔍 0. 符号表

符号含义数值例
Ppatch 边长16
Npatch 个数(224/16)² = 196
D隐藏维度768
LTransformer 层数12
h注意力头数12
D_h单头维度D/h = 64

🧮 1. 图像分块 & 线性投影(Patch Embedding)

1.1 图像 → 块

# 伪代码
x = x.view(3, 224, 224)
patches = x.unfold(1, 16, 16).unfold(2, 16, 16)  # (3, 14, 14, 16, 16)
patches = patches.permute(1,2,0,3,4).reshape(196, 3*16*16)  # (196, 768)

1.2 线性投影

W_proj = nn.Linear(768, 768)  # 公式: z_0^(i) = W_proj * patches[i]
z_patch = W_proj(patches)     # (196, 768)

🎫 2. 添加类别令牌 & 位置编码

2.1 可学习 cls token

cls_token = nn.Parameter(torch.zeros(1, 768))  # (1, D)
z_cls = cls_token.expand(batch_size, 1, 768)   # (B, 1, D)

2.2 拼接

z_0 = torch.cat([z_cls, z_patch], dim=1)  # (B, 197, 768)

2.3 可学习位置编码

pos_embed = nn.Parameter(torch.zeros(1, 197, 768))  # (1, 197, D)
z_0 = z_0 + pos_embed  # 广播相加

形状:(B, 197, 768),此后永远不变。


🔁 3. Transformer 编码器(共 L=12 层)

每层包含 MSA → Add&Norm → FFN → Add&Norm,循环 L 次。

3.1 层归一化(LN)

z = LayerNorm(z)        # (B, 197, 768)

3.2 多头自注意力(MSA)

3.2.1 线性映射 QKV
W_qkv = nn.Linear(768, 3*768)  # 一次性输出
qkv = W_qkv(z)                 # (B, 197, 3*768)
q, k, v = qkv.chunk(3, dim=-1) # 各 (B, 197, 768)
3.2.2 拆多头
q = q.view(B, 197, 12, 64).transpose(1,2)  # (B, 12, 197, 64)
k = k.view(B, 197, 12, 64).transpose(1,2)
v = v.view(B, 197, 12, 64).transpose(1,2)
3.2.3 缩放点积注意力
attn = (q @ k.transpose(-2,-1)) / sqrt(64)  # (B,12,197,197)
attn = softmax(attn, dim=-1)                # 行和=1
attn = dropout(attn, p=0.0)                 # 可设 0.1
out = attn @ v                              # (B,12,197,64)
3.2.4 合并头
out = out.transpose(1,2).reshape(B, 197, 768)  # (B,197,768)
z = z + dropout(out)                             # 残差

3.3 FFN(两线性 + GELU)

ffn = Linear(768, 3072)
ffn2 = Linear(3072, 768)
z = z + dropout(ffn2(gelu(ffn(LayerNorm(z)))))

输出形状仍为 (B, 197, 768)


🏁 4. 提取 cls token & 分类头

4.1 取 cls token

z_cls = z[:, 0]  # (B, 768)

4.2 线性分类器

head = nn.Linear(768, 1000)
logits = head(z_cls)  # (B, 1000)

📈 5. 损失 & 训练

criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)

🧵 6. 完整形状链路(一行一形状)

阶段形状
输入图像(B, 3, 224, 224)
patch unfold(B, 196, 768)
+cls token(B, 197, 768)
+pos embed(B, 197, 768)
LN → QKV(B, 197, 3×768)
多头拆分(B, 12, 197, 64)
注意力矩阵(B, 12, 197, 197)
加权 V(B, 12, 197, 64)
合并头(B, 197, 768)
FFN 隐藏(B, 197, 3072)
FFN 输出(B, 197, 768)
取 cls(B, 768)
logits(B, 1000)

🧪 7. 关键代码片段(PyTorch 风格,可直接粘贴)

import torch, torch.nn as nn, torch.nn.functional as F
from einops import rearrangeclass PatchEmbed(nn.Module):def __init__(self, img_size=224, patch_size=16, embed_dim=768):super().__init__()self.n = img_size // patch_size  # 14self.proj = nn.Conv2d(3, embed_dim, patch_size, patch_size)  # 一次性完成 unfold + lineardef forward(self, x):x = self.proj(x)                          # (B, 768, 14, 14)x = rearrange(x, 'b d h w -> b (h w) d')  # (B, 196, 768)return xclass MSA(nn.Module):def __init__(self, dim=768, heads=12, dropout=0.):super().__init__()self.heads = headsself.scale = (dim // heads) ** -0.5self.to_qkv = nn.Linear(dim, dim*3)self.attn_drop = nn.Dropout(dropout)self.proj = nn.Linear(dim, dim)def forward(self, x):B, N, D = x.shapeqkv = self.to_qkv(x).view(B, N, 3, self.heads, D//self.heads).permute(2,0,3,1,4)  # (3,B,h,N,dh)q, k, v = qkv[0], qkv[1], qkv[2]  # 各 (B,h,N,dh)attn = (q @ k.transpose(-2,-1)) * self.scaleattn = F.softmax(attn, dim=-1)attn = self.attn_drop(attn)out = attn @ v                     # (B,h,N,dh)out = out.transpose(1,2).reshape(B, N, D)return self.proj(out)class Block(nn.Module):def __init__(self, dim=768, heads=12, mlp_ratio=4, drop=0.):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attn = MSA(dim, heads, drop)self.norm2 = nn.LayerNorm(dim)self.mlp = nn.Sequential(nn.Linear(dim, int(dim*mlp_ratio)),nn.GELU(),nn.Dropout(drop),nn.Linear(int(dim*mlp_ratio), dim),nn.Dropout(drop))def forward(self, x):x = x + self.attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return xclass ViT(nn.Module):def __init__(self, img_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12):super().__init__()self.patch_embed = PatchEmbed(img_size, patch_size, dim)n_patches = (img_size//patch_size)**2self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))self.pos_embed = nn.Parameter(torch.zeros(1, n_patches+1, dim))nn.init.trunc_normal_(self.pos_embed, std=0.02)self.blocks = nn.ModuleList([Block(dim, heads) for _ in range(depth)])self.norm = nn.LayerNorm(dim)self.head = nn.Linear(dim, num_classes)def forward(self, x):B = x.shape[0]x = self.patch_embed(x)                                # (B,196,768)cls_tokens = self.cls_token.expand(B, -1, -1)          # (B,1,768)x = torch.cat((cls_tokens, x), dim=1) + self.pos_embed # (B,197,768)for blk in self.blocks:x = blk(x)x = self.norm(x[:, 0])                                 # (B,768)return self.head(x)                                    # (B,1000)

🧩 8. 小结(背口诀)

“卷积变切片,切片当单词;加 token 学位置,堆叠自注意力;取头做分类,softmax 出概率。”

http://www.dtcms.com/a/536555.html

相关文章:

  • 前端与移动开发之 CSS vs QSS
  • 上那个网站找手工活做网上项目外包
  • 网站建设项目开发响应式学校网站模板下载
  • CICD之git
  • 零基础从头教学Linux(Day 57)
  • 综合网站推广的含义天津网站建设如何
  • Playwright中Browser的实现类深度解析-Browser方法速查手册
  • 智能指针完全指南
  • 数字阵列雷达(三)——系统工作原理(接收)
  • linux动态库加载方式:dlopen和直接链接.so库的区别?
  • 可克达拉市建设局网站呼和浩特做网站的地方
  • 插入排序:扑克牌式的排序算法!
  • 如何实现简单的HTTP代理服务器
  • vscode断点使用
  • 做自己网站做站长网站模板对seo的影响
  • Rust中的异常处理方式
  • ETCD 学习使用
  • 新能源汽车故障诊断与排除虚拟实训软件——赋能职业教育新工具
  • 自用提示词02 || Prompt Engineering || RAG数据切分 || 作用:通过LLM将文档切分成chunks
  • 网站开发实战作业答案成功网站案例有哪些
  • 对电子商务网站建设与管理的理解我想做个网站推广怎么做
  • 青少年机器人技术(六级)等级考试试卷-实操题(2025年9月)
  • Spring Boot核心知识点全解析
  • 如何在Qt QML中定义枚举浅谈
  • 6 mysql源码中的查询逻辑
  • 网站a记录的是做cname网页设计欣赏分析
  • Optuna 黑科技自动化超参数优化框架详解
  • 江西省第二届职业技能大赛网络安全赛题 应急响应
  • 网站制作哪家好又便宜东莞建设企业网站
  • 提高命令行运行效率-正则 表达式