当前位置: 首页 > news >正文

Python 训练营打卡 Day 43

以猫狗图像辨别的新数据集为例,用CNN网络进行训练并用Grad-CAM做可视化

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from sklearn.model_selection import train_test_split
import os# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 训练集数据增强
train_transform = transforms.Compose([transforms.Resize((32, 32)),  # 调整为32×32transforms.RandomRotation(10),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 验证集仅需基础预处理
val_transform = transforms.Compose([transforms.Resize((32, 32)),  # 调整为32×32transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 数据集根目录
DATASET_ROOT = r'C:\Users\Lenovo\Desktop\archive\cats_vs_dogs_dataset'# 定义数据变换(训练集含增强,验证集无增强)
train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomRotation(10),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])val_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载完整数据集(训练+验证)
full_dataset = datasets.ImageFolder(root=DATASET_ROOT,transform=train_transform  # 初始使用训练集变换
)# 划分训练集和验证集(8:2比例)
total_samples = len(full_dataset)
train_samples = int(0.8 * total_samples)
val_samples = total_samples - train_samples# 随机划分(使用固定种子确保可复现)
torch.manual_seed(42)
train_dataset, val_dataset = random_split(full_dataset, [train_samples, val_samples], generator=torch.Generator().manual_seed(42)
)# 为验证集单独设置变换(移除数据增强)
val_dataset.dataset.transform = val_transform# 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
)# 查看数据集信息
class_names = full_dataset.classes
print(f"数据集类别: {class_names}")
print(f"训练集样本数: {len(train_dataset)}")
print(f"验证集样本数: {len(val_dataset)}")class CNN(nn.Module):def __init__(self, num_classes=2):super(CNN, self).__init__()# 卷积层配置self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)  # 32→32self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 32→16self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 16→16self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 16→8self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)  # 8→8self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 8→4# 全连接层输入维度:128通道 × 4×4特征图 = 2048self.fc1 = nn.Linear(128 * 4 * 4, 512)self.dropout = nn.Dropout(0.5)self.fc2 = nn.Linear(512, num_classes)def forward(self, x):x = self.pool1(self.relu1(self.bn1(self.conv1(x))))x = self.pool2(self.relu2(self.bn2(self.conv2(x))))x = self.pool3(self.relu3(self.bn3(self.conv3(x))))# 展平x = x.view(-1, 128 * 4 * 4)x = self.dropout(self.relu3(self.fc1(x)))x = self.fc2(x)return xdef train(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs):best_acc = 0.0best_model_path = 'best_cnn_model.pth'all_iter_losses = []iter_indices = []train_acc_history = []val_acc_history = []train_loss_history = []val_loss_history = []for epoch in range(epochs):# 训练阶段model.train()running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计准确率running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} 'f'| Loss: {iter_loss:.4f} | Acc: {100.*correct/total:.2f}%')# 计算训练指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_acc_history.append(epoch_train_acc)train_loss_history.append(epoch_train_loss)# 验证阶段model.eval()val_loss = 0correct_val = 0total_val = 0with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)output = model(data)val_loss += criterion(output, target).item()_, predicted = output.max(1)total_val += target.size(0)correct_val += predicted.eq(target).sum().item()epoch_val_loss = val_loss / len(val_loader)epoch_val_acc = 100. * correct_val / total_valval_acc_history.append(epoch_val_acc)val_loss_history.append(epoch_val_loss)# 更新学习率scheduler.step(epoch_val_loss)# 保存最佳模型if epoch_val_acc > best_acc:best_acc = epoch_val_acctorch.save(model.state_dict(), best_model_path)print(f'保存最佳模型 (Epoch {epoch+1} | Acc: {best_acc:.2f}%)')print(f'Epoch {epoch+1}/{epochs} | Train Loss: {epoch_train_loss:.4f} | 'f'Train Acc: {epoch_train_acc:.2f}% | Val Acc: {epoch_val_acc:.2f}%')# 加载最佳模型model.load_state_dict(torch.load(best_model_path))return best_acc, (train_acc_history, val_acc_history, train_loss_history, val_loss_history)def plot_epoch_metrics(train_acc, val_acc, train_loss, val_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 4))# 绘制准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, val_acc, 'r-', label='验证准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('训练和验证准确率')plt.legend()plt.grid(True)# 绘制损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, val_loss, 'r-', label='验证损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('训练和验证损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def visualize_gradcam(model, val_loader, class_names, device, num_samples=5):# 选择目标层(最后一个卷积层)target_layers = [model.conv3]# 创建GradCAM对象cam = GradCAM(model=model, target_layers=target_layers, use_cuda=device.type == 'cuda')model.eval()fig, axes = plt.subplots(num_samples, 2, figsize=(10, 4*num_samples))for i in range(num_samples):# 获取样本inputs, labels = next(iter(val_loader))input_tensor = inputs[0].unsqueeze(0).to(device)true_label = labels[0].item()# 预测with torch.no_grad():outputs = model(input_tensor)_, pred = torch.max(outputs, 1)pred = pred.item()# 生成Grad-CAM热力图grayscale_cam = cam(input_tensor=input_tensor, targets=None)grayscale_cam = grayscale_cam[0, :]  # 取第一个样本的热力图# 预处理原始图像用于可视化img = input_tensor[0].cpu().permute(1, 2, 0).numpy()img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))img = np.clip(img, 0, 1)# 叠加热力图visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)# 显示原始图像axes[i, 0].imshow(img)axes[i, 0].set_title(f'原始图像\n真实: {class_names[true_label]}, 预测: {class_names[pred]}')axes[i, 0].axis('off')# 显示Grad-CAM结果axes[i, 1].imshow(visualization)axes[i, 1].set_title('Grad-CAM热力图')axes[i, 1].axis('off')plt.tight_layout()plt.show()# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")# 初始化模型(适应32×32输入)
model = CNN(num_classes=len(class_names)).to(device)# 定义损失函数、优化器和学习率调度器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True
)# 训练模型
print("开始训练CNN模型...")
best_acc, metrics = train(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs=20)
print(f"训练完成!最佳验证准确率: {best_acc:.2f}%")# 绘制训练指标
train_acc, val_acc, train_loss, val_loss = metrics
plot_epoch_metrics(train_acc, val_acc, train_loss, val_loss)# 可视化Grad-CAM结果
visualize_gradcam(model, val_loader, class_names, device, num_samples=5)
使用设备: cuda
开始训练CNN模型...
Epoch 1/20 | Batch 100/541 | Loss: 0.6872 | Acc: 61.47%
Epoch 1/20 | Batch 200/541 | Loss: 0.6624 | Acc: 64.19%
Epoch 1/20 | Batch 300/541 | Loss: 0.5880 | Acc: 66.16%
Epoch 1/20 | Batch 400/541 | Loss: 0.5256 | Acc: 67.46%
Epoch 1/20 | Batch 500/541 | Loss: 0.5808 | Acc: 68.56%
保存最佳模型 (Epoch 1 | Acc: 76.11%)
Epoch 1/20 | Train Loss: 0.5969 | Train Acc: 68.75% | Val Acc: 76.11%
Epoch 2/20 | Batch 100/541 | Loss: 0.5069 | Acc: 73.16%
Epoch 2/20 | Batch 200/541 | Loss: 0.4214 | Acc: 74.80%
Epoch 2/20 | Batch 300/541 | Loss: 0.5005 | Acc: 75.47%
Epoch 2/20 | Batch 400/541 | Loss: 0.4932 | Acc: 75.99%
Epoch 2/20 | Batch 500/541 | Loss: 0.2958 | Acc: 76.34%
保存最佳模型 (Epoch 2 | Acc: 77.15%)
Epoch 2/20 | Train Loss: 0.4893 | Train Acc: 76.54% | Val Acc: 77.15%
Epoch 3/20 | Batch 100/541 | Loss: 0.5376 | Acc: 80.34%
Epoch 3/20 | Batch 200/541 | Loss: 0.4955 | Acc: 80.27%
Epoch 3/20 | Batch 300/541 | Loss: 0.3023 | Acc: 79.84%
Epoch 3/20 | Batch 400/541 | Loss: 0.4594 | Acc: 79.97%
Epoch 3/20 | Batch 500/541 | Loss: 0.3883 | Acc: 80.11%
保存最佳模型 (Epoch 3 | Acc: 81.61%)
Epoch 3/20 | Train Loss: 0.4306 | Train Acc: 80.06% | Val Acc: 81.61%
Epoch 4/20 | Batch 100/541 | Loss: 0.3557 | Acc: 81.66%
Epoch 4/20 | Batch 200/541 | Loss: 0.2884 | Acc: 82.02%
...
Epoch 20/20 | Batch 400/541 | Loss: 0.0146 | Acc: 99.88%
Epoch 20/20 | Batch 500/541 | Loss: 0.0139 | Acc: 99.88%
Epoch 20/20 | Train Loss: 0.0056 | Train Acc: 99.88% | Val Acc: 85.62%
训练完成!最佳验证准确率: 85.96%

相关文章:

  • 玄机-日志分析-IIS日志分析
  • JavaWeb:前端工程化-ElementPlus
  • Hot100 Day02(移动0,乘最多水的容器、三数之和、接雨水)
  • 区块链+AI融合实战:智能合约如何结合机器学习优化DeFi风控?
  • 2025年五一数学建模竞赛A题-支路车流量推测问题详细建模与源代码编写(一)
  • 守护生命律动:进行性核上性麻痹的专业健康护理指南
  • Python爬虫:trafilatura 的详细使用(快速提取正文和评论以及结构,转换为 TXT、CSV 和 XML)
  • SD卡通过读取bin文件替代读取图片格式文件来提高LCD显示速度
  • 34.2STM32下的can总线外设_csdn
  • GQA(Grouped Query Attention):分组注意力机制的原理与实践《三》
  • Linux 环境下 PPP 拨号的嵌入式开发实现
  • 网络可靠性的定义与核心要素
  • 用户 xxx is not in the sudoers file.
  • FEMFAT许可分析中的关键指标
  • CentOS在vmware局域网内搭建DHCP服务器【踩坑记录】
  • html2canvas v1.0.0-alpha.12版本文本重叠问题修复
  • qt+vs Generated File下的moc_和ui_文件丢失导致 error LNK2001
  • Unity安卓平台开发,启动app并传参
  • 使用 SseEmitter 实现 Spring Boot 后端的流式传输和前端的数据接收
  • 麒麟+ARM架构安装mysql8的操作指南
  • 哪个网站做logo/推广图片大全
  • 室内设计素材网站大全/适合奖励自己的网站免费
  • 网页模板下载网站/如何投放网络广告
  • 做网站如何语音对话/怎样做好网络推广呀
  • 一键生成海报的网站/手机百度网盘网页版登录入口
  • 高米店网站建设公司/b2b商务平台