当前位置: 首页 > news >正文

Vision Transformer(ViT)模型实例化PyTorch逐行实现

为了让大家更好地理解,我们将从零开始,逐步构建 ViT 的各个核心组件,并最终将它们组合成一个完整的模型。我们会以一个在 CIFAR-10 数据集上应用的实例来贯穿整个讲解过程。

ViT 核心思想

在讲解代码之前,我们先快速回顾一下 ViT 的核心思想,这有助于理解代码每一部分的目的。

图片切块 (Image to Patches): 传统 CNN 逐像素处理图像,而 ViT 模仿 NLP 中处理单词 (Token) 的方式。它将一幅图像 (H*W*C) 切割成一个个小块 (Patch),每个小块大小为 P*P*C。

展平与线性投射 (Patch Flattening & Linear Projection): 将每个小块展平成一个一维向量,然后通过一个全连接层(线性投射)将其映射到一个固定的维度 D,这个向量就成为了 Transformer 的 "Token"。

类别令牌 (Class Token): 模仿 BERT 的 [CLS] 令牌,在所有 Patch Token 的最前面加入一个可学习的 [CLS] Token。这个 Token 最终将用于图像分类。

位置编码 (Positional Embedding): Transformer 本身不包含位置信息。为了让模型知道每个 Patch 的原始位置,我们需要为每个 Token(包括 [CLS] Token)添加一个可学习的位置编码。

Transformer 编码器 (Transformer Encoder): 将带有位置编码的 Token 序列输入到标准的 Transformer Encoder 中。Encoder 由多层堆叠而成,每一层都包含一个多头自注意力模块 (Multi-Head Self-Attention) 和一个前馈网络 (Feed-Forward Network)

分类头 (MLP Head): 将 Transformer Encoder 输出的 [CLS] Token 对应的向量,送入一个简单的多层感知机(MLP),最终输出分类结果。

实例设定

我们将以 CIFAR-10 数据集为例。

图片尺寸 (image_size): 32*32*3

Patch 尺寸 (patch_size): 4*4 (我们可以选择 8x8 或 16x16,这里用 4x4 举例)

类别数 (num_classes): 10

嵌入维度 (dim): 512 (每个 Patch 展平后映射到的维度)

Transformer Encoder 层数 (depth): 6

多头注意力头数 (heads): 8

MLP 内部维度 (mlp_dim): 2048

根据这些设定,我们可以计算出:

每张图片的 Patch 数量 (num_patches): (32/4)x(32/4)=8x8=64

PyTorch 代码逐行实现

我们将按照 ViT 的思想,一步步构建代码。

1. Patch Embedding (图像切块与线性投射)

这是 ViT 的第一步,我们的目标是将一个 (B, C, H, W) 的图像张量,转换成一个 (B, N, D) 的 Token 序列张量,其中 B 是批量大小,N 是 Patch 数量,D 是嵌入维度。

一个巧妙高效的实现方法是使用二维卷积

思想: 我们可以设置一个卷积层,其卷积核大小 (kernel_size)步长 (stride) 都等于 patch_size。这样,卷积核每次滑动的区域恰好就是一个不重叠的 Patch。卷积的输出通道数设为我们想要的嵌入维度 dim

import torch
from torch import nnclass PatchEmbedding(nn.Module):"""将图像分割成块并进行线性嵌入。参数:image_size (int): 输入图像的尺寸 (假设为正方形)。patch_size (int): 每个图像块的尺寸 (假设为正方形)。in_channels (int): 输入图像的通道数。dim (int): 线性投射后的嵌入维度。"""def __init__(self, image_size, patch_size, in_channels, dim):super().__init__()self.patch_size = patch_size# 检查图像尺寸是否能被 patch 尺寸整除if not (image_size % patch_size == 0):raise ValueError("error")# 计算 patch 的数量self.num_patches = (image_size // patch_size) ** 2# 核心:使用 Conv2d 实现 patch 化和线性投射# kernel_size 和 stride 都设为 patch_size,实现不重叠的块分割# out_channels 设为嵌入维度 dimself.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):# 输入 x 的形状: (B, C, H, W)# 例如: (B, 3, 32, 32)# 经过卷积层,将图像转换为 patch 的特征图# 输出形状: (B, dim, H/P, W/P)# 例如: (B, 512, 8, 8)x = self.projection(x)# 将特征图展平# .flatten(2) 将从第2个维度开始展平 (H/P 和 W/P 维度)# 输出形状: (B, dim, N) 其中 N = (H/P) * (W/P)# 例如: (B, 512, 64)x = x.flatten(2)# 交换维度,以匹配 Transformer 输入格式 (B, N, D)# 输出形状: (B, N, dim)# 例如: (B, 64, 512)x = x.transpose(1, 2)return x
2. Transformer Encoder Block

Transformer Encoder 由多个相同的块 (Block) 堆叠而成。每个块包含两个主要部分:

多头自注意力 (Multi-Head Self-Attention)

前馈网络 (Feed-Forward Network / MLP)

每个部分都伴随着残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。

class TransformerEncoderBlock(nn.Module):"""标准的 Transformer Encoder 块。参数:dim (int): 输入的 token 维度。heads (int): 多头注意力的头数。mlp_dim (int): MLP 层的隐藏维度。dropout (float): Dropout 的概率。"""def __init__(self, dim, heads, mlp_dim, dropout=0.1):super().__init__()# 第一个 LayerNormself.norm1 = nn.LayerNorm(dim)# 多头自注意力模块# PyTorch 内置的 MultiheadAttention 期望输入形状为 (N, B, D),# 但我们通常使用 (B, N, D)。设置 batch_first=True 可以解决这个问题。self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)# 第二个 LayerNormself.norm2 = nn.LayerNorm(dim)# MLP / 前馈网络self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim),nn.GELU(),  # ViT 论文中使用的激活函数nn.Dropout(dropout),nn.Linear(mlp_dim, dim),nn.Dropout(dropout))def forward(self, x):# x 的形状: (B, N, D)# 1. 多头自注意力部分# 残差连接: x + Attention(LayerNorm(x))x_norm = self.norm1(x)# 注意力模块返回 attn_output 和 attn_weights,我们只需要前者attn_output, _ = self.attention(x_norm, x_norm, x_norm)x = x + attn_output# 2. 前馈网络部分# 残差连接: x + MLP(LayerNorm(x))x_norm = self.norm2(x)mlp_output = self.mlp(x_norm)x = x + mlp_outputreturn x
3. 完整的 Vision Transformer 模型

现在,我们将所有组件整合在一起。

class VisionTransformer(nn.Module):"""Vision Transformer 模型。参数:image_size (int): 输入图像尺寸。patch_size (int): Patch 尺寸。in_channels (int): 输入通道数。num_classes (int): 分类类别数。dim (int): 嵌入维度。depth (int): Transformer Encoder 层数。heads (int): 多头注意力头数。mlp_dim (int): MLP 隐藏维度。dropout (float): Dropout 概率。"""def __init__(self, image_size, patch_size, in_channels, num_classes,dim, depth, heads, mlp_dim, dropout=0.1):super().__init__()# 1. Patch Embeddingself.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)# 计算 patch 数量num_patches = self.patch_embedding.num_patches# 2. Class Token# 这是一个可学习的参数,维度为 (1, 1, D)# '1' 个 batch,'1' 个 token,'D' 维self.cls_token = nn.Parameter(torch.randn(1, 1, dim))# 3. Positional Embedding# 这也是一个可学习的参数# 长度为 num_patches + 1 (为了包含 cls_token)# 维度为 (1, N+1, D)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.dropout = nn.Dropout(dropout)# 4. Transformer Encoder# 使用 nn.Sequential 将多个 Encoder Block 堆叠起来self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)])# 5. MLP Head (分类头)self.mlp_head = nn.Sequential(nn.LayerNorm(dim), # 在送入分类头前先进行一次 LayerNormnn.Linear(dim, num_classes))def forward(self, img):# img 形状: (B, C, H, W)# 1. 获取 Patch Embedding# x 形状: (B, N, D)x = self.patch_embedding(img)b, n, d = x.shape  # b: batch_size, n: num_patches, d: dim# 2. 添加 Class Token# 将 cls_token 复制 b 份,拼接到 x 的最前面# cls_tokens 形状: (B, 1, D)cls_tokens = self.cls_token.expand(b, -1, -1) # x 形状变为: (B, N+1, D)x = torch.cat((cls_tokens, x), dim=1)# 3. 添加 Positional Embedding# pos_embedding 形状是 (1, N+1, D),利用广播机制直接相加x += self.pos_embeddingx = self.dropout(x)# 4. 通过 Transformer Encoder# x 形状不变: (B, N+1, D)x = self.transformer_encoder(x)# 5. 提取 Class Token 的输出用于分类# 只取序列的第一个 token (cls_token) 的输出# x 形状: (B, D)cls_token_output = x[:, 0]# 6. 通过 MLP Head 得到最终的分类 logits# output 形状: (B, num_classes)output = self.mlp_head(cls_token_output)return output

完整模型与实例

现在我们把所有代码放在一起,并用我们之前设定的 CIFAR-10 参数来实例化模型,看看它的输入和输出。

import torch
from torch import nn# --- 组件 1: PatchEmbedding ---
class PatchEmbedding(nn.Module):def __init__(self, image_size, patch_size, in_channels, dim):super().__init__()if not (image_size % patch_size == 0):raise ValueError("Image dimensions must be divisible by the patch size.")self.num_patches = (image_size // patch_size) ** 2self.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):x = self.projection(x)x = x.flatten(2)x = x.transpose(1, 2)return x# --- 组件 2: TransformerEncoderBlock ---
class TransformerEncoderBlock(nn.Module):def __init__(self, dim, heads, mlp_dim, dropout=0.1):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attention = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)self.norm2 = nn.LayerNorm(dim)self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim, dim),nn.Dropout(dropout))def forward(self, x):attn_output, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))x = x + attn_outputmlp_output = self.mlp(self.norm2(x))x = x + mlp_outputreturn x# --- 主模型: VisionTransformer ---
class VisionTransformer(nn.Module):def __init__(self, image_size, patch_size, in_channels, num_classes,dim, depth, heads, mlp_dim, dropout=0.1):super().__init__()self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)num_patches = self.patch_embedding.num_patchesself.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.dropout = nn.Dropout(dropout)self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)])self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.patch_embedding(img)b, n, d = x.shapecls_tokens = self.cls_token.expand(b, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embeddingx = self.dropout(x)x = self.transformer_encoder(x)cls_token_output = x[:, 0]output = self.mlp_head(cls_token_output)return output# --- 实例化并测试 ---# CIFAR-10 实例参数
BATCH_SIZE = 4
IMAGE_SIZE = 32
IN_CHANNELS = 3
PATCH_SIZE = 4
NUM_CLASSES = 10
DIM = 512
DEPTH = 6
HEADS = 8
MLP_DIM = 2048# 创建模型实例
vit_model = VisionTransformer(image_size=IMAGE_SIZE,patch_size=PATCH_SIZE,in_channels=IN_CHANNELS,num_classes=NUM_CLASSES,dim=DIM,depth=DEPTH,heads=HEADS,mlp_dim=MLP_DIM
)# 创建一个假的输入图像张量 (Batch, Channels, Height, Width)
dummy_img = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)# 将图像输入模型
logits = vit_model(dummy_img)# 打印输出的形状
print(f"输入图像形状: {dummy_img.shape}")
print(f"模型输出 (Logits) 形状: {logits.shape}")# 检查输出形状是否正确
assert logits.shape == (BATCH_SIZE, NUM_CLASSES)
print("\n模型构建成功,输入输出形状正确!")

http://www.dtcms.com/a/308209.html

相关文章:

  • 从 MySQL 迁移到 TiDB:使用 SQL-Replay 工具进行真实线上流量回放测试 SOP
  • SpringBoot3.x入门到精通系列:1.2 开发环境搭建
  • 25-vue-photo-preview的使用及使用过程中的问题解决方案
  • 实战教程 ---- Nginx结合Lua实现WAF拦截并可视化配置教程框架
  • 走进computed,了解computed的前世今生
  • 【云故事探索】NO.16:阿里云弹性计算加速精准学 AI 教育普惠落地
  • 谁在托举Agent?阿里云抢滩Agent Infra新赛道
  • 安装 docker compose v2版 笔记250731
  • 对接八大应用渠道
  • Tomcat,WebLogic等中间件漏洞实战解析
  • 大模型流式长链接场景下 k8s 优雅退出 JAVA
  • 用 MyBatis + MySQL 实现高效的批量 Upsert
  • 关于tresos Studio(EB)的MCAL配置之GtmCfg
  • 性能测试篇 :Jmeter监控服务器性能
  • Golang 语言的编程技巧之类型
  • 基础组件(六):网络缓冲区设计 和 定时器方案
  • TTS语音合成|GPT-SoVITS语音合成服务器部署,实现http访问
  • Vue3+Vite项目如何简单使用tsx
  • nl2sql grpo强化学习训练,加大数据量和轮数后,准确率没提升,反而下降了,如何调整
  • PostgreSQL dblink 与 Spring Boot @Transactional 的事务整合
  • Text2SQL 智能问答系统开发-预定义模板(二)
  • docker离线安装mysql镜像
  • 记录几个SystemVerilog的语法——覆盖率
  • 基于MATLAB的GUI来对不同的(彩色或灰色)图像进行图像增强
  • 【国内电子数据取证厂商龙信科技】内存取证
  • 法式基因音响品牌SK(SINGKING AUDIO)如何以硬核科技重塑专业音频版图
  • 防御保护第一次作业
  • AI Gateway 分析:OpenRouter vs Higress
  • python基础语法3,组合数据类型(简单易上手的python语法教学)(课后习题)
  • BFT平台:打造科研教育“最强机器人矩阵”