神经网络-Day49
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# ====================== 配置与设备检查 ======================
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# ====================== 数据预处理与加载 ======================
# 训练集数据增强与归一化
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 测试集仅归一化
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载CIFAR10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform
)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# ====================== CBAM模块定义 ======================
## 通道注意力模块
class ChannelAttention(nn.Module):def __init__(self, in_channels: int, ratio: int = 16):"""通道注意力机制Args:in_channels: 输入通道数ratio: 降维比例,默认16"""super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化self.max_pool = nn.AdaptiveMaxPool2d(1) # 全局最大池化# 共享全连接层实现通道降维和升维self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False),nn.ReLU(),nn.Linear(in_channels // ratio, in_channels, bias=False))self.sigmoid = nn.Sigmoid() # 生成通道权重def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播Args:x: 输入特征图 (B, C, H, W)Returns:通道加权后的特征图"""b, c, h, w = x.shapeavg_feat = self.fc(self.avg_pool(x).view(b, c)) # 平均池化特征max_feat = self.fc(self.max_pool(x).view(b, c)) # 最大池化特征attn = self.sigmoid(avg_feat + max_feat).view(b, c, 1, 1) # 融合权重return x * attn # 应用通道注意力## 空间注意力模块
class SpatialAttention(nn.Module):def __init__(self, kernel_size: int = 7):"""空间注意力机制Args:kernel_size: 卷积核尺寸,默认7"""super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播Args:x: 输入特征图 (B, C, H, W)Returns:空间加权后的特征图"""# 通道维度池化avg_feat = torch.mean(x, dim=1, keepdim=True) # 平均池化max_feat, _ = torch.max(x, dim=1, keepdim=True) # 最大池化pool_feat = torch.cat([avg_feat, max_feat], dim=1) # 拼接特征attn = self.conv(pool_feat) # 卷积提取空间特征return x * self.sigmoid(attn) # 应用空间注意力## CBAM组合模块
class CBAM(nn.Module):def __init__(self, in_channels: int, ratio: int = 16, kernel_size: int = 7):"""卷积块注意力模块 (CBAM)Args:in_channels: 输入通道数ratio: 通道注意力降维比例,默认16kernel_size: 空间注意力卷积核尺寸,默认7"""super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播(先通道注意力,后空间注意力)Args:x: 输入特征图 (B, C, H, W)Returns:注意力加权后的特征图"""x = self.channel_attn(x)x = self.spatial_attn(x)return x# ====================== 带CBAM的CNN模型定义 ======================
class CBAM_CNN(nn.Module):def __init__(self):super().__init__()# 卷积块1:3->32通道,带CBAMself.conv_block1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.cbam1 = CBAM(in_channels=32) # 第一个CBAM模块# 卷积块2:32->64通道,带CBAMself.conv_block2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.cbam2 = CBAM(in_channels=64) # 第二个CBAM模块# 卷积块3:64->128通道,带CBAMself.conv_block3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.cbam3 = CBAM(in_channels=128) # 第三个CBAM模块# 全连接层self.fc_layers = nn.Sequential(nn.Linear(128 * 4 * 4, 512),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(512, 10))def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播流程Args:x: 输入图像 (B, 3, 32, 32)Returns:分类 logits (B, 10)"""# 卷积块1 + CBAM1x = self.conv_block1(x)x = self.cbam1(x)# 卷积块2 + CBAM2x = self.conv_block2(x)x = self.cbam2(x)# 卷积块3 + CBAM3x = self.conv_block3(x)x = self.cbam3(x)# 展平并通过全连接层x = x.view(x.size(0), -1)x = self.fc_layers(x)return x# ====================== 训练配置与函数 ======================
# 初始化模型、损失函数和优化器
model = CBAM_CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)def train(model: nn.Module,train_loader: DataLoader,test_loader: DataLoader,criterion: nn.Module,optimizer: optim.Optimizer,scheduler: optim.lr_scheduler._LRScheduler,device: torch.device,epochs: int
) -> float:"""训练过程主函数Args:model: 待训练模型train_loader: 训练数据加载器test_loader: 测试数据加载器criterion: 损失函数optimizer: 优化器scheduler: 学习率调度器device: 计算设备epochs: 训练轮数Returns:最终测试准确率"""model.train()train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct_train = 0total_train = 0# 训练阶段for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录迭代级损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:avg_loss = running_loss / (batch_idx + 1)print(f"Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f} | 平均损失: {avg_loss:.4f}")# 计算 epoch 级训练指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_traintrain_loss_history.append(epoch_train_loss)train_acc_history.append(epoch_train_acc)# 测试阶段model.eval()test_loss = 0.0correct_test = 0total_test = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()# 计算 epoch 级测试指标epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_loss_history.append(epoch_test_loss)test_acc_history.append(epoch_test_acc)# 调整学习率scheduler.step(epoch_test_loss)# 打印 epoch 总结print(f"Epoch {epoch+1}/{epochs} 完成 | "f"Train Acc: {epoch_train_acc:.2f}% | Test Acc: {epoch_test_acc:.2f}%")# 绘制训练过程图表plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc# ====================== 绘图函数 ======================
def plot_iter_losses(losses: list, indices: list) -> None:"""绘制每个迭代的损失曲线"""plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration (Batch序号)')plt.ylabel('Loss值')plt.title('训练过程中每个Batch的损失变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def plot_epoch_metrics(train_acc: list,test_acc: list,train_loss: list,test_loss: list
) -> None:"""绘制 epoch 级准确率和损失曲线"""epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 4))# 准确率子图plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='Train Accuracy')plt.plot(epochs, test_acc, 'r-', label='Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('训练与测试准确率对比')plt.legend()plt.grid(True)# 损失子图plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='Train Loss')plt.plot(epochs, test_loss, 'r-', label='Test Loss')plt.xlabel('Epoch')plt.ylabel('Loss值')plt.title('训练与测试损失对比')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# ====================== 执行训练 ======================
epochs = 50
print("开始训练带CBAM的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# # 如需保存模型,取消注释以下代码
# torch.save(model.state_dict(), 'cifar10_cbam_cnn_model.pth')
# print("模型已保存为: cifar10_cbam_cnn_model.pth")
@浙大疏锦行