python打卡day43@浙大疏锦行
作业:
kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化
进阶:并拆分成多个文件
一、配置文件 (config.py)
import torchclass Config:# 数据集配置DATASET_PATH = "/path/to/kaggle/dataset"IMAGE_SIZE = 224BATCH_SIZE = 32# 模型配置NUM_CLASSES = 10PRETRAINED = True# 训练配置EPOCHS = 10LR = 0.001DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
二、数据加载 (dataset.py)
from torchvision import transforms, datasets
from config import Configdef get_dataloaders():train_transform = transforms.Compose([transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])test_transform = transforms.Compose([transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_set = datasets.ImageFolder(f"{Config.DATASET_PATH}/train",transform=train_transform)test_set = datasets.ImageFolder(f"{Config.DATASET_PATH}/test",transform=test_transform)return train_set, test_set
三、CNN模型定义 (model.py)
import torch.nn as nn
from torchvision import models
from config import Configclass CNNModel(nn.Module):def __init__(self):super().__init__()base_model = models.resnet18(pretrained=Config.PRETRAINED)num_features = base_model.fc.in_featuresbase_model.fc = nn.Linear(num_features, Config.NUM_CLASSES)self.model = base_modeldef forward(self, x):return self.model(x)def get_feature_maps(self):return self.model.layer4[-1].conv2
四、训练脚本 (train.py)
import torch
from torch.utils.data import DataLoader
from dataset import get_dataloaders
from model import CNNModel
from config import Config
from utils import save_checkpointdef train():train_set, test_set = get_dataloaders()train_loader = DataLoader(train_set, batch_size=Config.BATCH_SIZE, shuffle=True)test_loader = DataLoader(test_set, batch_size=Config.BATCH_SIZE)model = CNNModel().to(Config.DEVICE)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=Config.LR)for epoch in range(Config.EPOCHS):model.train()for inputs, labels in train_loader:inputs, labels = inputs.to(Config.DEVICE), labels.to(Config.DEVICE)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 验证代码...save_checkpoint(model, epoch)if __name__ == "__main__":train()
五、Grad-CAM实现 (gradcam.py)
import torch
import numpy as np
import cv2
import matplotlib.pyplot as pltclass GradCAM:def __init__(self, model, target_layer):self.model = modelself.gradients = Noneself.activations = Nonetarget_layer.register_forward_hook(self.save_activations)target_layer.register_backward_hook(self.save_gradients)def save_activations(self, module, input, output):self.activations = output.detach()def save_gradients(self, module, grad_input, grad_output):self.gradients = grad_output[0].detach()def __call__(self, x, class_idx=None):# 前向传播output = self.model(x)if class_idx is None:class_idx = torch.argmax(output, dim=1)# 反向传播self.model.zero_grad()one_hot = torch.zeros_like(output)one_hot[0][class_idx] = 1output.backward(gradient=one_hot)# 计算权重weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)cam = torch.sum(self.activations * weights, dim=1)cam = torch.relu(cam)# 后处理cam = cam.squeeze().cpu().numpy()cam = cv2.resize(cam, (x.shape[3], x.shape[2]))cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))return camdef visualize_gradcam(model, image_tensor, original_image):target_layer = model.get_feature_maps()gradcam = GradCAM(model, target_layer)cam = gradcam(image_tensor.unsqueeze(0).to(Config.DEVICE))heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)heatmap = np.float32(heatmap) / 255superimposed_img = heatmap + np.float32(original_image)superimposed_img = superimposed_img / np.max(superimposed_img)plt.imshow(superimposed_img)plt.axis('off')plt.show()
六、工具函数 (utils.py)
import torch
import osdef save_checkpoint(model, epoch, path="checkpoints"):if not os.path.exists(path):os.makedirs(path)torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),}, f"{path}/checkpoint_{epoch}.pth")def load_checkpoint(model, path):checkpoint = torch.load(path)model.load_state_dict(checkpoint['model_state_dict'])return model