视觉Transformer(Vision Transformer , ViT )
将传统的Transformer架构应用于计算机视觉任务的模型,2020年由Dosovitskiy等人提出,挑战了卷积神经网络(CNN)长期以来的主导地位。
一 核心思想
由于Transformer不像卷积使用卷积核,而是采用了全局自注意力机制,每一个参数与其他所有参数都有交互。因此为了处理图像,ViT将图像分割为不重叠的图像块(patches),转换为一维序列后输入Transformer进行全局信息建模。
二 建模流程
2.1图像分块(Patch Embedding)
输入图像:假设为 (图像的大小,C为通道数)
分块操作:将图像分割为 N 个 P×P 的小块(例如16x16),得到序列长度。
展平与线性投影:每个Patch展平为一维向量 ,通过可学习的线性层投影到维度
,得到Patch Embeddings。
(eg:展平为一维,变为长度256×3=768的一维向量,并投影到D=768维的空间。)
公式:
为投影矩阵,
是位置编码,
是分类token。
2.2 位置编码(Position Encoding)
作用:保留图像块的空间位置信息(Transformer本身对顺序不敏感)。
方式:可学习的位置编码(ViT)或固定正弦编码(原版Transformer)。
2.3 Transformer Encoder
层结构:由多头自注意力(MSA)和多层感知机(MLP)交替组成,每层包含残差连接和层归一化(LayerNorm)。、
MLP的作用:
对自注意力层输出的特征进行非线性变换和重组。
提供模型容量,增强表达能力。
2.4 分类头
取出[class]
token对应的输出,经MLP得到类别概率。
三 核心组件
1.多头自注意力(MSA)
Query-Key-Value计算:通过线性变换生成Q、K、V矩阵,计算注意力权重。
多头机制:将Q/K/V拆分为 h 个头,独立计算后拼接,增强捕捉不同子空间特征的能力。
2. 前馈网络(MLP)
由两个全连接层组成,中间通过GELU激活函数:
四 ViT的特性
全局建模能力:自注意力机制允许任意两个图像块间交互,克服了CNN的局部性限制。
对数据量的依赖:需大规模数据(如JFT-300M)预训练,小数据容易过拟合。
计算复杂度:序列长度 N 的平方级复杂度。例如,输入序列越长(密集任务如分割),计算成本显著增加。
五 ViT vs CNN
特性 | ViT | CNN |
---|---|---|
感受野 | 全局(来自自注意力) | 局部(通过堆叠扩大) |
归纳偏差 | 弱(无先验假设) | 强(平移不变性、局部性) |
数据需求 | 依赖大数据 | 小数据友好 |
计算效率 | 高分辨率图像效率低 | 适合高分辨率任务 |
六 应用场景
图像分类:ViT在ImageNet上达到与CNN相当甚至更优的准确率。
目标检测(如DETR):使用Transformer编码器-解码器进行端到端检测。
图像生成(如TransGAN):结合生成对抗网络。
医学图像分析:利用全局上下文处理病变区域的长程依赖。
七 pytorch中使用ViT处理CIFAR10
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTConfig
import matplotlib.pyplot as plt
import numpy as np# 解决PyCharm科学模式与Matplotlib的兼容性问题
plt.switch_backend('TkAgg')# ----------- 环境验证 -----------
def check_environment():"""检查必要组件的安装状态"""try:assert torch.__version__ >= '1.10.0', "需要PyTorch ≥1.10"assert sys.version_info >= (3, 7), "需要Python ≥3.7"assert np.__version__.startswith(('1.', '2.')), "NumPy需要1.x或2.x版本"except AssertionError as e:print(f"环境错误: {e}")sys.exit(1)# ----------- 数据预处理 -----------
class CIFAR10Processor:"""封装数据加载和处理逻辑"""def __init__(self):self.transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])def get_loaders(self, batch_size=32):"""获取数据加载器"""# 自动检测PyCharm项目根目录下的data文件夹data_root = './data' if os.path.exists('./data') else '../data'train_set = datasets.CIFAR10(root=data_root, train=True, download=True,transform=self.transform)test_set = datasets.CIFAR10(root=data_root, train=False, download=True,transform=self.transform)# Windows下必须设置num_workers=0train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)return train_loader, test_loader# ----------- 模型定义 -----------
def create_vit_model(device):"""创建轻量化ViT模型"""config = ViTConfig(image_size=224,patch_size=16,num_channels=3,num_labels=10,hidden_size=256, # 原始为768,减小参数规模num_hidden_layers=4,num_attention_heads=4,intermediate_size=1024,)model = ViTForImageClassification(config).to(device)print(f"模型参数量:{sum(p.numel() for p in model.parameters()):,}")return model# ----------- 训练流程 -----------
class Trainer:def __init__(self, model, device):self.model = modelself.device = deviceself.criterion = nn.CrossEntropyLoss()self.optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=10)def train_epoch(self, train_loader):"""单轮训练"""self.model.train()total_loss = 0.0for images, labels in train_loader:images, labels = images.to(self.device), labels.to(self.device)self.optimizer.zero_grad()outputs = self.model(images).logitsloss = self.criterion(outputs, labels)loss.backward()nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)self.optimizer.step()total_loss += loss.item() * images.size(0)return total_loss / len(train_loader.dataset)def evaluate(self, test_loader):"""模型评估"""self.model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(self.device), labels.to(self.device)outputs = self.model(images).logits_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return 100 * correct / total# ----------- 主程序 -----------
def main():check_environment()# 设备检测device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"当前设备: {device}")# 数据加载processor = CIFAR10Processor()train_loader, test_loader = processor.get_loaders(batch_size=32)# 模型初始化model = create_vit_model(device)# 训练配置trainer = Trainer(model, device)num_epochs = 15best_acc = 0.0train_losses = []accuracies = []# 训练循环for epoch in range(num_epochs):loss = trainer.train_epoch(train_loader)acc = trainer.evaluate(test_loader)train_losses.append(loss)accuracies.append(acc)trainer.scheduler.step()print(f"Epoch [{epoch+1}/{num_epochs}] "f"Loss: {loss:.4f} | Acc: {acc:.2f}% "f"LR: {trainer.scheduler.get_last_lr()[0]:.2e}")# 保存最佳模型if acc > best_acc:best_acc = acctorch.save(model.state_dict(), "best_vit.pth")# 可视化结果plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Training Loss')plt.title("Loss Curve")plt.xlabel("Epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(accuracies, label='Validation Accuracy')plt.title("Accuracy Curve")plt.xlabel("Epoch")plt.ylabel("Accuracy (%)")plt.tight_layout()plt.savefig('training_results.png')plt.show()if __name__ == '__main__':# 解决Windows多进程问题torch.multiprocessing.freeze_support()# 设置NumPy兼容模式(可选)os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION'] = '0'main()