当前位置: 首页 > news >正文

day43 CNN及Grad-CAM实战——以Intel Image Classification为例

步骤一:数据准备

📌 基本信息
名称:Intel Image Classification Dataset

来源:由 Intel 提供,广泛用于图像分类任务入门(原始来源:Kaggle)

任务类型:多类别图像分类(多类场景识别)

图像数量:25,000+ 张图像

图像尺寸:150×150 彩色 JPG 图像

类别数量:6 类

🏷️ 类别(共 6 类)
每一张图片都属于以下六类之一,每类对应一个子文件夹:

类别名称 内容描述
buildings 各种建筑场景(如城市楼宇)
forest 林地、森林风景图
glacier 冰川、雪地场景
mountain 山地或丘陵环境
sea 海洋、水体场景
street 城市街景、人行道、公路等

这些类别大多属于自然或人造景观,非常适合做场景识别模型实验。

📁 数据结构
通常包含以下三个子文件夹(可能因版本不同略有差异):

Intel Image Classification/
├── seg_train/ ← 训练集(每类一个子文件夹)
├── seg_test/ ← 测试集(结构同上)
├── seg_pred/ 或 prediction/ ← 预测集(未标注图像,用于模型测试)

子结构示例(以训练集为例):

seg_train/
├── buildings/
│ ├── img_1.jpg
│ └── …
├── forest/
├── glacier/
├── mountain/
├── sea/
└── street/
📊 图像数量示例(以 Kaggle 数据集为例):
数据集部分
训练集 2,000 张
测试集 500 张
预测集 无标签

from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 路径适配
train_dir = "./data/Intel Image Classification/seg_train"
test_dir = "./data/Intel Image Classification/seg_test"
# 图像预处理
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])# 加载训练与测试集
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
test_dataset = datasets.ImageFolder(test_dir, transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

CNN训练代码(此处以resnet18为例)

import torch
import torch.nn as nn
from torchvision import modelsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(train_dataset.classes))
model = model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)# 训练循环
for epoch in range(5):model.train()running_loss = 0for imgs, labels in train_loader:imgs, labels = imgs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(imgs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader):.4f}")
d:\Anaconda\envs\DL\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
d:\Anaconda\envs\DL\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\Administrator/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:04<00:00, 10.8MB/s]Epoch 1, Loss: 0.2814
Epoch 2, Loss: 0.1280
Epoch 3, Loss: 0.0626
Epoch 4, Loss: 0.0495
Epoch 5, Loss: 0.0345

Grad-CAM可视化(可视化 /data/Intel Image Classification/seg_test/buildings 中每个类别的随机任一图片)

import os
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_imagedef show_gradcam_side_by_side_classes(model,folder_path,target_layers,class_names=None,device="cuda" if torch.cuda.is_available() else "cpu"
):model.eval()transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])all_classes = class_names if class_names else os.listdir(folder_path)cam = GradCAM(model=model, target_layers=target_layers)for cls in all_classes:cls_path = os.path.join(folder_path, cls)image_files = [f for f in os.listdir(cls_path) if f.lower().endswith((".jpg", ".jpeg", ".png"))]img_file = random.choice(image_files)img_path = os.path.join(cls_path, img_file)print(f"\n▶️ 类别: {cls} | 图像: {img_file}")image = Image.open(img_path).convert("RGB")input_tensor = transform(image).unsqueeze(0).to(device)output = model(input_tensor)pred_class = output.argmax().item()grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(pred_class)])grayscale_cam = grayscale_cam[0, :]rgb_img = np.array(image.resize((224, 224))) / 255.0visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)# 横向并排显示fig, axes = plt.subplots(1, 2, figsize=(10, 4))axes[0].imshow(rgb_img)axes[0].set_title(f"[原图] 类别: {cls}")axes[0].axis('off')axes[1].imshow(visualization)axes[1].set_title(f"[Grad-CAM] 预测: {class_names[pred_class]}")axes[1].axis('off')plt.tight_layout()plt.show()show_gradcam_side_by_side_classes(model=model,folder_path=r".\data\Intel Image Classification\seg_test",target_layers=[model.layer4[-1]],class_names=train_dataset.classes
)
▶️ 类别: buildings | 图像: 21496.jpg

在这里插入图片描述

▶️ 类别: forest | 图像: 20328.jpg

在这里插入图片描述

▶️ 类别: glacier | 图像: 23448.jpg

在这里插入图片描述

▶️ 类别: mountain | 图像: 22117.jpg

在这里插入图片描述

▶️ 类别: sea | 图像: 20172.jpg

在这里插入图片描述

▶️ 类别: street | 图像: 20572.jpg

在这里插入图片描述

http://www.dtcms.com/a/287249.html

相关文章:

  • JAVA中的Collections 类
  • [论文阅读] 人工智能 + 软件工程 | 强化学习在软件工程中的全景扫描:从应用到未来
  • ABP VNext + Temporal:分布式工作流与 Saga
  • 当OT遇见IT:Apache IoTDB如何用“时序空间一体化“破解工业物联网数据孤岛困局
  • 时序数据库选型实战:Apache IoTDB技术深度解析
  • Bicep入门篇
  • 如何解决pip安装报错ModuleNotFoundError: No module named ‘pillow’问题
  • C/C++---文件读取
  • kotlin部分常用特性总结
  • Node.js net.Socket.destroy()深入解析
  • 海思3516cv610 NPU学习
  • 【C语言进阶】题目练习(3)
  • kafka--基础知识点--6.1--LEO、HW、LW
  • Validation - Spring Boot项目中参数检验的利器
  • web.m3u8流媒体视频处理
  • Flutter基础(前端教程①③-单例)
  • 定时器与间歇函数
  • Web3.0与元宇宙:区块链驱动的数字新生态解析
  • 【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - snowNLP库实现中文情感分析
  • 如何增强LLM(大语言模型)的“置信度”和“自信心” :LLM的“自信”不是“什么都能答”,而是“该答的答得准,不该答的敢说不”。
  • 【unity游戏开发入门到精通——3D篇】3D光源之——unity使用Lens Flare (SRP) 组件实现太阳耀斑镜头光晕效果
  • 《Origin画百图》之多分类矩阵散点图
  • 2025最新版 Go语言Goland 专业安装及配置(超详细)
  • 华为仓颉编程语言语法简介与示例
  • 从0开始学习R语言--Day51--PH检验
  • 操作系统-分布式同步
  • 【REACT18.x】creat-react-app在添加eslint时报错Environment key “jest/globals“ is unknown
  • Spring AI 项目实战(十九):Spring Boot + AI + Vue3 + OSS + DashScope 构建多模态视觉理解平台(附完整源码)
  • 在 .NET Core 中创建 Web Socket API
  • Redis 如何保证高并发与高可用