当前位置: 首页 > news >正文

Python打卡第51天

@浙大疏锦行

作业:

day43的时候我们安排大家对自己找的数据集用简单cnn训练,现在可以尝试下借助这几天的知识来实现精度的进一步提高

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import torch.nn.functional as F
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import random# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 数据集路径
data_dir = r"D:\archive (1)\MY_data"# 数据预处理和增强
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_transform)# 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)# 获取类别名称
classes = train_dataset.classes
print(f"类别: {classes}")# CBAM注意力机制实现
class ChannelAttention(nn.Module):def __init__(self, in_channels, reduction_ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x_cat = torch.cat([avg_out, max_out], dim=1)out = self.conv(x_cat)return self.sigmoid(out)class CBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):super(CBAM, self).__init__()self.channel_attention = ChannelAttention(in_channels, reduction_ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):x = x * self.channel_attention(x)x = x * self.spatial_attention(x)return x# 定义改进的CNN模型(支持多种预训练模型和CBAM注意力机制)
class EnhancedFruitClassifier(nn.Module):def __init__(self, num_classes=10, model_name='resnet18', use_cbam=True):super(EnhancedFruitClassifier, self).__init__()self.use_cbam = use_cbam# 根据选择加载不同的预训练模型if model_name == 'resnet18':self.model = models.resnet18(pretrained=True)in_features = self.model.fc.in_features# 保存原始层以便后续使用self.features = nn.Sequential(*list(self.model.children())[:-2])self.avgpool = self.model.avgpoolelif model_name == 'resnet50':self.model = models.resnet50(pretrained=True)in_features = self.model.fc.in_featuresself.features = nn.Sequential(*list(self.model.children())[:-2])self.avgpool = self.model.avgpoolelif model_name == 'efficientnet_b0':self.model = models.efficientnet_b0(pretrained=True)in_features = self.model.classifier[1].in_featuresself.features = nn.Sequential(*list(self.model.children())[:-1])self.avgpool = nn.AdaptiveAvgPool2d(1)else:raise ValueError(f"不支持的模型: {model_name}")# 冻结大部分预训练层for param in list(self.model.parameters())[:-5]:param.requires_grad = False# 添加CBAM注意力机制if use_cbam:self.cbam = CBAM(in_features)# 修改最后一层以适应我们的分类任务self.fc = nn.Linear(in_features, num_classes)def forward(self, x):# 特征提取x = self.features(x)# 应用CBAM注意力机制if self.use_cbam:x = self.cbam(x)# 全局池化x = self.avgpool(x)x = torch.flatten(x, 1)# 分类x = self.fc(x)return x# 初始化模型 - 可以选择不同的预训练模型和是否使用CBAM
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = EnhancedFruitClassifier(num_classes=len(classes),model_name='resnet18',  # 可选: 'resnet18', 'resnet50', 'efficientnet_b0'use_cbam=True
).to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)# 训练模型
def train_model(model, train_loader, criterion, optimizer, scheduler, device, epochs=10):model.train()for epoch in range(epochs):running_loss = 0.0correct = 0total = 0progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))for i, (inputs, labels) in progress_bar:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()progress_bar.set_description(f"Epoch {epoch+1}/{epochs}, "f"Loss: {running_loss/(i+1):.4f}, "f"Acc: {100.*correct/total:.2f}%")scheduler.step()print(f"Epoch {epoch+1}/{epochs}, "f"Train Loss: {running_loss/len(train_loader):.4f}, "f"Train Acc: {100.*correct/total:.2f}%")return model# 评估模型
def evaluate_model(model, test_loader, device):model.eval()correct = 0total = 0class_correct = list(0. for i in range(len(classes)))class_total = list(0. for i in range(len(classes)))with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()# 计算每个类别的准确率for i in range(len(labels)):label = labels[i]class_correct[label] += (predicted[i] == label).item()class_total[label] += 1print(f"测试集整体准确率: {100.*correct/total:.2f}%")# 打印每个类别的准确率for i in range(len(classes)):if class_total[i] > 0:print(f"{classes[i]} 类别的准确率: {100.*class_correct[i]/class_total[i]:.2f}%")else:print(f"{classes[i]} 类别的样本数为0")return 100.*correct/total# Grad-CAM实现
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子self.hook_handles = []# 保存梯度的钩子def backward_hook(module, grad_in, grad_out):self.gradients = grad_out[0]return None# 保存激活值的钩子def forward_hook(module, input, output):self.activations = outputreturn Noneself.hook_handles.append(target_layer.register_forward_hook(forward_hook))self.hook_handles.append(target_layer.register_backward_hook(backward_hook))def __call__(self, x, class_idx=None):# 前向传播model_output = self.model(x)if class_idx is None:class_idx = torch.argmax(model_output, dim=1)# 构建one-hot向量one_hot = torch.zeros_like(model_output)one_hot[0, class_idx] = 1# 反向传播self.model.zero_grad()model_output.backward(gradient=one_hot, retain_graph=True)# 计算权重(全局平均池化梯度)weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)# 加权组合激活映射cam = torch.sum(weights * self.activations, dim=1).squeeze()# ReLU激活,因为我们只关心对类别有正贡献的区域cam = F.relu(cam)# 归一化if torch.max(cam) > 0:cam = cam / torch.max(cam)# 调整大小到输入图像尺寸cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False).squeeze()return cam.detach().cpu().numpy(), class_idx.item()def remove_hooks(self):for handle in self.hook_handles:handle.remove()# 可视化Grad-CAM结果
def visualize_gradcam(img_path, model, target_layer, classes, device):# 加载并预处理图像img = Image.open(img_path).convert('RGB')img_tensor = test_transform(img).unsqueeze(0).to(device)# 初始化Grad-CAMgrad_cam = GradCAM(model, target_layer)# 获取Grad-CAM热力图cam, pred_class = grad_cam(img_tensor)# 反归一化图像以便显示img_np = img_tensor.squeeze().cpu().numpy().transpose((1, 2, 0))img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])img_np = np.clip(img_np, 0, 1)# 调整热力图大小heatmap = cv2.resize(cam, (img_np.shape[1], img_np.shape[0]))# 创建彩色热力图heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)heatmap = np.float32(heatmap) / 255# 叠加原始图像和热力图superimposed_img = heatmap * 0.4 + img_npsuperimposed_img = np.clip(superimposed_img, 0, 1)# 显示结果plt.figure(figsize=(15, 5))plt.subplot(131)plt.imshow(img_np)plt.title('原始图像')plt.axis('off')plt.subplot(132)plt.imshow(cam, cmap='jet')plt.title('Grad-CAM热力图')plt.axis('off')plt.subplot(133)plt.imshow(superimposed_img)plt.title(f'叠加图像\n预测类别: {classes[pred_class]}')plt.axis('off')plt.tight_layout()plt.show()# 预测函数
def predict_image(img_path, model, classes, device):# 加载并预处理图像img = Image.open(img_path).convert('RGB')img_tensor = test_transform(img).unsqueeze(0).to(device)# 预测model.eval()with torch.no_grad():outputs = model(img_tensor)probs = F.softmax(outputs, dim=1)top_probs, top_classes = probs.topk(5, dim=1)# 打印预测结果print(f"图像: {os.path.basename(img_path)}")print("预测结果:")for i in range(top_probs.size(1)):print(f"{classes[top_classes[0, i]]}: {top_probs[0, i].item() * 100:.2f}%")return top_classes[0, 0].item()# 主函数
def main():# 训练模型print("开始训练模型...")trained_model = train_model(model, train_loader, criterion, optimizer, scheduler, device, epochs=5)# 评估模型print("\n评估模型...")evaluate_model(trained_model, test_loader, device)# 保存模型model_path = "fruit_classifier.pth"torch.save(trained_model.state_dict(), model_path)print(f"\n模型已保存至: {model_path}")# 可视化Grad-CAM结果print("\n可视化Grad-CAM结果...")# 从测试集中随机选择几张图像进行可视化predict_dir = os.path.join(data_dir, 'predict')if os.path.exists(predict_dir):# 使用predict目录中的图像image_files = [os.path.join(predict_dir, f) for f in os.listdir(predict_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]if len(image_files) > 0:# 随机选择2张图像sample_images = random.sample(image_files, min(2, len(image_files)))for img_path in sample_images:print(f"\n处理图像: {img_path}")# 预测图像类别pred_class = predict_image(img_path, trained_model, classes, device)# 可视化Grad-CAMif hasattr(trained_model, 'model') and hasattr(trained_model.model, 'layer4'):# 对于ResNet系列模型visualize_gradcam(img_path, trained_model, trained_model.model.layer4[-1].conv2, classes, device)else:# 对于其他模型,使用最后一个特征层visualize_gradcam(img_path, trained_model, list(trained_model.features.children())[-1], classes, device)else:print(f"predict目录为空,无法进行可视化")else:print(f"predict目录不存在,无法进行可视化")if __name__ == "__main__":main()
类别: ['Apple', 'Banana', 'avocado', 'cherry', 'kiwi', 'mango', 'orange', 'pinenapple', 'strawberries', 'watermelon']
开始训练模型...
Epoch 1/5, Loss: 0.8748, Acc: 74.23%: 100%|██████████| 72/72 [00:08<00:00,  8.66it/s]
Epoch 1/5, Train Loss: 0.8748, Train Acc: 74.23%
Epoch 2/5, Loss: 0.4802, Acc: 83.83%: 100%|██████████| 72/72 [00:07<00:00, 10.02it/s]
Epoch 2/5, Train Loss: 0.4802, Train Acc: 83.83%
Epoch 3/5, Loss: 0.4239, Acc: 86.35%: 100%|██████████| 72/72 [00:07<00:00,  9.69it/s]
Epoch 3/5, Train Loss: 0.4239, Train Acc: 86.35%
Epoch 4/5, Loss: 0.4179, Acc: 85.96%: 100%|██████████| 72/72 [00:07<00:00,  9.64it/s]
Epoch 4/5, Train Loss: 0.4179, Train Acc: 85.96%
Epoch 5/5, Loss: 0.3747, Acc: 87.44%: 100%|██████████| 72/72 [00:07<00:00,  9.68it/s]
Epoch 5/5, Train Loss: 0.3747, Train Acc: 87.44%评估模型...
测试集整体准确率: 66.83%
Apple 类别的准确率: 80.90%
Banana 类别的准确率: 0.00%
avocado 类别的准确率: 1.89%
cherry 类别的准确率: 93.33%
kiwi 类别的准确率: 93.33%
mango 类别的准确率: 48.57%
orange 类别的准确率: 97.94%
pinenapple 类别的准确率: 96.19%
strawberries 类别的准确率: 90.29%
watermelon 类别的准确率: 71.43%模型已保存至: fruit_classifier.pth可视化Grad-CAM结果...处理图像: D:\archive (1)\MY_data\predict\img_341.jpeg
图像: img_341.jpeg
预测结果:
mango: 90.52%
orange: 3.99%
kiwi: 2.45%
avocado: 1.86%
Apple: 0.98%

处理图像: D:\archive (1)\MY_data\predict\1.jpeg
图像: 1.jpeg
预测结果:
Apple: 95.86%
cherry: 2.94%
Banana: 0.71%
avocado: 0.24%
strawberries: 0.19%

相关文章:

  • 文献管理软件EndNote下载与安装教程(详细教程)2025最新版详细图文安装教程
  • MySQL查看连接情况
  • 力扣-347.前K个高频元素
  • (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为使用深度强化学习和模仿学习训练智能代理的环境
  • 建造者模式(Builder Pattern)
  • Go 通道(Channel)入门与基础使用
  • ZZU-ARM汇编语言实验2
  • 41页PPT | 基于AI制造企业解决方案架构设计智能制造AI人工智能应用智能质检人工智能质检建设
  • 在C# 中使用建造者模式
  • Spring cloud-k8s容器化部署
  • 同步与异步:软件工程中的时空艺术与实践智慧-以蜻蜓hr人才系统举例-优雅草卓伊凡
  • 记录rust滥用lazy_static导致的一个bug
  • 论文笔记 - 《Implementing block-sparse matrix multiplication kernels using Triton》
  • Linux【7】------Linux系统编程(进程间通信IPC)
  • docker-compose和docker下载
  • mysql DQL(javaweb第七天)
  • 博客:基本框架设计(下)
  • 搭建第一个 Vite 项目
  • 【读论文】DiffPhyCon 扩散物理系统控制
  • 【Django】性能优化-普通版
  • 东莞wordpress/seo在线论坛
  • 登录注册网站怎么做/网推项目平台
  • wordpress 本地头像/小程序seo
  • 做网站怎么做放大图片/网站优化是什么
  • 做企业网站软件/2021年度关键词有哪些
  • 自己架设网站服务器/2021年新闻摘抄