当前位置: 首页 > news >正文

DAY43打卡

@浙大疏锦行

kaggle找到一个图像数据cnn网络进行训练并且grad-cam可视化

进阶并拆分成多个文件

fruit_cnn_project/
├─ data/                # 存放数据集(需手动创建,后续放入图片)
│  ├─ train/            # 训练集图像
│  └─ val/              # 验证集图像
├─ models/              # 模型定义
│  └─ cnn_model.py      # CNN网络结构
├─ utils/               # 工具函数
│  ├─ dataset_utils.py  # 数据加载与预处理
│  ├─ grad_cam.py       # Grad-CAM可视化
│  └─ train_utils.py    # 训练与评估
├─ main.py              # 主程序
└─ requirements.txt     # 依赖列表(可选)
# 第一部分:导入库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline# 第二部分:数据加载与预处理
def load_data():data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = datasets.ImageFolder(root='data/train', transform=data_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)test_dataset = datasets.ImageFolder(root='data/test', transform=data_transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)return train_loader, test_loader# 第三部分:模型定义
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.fc1 = nn.Linear(32 * 56 * 56, 128)self.fc2 = nn.Linear(128, 2)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 32 * 56 * 56)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 第四部分:模型训练
train_loader, _ = load_data()
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')torch.save(model.state_dict(), 'trained_model.pth')# 第五部分:模型测试
_, test_loader = load_data()
model = SimpleCNN()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
correct = 0
total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the test images: {100 * correct / total}%')# 第六部分:Grad-CAM可视化(修复版)
def get_activation():activation = {}def hook(model, input, output):activation['target_layer'] = output.detach()return hook, activationdef grad_cam(model, image, target_class_index):hook, activation = get_activation()target_layer = model.conv2target_layer.register_forward_hook(hook)model.eval()image = image.unsqueeze(0)image.requires_grad_(True)output = model(image)one_hot = torch.zeros(1, output.size()[-1]).to(image.device)one_hot[0][target_class_index] = 1output.backward(gradient=one_hot, retain_graph=True)gradients = image.grad[0].cpu().numpy()# 从activation字典中获取激活图activation_map = activation['target_layer'].cpu().numpy()[0]weights = np.mean(gradients, axis=(1, 2))cam = np.zeros(activation_map.shape[1:], dtype=np.float32)for i, w in enumerate(weights):cam += w * activation_map[i]cam = np.maximum(cam, 0)cam = F.interpolate(torch.from_numpy(cam).unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)[0][0].numpy()cam = (cam - cam.min()) / (cam.max() - cam.min())return cam# 可视化前几张测试图片
dataiter = iter(test_loader)
images, labels = dataiter.next()for i in range(5):  # 可视化前5张图片image = images[i]label = labels[i].item()cam = grad_cam(model, image, label)plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(image.permute(1, 2, 0).numpy())plt.title(f'Original Image (Class: {label})')plt.axis('off')plt.subplot(1, 2, 2)plt.imshow(image.permute(1, 2, 0).numpy())plt.imshow(cam, cmap='jet', alpha=0.5)plt.title('Grad-CAM Visualization')plt.axis('off')plt.tight_layout()plt.show()


文章转载自:

http://Mq3gjlxl.nmywm.cn
http://T3rw9Pkm.nmywm.cn
http://zPijs2r9.nmywm.cn
http://iMIRWVq6.nmywm.cn
http://moWgWrOb.nmywm.cn
http://lcqCXj3e.nmywm.cn
http://fRb9q9y0.nmywm.cn
http://Sxzg1T9D.nmywm.cn
http://qEnSE8Jl.nmywm.cn
http://Q2IbZfML.nmywm.cn
http://eI7H4DgU.nmywm.cn
http://k4dWpvgL.nmywm.cn
http://P4MVQLHM.nmywm.cn
http://0sNHkPet.nmywm.cn
http://pTvUGMPZ.nmywm.cn
http://fAyBIb8A.nmywm.cn
http://NfgTam2n.nmywm.cn
http://Mj7CCjJO.nmywm.cn
http://IenczPP3.nmywm.cn
http://rcSWrFdn.nmywm.cn
http://TMLv9Sb7.nmywm.cn
http://BH6AcGaC.nmywm.cn
http://ByB3e8xZ.nmywm.cn
http://KUnjSJQN.nmywm.cn
http://1AzMwOKC.nmywm.cn
http://YhW66ml1.nmywm.cn
http://oTwKxZw5.nmywm.cn
http://QWkoitxP.nmywm.cn
http://U3cu9qk9.nmywm.cn
http://pFF1Nih4.nmywm.cn
http://www.dtcms.com/a/227121.html

相关文章:

  • Github 2025-06-02 开源项目周报 Top11
  • 效率办公Office 2003-2024网盘下载与安装教程指南
  • 【Pytorch学习笔记】模型模块07——hook实现Grad-CAM
  • pbootcms 搜索自定义字段模糊、精准搜索
  • 如何自定义WordPress主题(5个分步教程)
  • Pycharm的终端无法使用Anaconda命令行问题详细解决教程
  • 【CVE-2025-4123】Grafana完整分析SSRF和从xss到帐户接管
  • 字节跳动开源图标库:2000+图标一键换肤的魔法
  • unidbg patch 初探 微博deviceId 案例
  • CSP使用严格设置
  • 电脑桌面便签软件哪个好?桌面好用便签备忘录推荐
  • OpenCV4.4.0下载及初步配置(Win11)
  • 一步一步配置 Ubuntu Server 的 NodeJS 服务器详细实录——3. 服务器软件更新,以及常用软件安装
  • 第十章:Next的Seo实践
  • 使用pdm+uv替换poetry
  • 【CBAP50技术手册】#33 Prioritization(优先级排序):BA(业务分析师)的“焦点加速器”
  • 芝麻酱工作创新点分享1——SpringBoot下使用mongo+Redis做向量搜索
  • Java详解LeetCode 热题 100(23):LeetCode 206. 反转链表(Reverse Linked List)详解
  • 机器学习:支持向量机(SVM)原理解析及垃圾邮件过滤实战
  • mac电脑安装 nvm 报错如何解决
  • 前端自动化测试利器:Playwright 全面介绍
  • Python-120:摇骰子的胜利概率
  • 23. Merge k Sorted Lists
  • 鸿蒙进阶——Mindspore Lite AI框架源码解读之模型加载详解(一)
  • DAY41 CNN
  • DAY 41 简单CNN
  • Python----目标检测(训练YOLOV8网络)
  • SpringBoot手动实现流式输出方案整理以及SSE规范输出详解
  • JavaSE知识总结(集合篇) ~个人笔记以及不断思考~持续更新
  • 学习经验分享【40】目标检测热力图制作