当前位置: 首页 > news >正文

python打卡day43@浙大疏锦行

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

一、配置文件 (config.py)

import torchclass Config:# 数据集配置DATASET_PATH = "/path/to/kaggle/dataset"IMAGE_SIZE = 224BATCH_SIZE = 32# 模型配置NUM_CLASSES = 10PRETRAINED = True# 训练配置EPOCHS = 10LR = 0.001DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

二、数据加载 (dataset.py)

from torchvision import transforms, datasets
from config import Configdef get_dataloaders():train_transform = transforms.Compose([transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])test_transform = transforms.Compose([transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_set = datasets.ImageFolder(f"{Config.DATASET_PATH}/train",transform=train_transform)test_set = datasets.ImageFolder(f"{Config.DATASET_PATH}/test",transform=test_transform)return train_set, test_set

三、CNN模型定义 (model.py)

import torch.nn as nn
from torchvision import models
from config import Configclass CNNModel(nn.Module):def __init__(self):super().__init__()base_model = models.resnet18(pretrained=Config.PRETRAINED)num_features = base_model.fc.in_featuresbase_model.fc = nn.Linear(num_features, Config.NUM_CLASSES)self.model = base_modeldef forward(self, x):return self.model(x)def get_feature_maps(self):return self.model.layer4[-1].conv2

四、训练脚本 (train.py)

import torch
from torch.utils.data import DataLoader
from dataset import get_dataloaders
from model import CNNModel
from config import Config
from utils import save_checkpointdef train():train_set, test_set = get_dataloaders()train_loader = DataLoader(train_set, batch_size=Config.BATCH_SIZE, shuffle=True)test_loader = DataLoader(test_set, batch_size=Config.BATCH_SIZE)model = CNNModel().to(Config.DEVICE)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=Config.LR)for epoch in range(Config.EPOCHS):model.train()for inputs, labels in train_loader:inputs, labels = inputs.to(Config.DEVICE), labels.to(Config.DEVICE)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 验证代码...save_checkpoint(model, epoch)if __name__ == "__main__":train()

五、Grad-CAM实现 (gradcam.py)

import torch
import numpy as np
import cv2
import matplotlib.pyplot as pltclass GradCAM:def __init__(self, model, target_layer):self.model = modelself.gradients = Noneself.activations = Nonetarget_layer.register_forward_hook(self.save_activations)target_layer.register_backward_hook(self.save_gradients)def save_activations(self, module, input, output):self.activations = output.detach()def save_gradients(self, module, grad_input, grad_output):self.gradients = grad_output[0].detach()def __call__(self, x, class_idx=None):# 前向传播output = self.model(x)if class_idx is None:class_idx = torch.argmax(output, dim=1)# 反向传播self.model.zero_grad()one_hot = torch.zeros_like(output)one_hot[0][class_idx] = 1output.backward(gradient=one_hot)# 计算权重weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)cam = torch.sum(self.activations * weights, dim=1)cam = torch.relu(cam)# 后处理cam = cam.squeeze().cpu().numpy()cam = cv2.resize(cam, (x.shape[3], x.shape[2]))cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))return camdef visualize_gradcam(model, image_tensor, original_image):target_layer = model.get_feature_maps()gradcam = GradCAM(model, target_layer)cam = gradcam(image_tensor.unsqueeze(0).to(Config.DEVICE))heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)heatmap = np.float32(heatmap) / 255superimposed_img = heatmap + np.float32(original_image)superimposed_img = superimposed_img / np.max(superimposed_img)plt.imshow(superimposed_img)plt.axis('off')plt.show()

六、工具函数 (utils.py)

import torch
import osdef save_checkpoint(model, epoch, path="checkpoints"):if not os.path.exists(path):os.makedirs(path)torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),}, f"{path}/checkpoint_{epoch}.pth")def load_checkpoint(model, path):checkpoint = torch.load(path)model.load_state_dict(checkpoint['model_state_dict'])return model

相关文章:

  • 软件开发项目管理工具选型及禅道开源版安装
  • 从0开始学vue:vue3和vue2的关系
  • 《信号与系统》--期末总结V1.0
  • 【算法训练营Day05】哈希表part1
  • 逐步检索增强推理的跨知识库路由学习
  • Ubuntu22.04 安装 CUDA12.8
  • 类和对象:实现日期类
  • MATLAB 安装与使用详细教程
  • gcc符号表生成机制
  • 【位运算】只出现⼀次的数字 II(medium)
  • 【latex】易遗忘的表达
  • esp32 platformio lvgl_gif的使用和踩坑情况
  • Qt OpenGL 3D 编程入门
  • 2 Studying《Effective STL》
  • 使用ArcPy批量处理矢量数据
  • inux系统基本操作命令(系统信息查看)
  • MyBatis04:SpringBoot整合MyBatis——多表关联|延迟加载|MyBatisX插件|SQL注解
  • Linux 基础指令入门指南:解锁命令行的实用密码
  • 常见 Web 安全问题
  • MySQL中的锁
  • aspcms做双语网站修改配置/网站推广的常用方法
  • 左侧导航栏网站模板/360推广助手
  • pos机网站模板/抖音seo排名软件
  • 网站正在建设中的图片大全/网站优化排名的方法
  • 网站3d展示怎么做的/推广注册app赚钱平台
  • 微信怎么建小网站/域名查询站长之家