当前位置: 首页 > news >正文

打卡第43天

作业: kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化 进阶:并拆分成多个文件

一、数据集准备

  1. Kaggle 选数据集

    • 登录 Kaggle,搜索关键词如 “image classification dataset”(例:CIFAR-10、猫狗分类、MNIST)。
    • 下载数据集(需注册账号,部分需同意协议)。
  2. 数据拆分文件

    • 用 Python 脚本拆分数据集为:
      • train/(训练集,约 80%)
      • val/(验证集,约 10%)
      • test/(测试集,约 10%)
    • 代码:
from sklearn.model_selection import train_test_split
import shutil, os# 假设原始数据在data/目录下,每个类别一个子文件夹
classes = os.listdir('data/')
for cls in classes:img_paths = [os.path.join('data/', cls, f) for f in os.listdir(f'data/{cls}/')]train_val, test = train_test_split(img_paths, test_size=0.1, random_state=42)train, val = train_test_split(train_val, test_size=0.111, random_state=42)  # 10%/90%拆验证集# 创建文件夹os.makedirs(f'train/{cls}', exist_ok=True)os.makedirs(f'val/{cls}', exist_ok=True)os.makedirs(f'test/{cls}', exist_ok=True)# 移动文件for path in train: shutil.move(path, f'train/{cls}/')for path in val: shutil.move(path, f'val/{cls}/')for path in test: shutil.move(path, f'test/{cls}/')

二、CNN 模型训练

  1. 环境搭建

    • 安装库:pip install torch torchvision torchaudio matplotlib grad-cam(PyTorch 版)。
  2. 构建 CNN 模型

    • 示例:

 

import torch
import torch.nn as nn
import torch.optim as optimclass CustomCNN(nn.Module):def __init__(self, num_classes=10):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc_layers = nn.Sequential(nn.Linear(32*7*7, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, num_classes))def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)x = self.fc_layers(x)return xmodel = CustomCNN(num_classes=数据集类别数).to('cuda' if torch.cuda.is_available() else 'cpu')

3.训练流程

  • 定义损失函数、优化器、数据加载器,循环训练:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 10for epoch in range(epochs):model.train()for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 验证集评估...

三、Grad-CAM 可视化

  1. 安装工具

    • pip install grad-cam
  2. 生成热力图

    • 代码:
from grad_cam import GradCAM
from torchvision import transforms
import matplotlib.pyplot as plt# 加载单张测试图像
img_path = 'test/某类别/某图片.jpg'
img = Image.open(img_path).convert('RGB')
transform = transforms.Compose([transforms.Resize((224, 224)),  # 需与模型输入尺寸一致transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(img).unsqueeze(0).to(device)# 指定目标层(通常取最后一个卷积层)
target_layer = model.conv_layers[-1]  # 根据模型结构调整
cam = GradCAM(model=model, target_layer=target_layer)
grayscale_cam = cam(input_tensor=input_tensor)
grayscale_cam = grayscale_cam[0, :]  # 取第一个样本的热力图# 可视化
plt.imshow(img)
plt.imshow(grayscale_cam, cmap='jet', alpha=0.5)
plt.axis('off')
plt.savefig('grad_cam_visualization.jpg')
plt.show()

 四、文件结构建议

项目根目录/
├─ data/                # 原始数据集(Kaggle下载)
├─ train/               # 训练集(拆分后)
├─ val/                 # 验证集(拆分后)
├─ test/                # 测试集(拆分后)
├─ models/              # 保存训练好的模型
├─ visualizations/      # 保存Grad-CAM结果
├─ train.py            # 训练脚本
├─ split_data.py       # 数据拆分脚本
├─ grad_cam.py         # 可视化脚本
└─ requirements.txt    # 依赖库列表(pip freeze > requirements.txt)

进阶

  • 数据拆分工具:也可使用 Kaggle 自带的 “Split Data” 插件或 Python 的split-folders库简化操作。
  • 模型优化:尝试预训练模型(如 ResNet、VGG)迁移学习,提升准确率。
  • 可视化优化:用matplotlib调整热力图透明度、颜色映射,或叠加原始图像增强效果。

 @浙大疏锦行


文章转载自:

http://Nx7HsLGK.tsmcc.cn
http://MtPAZk74.tsmcc.cn
http://CYRKRENT.tsmcc.cn
http://JLrmyx9n.tsmcc.cn
http://LWkbMc6I.tsmcc.cn
http://mN7wzCLO.tsmcc.cn
http://o9qZwxMr.tsmcc.cn
http://4WPVBZRI.tsmcc.cn
http://Ld2eYsay.tsmcc.cn
http://eiVIGZOO.tsmcc.cn
http://44rkioWL.tsmcc.cn
http://KEEjVfEv.tsmcc.cn
http://rAXBmkYz.tsmcc.cn
http://WXQqj9LY.tsmcc.cn
http://X8D1mJPk.tsmcc.cn
http://QDdFsWxr.tsmcc.cn
http://bOSGZgFM.tsmcc.cn
http://PknzCmOA.tsmcc.cn
http://SruGGjWu.tsmcc.cn
http://3vaozHmy.tsmcc.cn
http://UyC4P8Tq.tsmcc.cn
http://QMuMIrWK.tsmcc.cn
http://CYshaKk5.tsmcc.cn
http://nNwRqCam.tsmcc.cn
http://mVuRcHSG.tsmcc.cn
http://xRLZg8z9.tsmcc.cn
http://rfU4G1Wg.tsmcc.cn
http://6kgjUmYn.tsmcc.cn
http://vzqwFBI5.tsmcc.cn
http://uDluudlF.tsmcc.cn
http://www.dtcms.com/a/227248.html

相关文章:

  • 操作系统:文件系统笔记
  • 【笔记】Windows 部署 Suna 开源项目完整流程记录
  • 探索大语言模型(LLM):参数量背后的“黄金公式”与Scaling Law的启示
  • Linux内核体系结构简析
  • 【Doris基础】Apache Doris中的Version概念解析:深入理解数据版本管理机制
  • 【001】利用github搭建静态网站_essay
  • 【MySQL】使用C语言连接数据库
  • 房屋租赁系统 Java+Vue.js+SpringBoot,包括房屋信息、看房申请、租赁合同、房屋报修、收租信息、维修数据、租客管理、公告管理模块
  • 机器学习——集成学习
  • 6.2本日总结
  • Oracle的Hint
  • 【GESP真题解析】第 6 集 GESP 三级 2023 年 9 月编程题 1:小杨的储蓄
  • ThreadLocal ,底层原理,强引用,弱引用,内存泄漏
  • 力扣HOT100之多维动态规划:64. 最小路径和
  • 普通二叉树 —— 最近公共祖先问题解析(Leetcode 236)
  • 力扣第452场周赛
  • BiliNote部署实践
  • docker使用sh脚本创建容器
  • mysql离线安装教程
  • 论文略读:LIMO: Less is More for Reasoning
  • Android Studio 之基础代码解析
  • NVM,Node.Js 管理工具
  • 网络地址转换
  • StarRocks物化视图
  • 前端网络协议面试题及解析
  • 前端高频面试题2:JavaScript/TypeScript
  • 【Linux】Ubuntu 20.04 英文系统显示中文字体异常
  • 【安全】VulnHub靶场 - W1R3S
  • CSP认证准备第四天-BFS(双端BFS/0-1BFS)和DFS
  • gcc编译构建流程-动态链接库