打卡day54
作业:
1.对inception网络在cifar10上观察精度
2.消融实验:引入残差机制和cbam模块分别进行消融
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
import matplotlib.pyplot as plt# 显存优化设置(必须放在所有torch导入之前)
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# --------------------------
# 1. 数据准备(CIFAR-10)
# --------------------------
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 减小batch_size以适应显存
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False, num_workers=2)# --------------------------
# 2. 模型定义(显存优化版)
# --------------------------
class CBAM(nn.Module):"""轻量级CBAM模块(通道和空间注意力)"""def __init__(self, channels, reduction_ratio=8):super(CBAM, self).__init__()self.channel_attention = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, channels // reduction_ratio, kernel_size=1),nn.ReLU(),nn.Conv2d(channels // reduction_ratio, channels, kernel_size=1),nn.Sigmoid())self.spatial_attention = nn.Sequential(nn.Conv2d(2, 1, kernel_size=7, padding=3),nn.Sigmoid())def forward(self, x):# 通道注意力channel = self.channel_attention(x) * x# 空间注意力spatial_avg = torch.mean(channel, dim=1, keepdim=True)spatial_max, _ = torch.max(channel, dim=1, keepdim=True)spatial = torch.cat([spatial_avg, spatial_max], dim=1)spatial = self.spatial_attention(spatial) * channelreturn spatialclass InceptionModule(nn.Module):"""优化后的Inception模块(通道数减半)"""def __init__(self, in_channels, use_residual=False, use_cbam=False):super(InceptionModule, self).__init__()self.use_residual = use_residualself.use_cbam = use_cbam# 分支1:1x1卷积self.branch1 = nn.Sequential(nn.Conv2d(in_channels, 16, kernel_size=1),nn.BatchNorm2d(16),nn.ReLU())# 分支2:1x1 -> 3x3self.branch2 = nn.Sequential(nn.Conv2d(in_channels, 16, kernel_size=1),nn.BatchNorm2d(16),nn.ReLU(),nn.Conv2d(16, 24, kernel_size=3, padding=1),nn.BatchNorm2d(24),nn.ReLU())# 分支3:1x1 -> 5x5self.branch3 = nn.Sequential(nn.Conv2d(in_channels, 16, kernel_size=1),nn.BatchNorm2d(16),nn.ReLU(),nn.Conv2d(16, 24, kernel_size=5, padding=2),nn.BatchNorm2d(24),nn.ReLU())# 分支4:3x3池化 -> 1x1self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),nn.Conv2d(in_channels, 16, kernel_size=1),nn.BatchNorm2d(16),nn.ReLU())# CBAM模块(如果启用)if use_cbam:self.cbam = CBAM(16 + 24 + 24 + 16)# 残差连接(如果启用)if use_residual:self.residual = nn.Sequential(nn.Conv2d(in_channels, 16 + 24 + 24 + 16, kernel_size=1),nn.BatchNorm2d(16 + 24 + 24 + 16))def forward(self, x):residual = x# 并行分支b1 = self.branch1(x)b2 = self.branch2(x)b3 = self.branch3(x)b4 = self.branch4(x)# 合并分支out = torch.cat([b1, b2, b3, b4], dim=1)# 应用CBAMif self.use_cbam:out = self.cbam(out)# 残差连接if self.use_residual:residual = self.residual(residual)out += residualreturn outclass InceptionNet(nn.Module):"""完整的轻量化Inception网络"""def __init__(self, num_classes=10, use_residual=False, use_cbam=False):super(InceptionNet, self).__init__()# 初始层(通道数减半)self.stem = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU())# Inception模块堆叠self.inception1 = InceptionModule(32, use_residual, use_cbam)self.inception2 = InceptionModule(16+24+24+16, use_residual, use_cbam)self.inception3 = InceptionModule(16+24+24+16, use_residual, use_cbam)# 分类头self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(16+24+24+16, num_classes)def forward(self, x):x = self.stem(x)x = self.inception1(x)x = self.inception2(x)x = self.inception3(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x# --------------------------
# 3. 训练与评估函数(含混合精度)
# --------------------------
def train_model(model, train_loader, criterion, optimizer, epoch):model.train()running_loss = 0.0correct = 0total = 0# 混合精度训练scaler = torch.cuda.amp.GradScaler()for inputs, labels in tqdm(train_loader, desc=f'Train Epoch {epoch}'):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()with torch.cuda.amp.autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()train_loss = running_loss / len(train_loader)train_acc = 100. * correct / totalreturn train_loss, train_accdef evaluate_model(model, test_loader, criterion):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for inputs, labels in tqdm(test_loader, desc='Evaluating'):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()test_loss = running_loss / len(test_loader)test_acc = 100. * correct / totalreturn test_loss, test_acc# --------------------------
# 4. 消融实验主流程
# --------------------------
def run_ablation(use_residual=False, use_cbam=False):print(f"\nRunning experiment: Residual={use_residual}, CBAM={use_cbam}")# 初始化模型model = InceptionNet(use_residual=use_residual, use_cbam=use_cbam).to(device)# 损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 训练记录train_losses, train_accs = [], []test_losses, test_accs = [], []# 训练循环(epochs减少以节省时间)for epoch in range(15):train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, epoch)test_loss, test_acc = evaluate_model(model, test_loader, criterion)scheduler.step()train_losses.append(train_loss)train_accs.append(train_acc)test_losses.append(test_loss)test_accs.append(test_acc)print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}% | 'f'Test Loss: {test_loss:.4f}, Acc: {test_acc:.2f}%')# 保存结果torch.save(model.state_dict(), f'best_model_res{use_residual}_cbam{use_cbam}.pth')# 绘制曲线plt.figure()plt.plot(train_accs, label='Train Acc')plt.plot(test_accs, label='Test Acc')plt.title(f'Accuracy (Residual={use_residual}, CBAM={use_cbam})')plt.legend()plt.savefig(f'result_res{use_residual}_cbam{use_cbam}.png')plt.close()# --------------------------
# 5. 执行所有消融实验
# --------------------------
if __name__ == '__main__':# 基础模型(无残差无CBAM)run_ablation(use_residual=False, use_cbam=False)# 仅残差run_ablation(use_residual=True, use_cbam=False)# 仅CBAMrun_ablation(use_residual=False, use_cbam=True)# 残差+CBAMrun_ablation(use_residual=True, use_cbam=True)