当前位置: 首页 > news >正文

打卡Day51

day43的时候我们安排大家对自己找的数据集用简单cnn训练,现在可以尝试下借助这几天的知识来实现精度的进一步提高

Kaggle图像分类与可视化方案

一、数据准备(修改 src/data/preprocessing.py )

# ... existing code ...
def create_dataloader(data_path, batch_size=32):transform = transforms.Compose([transforms.Resize((256, 256)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),  # 新增数据增强transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])dataset = datasets.ImageFolder(data_path, transform=transform)return DataLoader(dataset, batch_size=batch_size, shuffle=True)

二、高效CNN模型(修改 src/models/train.py )

from efficientnet_pytorch import EfficientNet  # 需安装 pip install efficientnet-pytorchclass CustomEfficientNet(nn.Module):def __init__(self, num_classes):super().__init__()self.base = EfficientNet.from_pretrained('efficientnet-b3')self.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(1536, num_classes))def forward(self, x):features = self.base.extract_features(x)return self.classifier(nn.functional.adaptive_avg_pool2d(features, 1).squeeze())

三、Grad-CAM可视化(新增 src/visualization/gradcam.py )

class GradCAM:def __init__(self, model, target_layer):self.model = model.eval()self.target_layer = target_layerself.activations = []self.gradients = []target_layer.register_forward_hook(self.save_activation)target_layer.register_backward_hook(self.save_gradient)def save_activation(self, module, input, output):self.activations.append(output)def save_gradient(self, module, grad_input, grad_output):self.gradients.append(grad_output[0])def generate(self, input_img, class_idx=None):# ... 完整实现见项目文件 ...
http://www.dtcms.com/a/244380.html

相关文章:

  • force命令的使用
  • 青藏高原地区多源融合降水数据(1998-2017)
  • 【Unity3D优化】优化多语言字体包大小
  • NuGet 从入门到精进全解析
  • Transformers KV Caching 图解
  • h5fortran 简介与使用指南
  • vue前端面试题——记录一次面试当中遇到的题(1)
  • 冒险岛的魔法果实-多重背包
  • 关于有害的过度使用 std::move
  • SCADA|测试KingSCADA4.0信创版采集汇川PLC AC810数据
  • python学习打卡day50
  • A. Dr. TC
  • RPG24.设置武器伤害(二):将效果应用于目标
  • RabbitMQ可靠和延迟队列
  • 接收rabbitmq消息
  • 中心化交易所(CEX)架构:高并发撮合引擎与合规安全体系
  • [蓝桥杯 2024 国 Python B] 设计
  • TripGenie:畅游济南旅行规划助手:个人工作纪实(二十四)
  • Arduino入门教程:1、Arduino硬件介绍
  • LAN、WAN、WLAN、VLAN 、VPN对比
  • Java异步编程深度解析:从基础到复杂场景的难题拆解
  • 动态多目标进化算法:VARE(Vector Autoregressive Evolution)求解DF1-DF14,提供完整MATLAB代码
  • [服务器] Amazon Lightsail SSH连接黑屏的常见原因及解决方案
  • 曼昆《经济学原理》第九版 第十七章寡头垄断
  • 【leetcode】36. 有效的数独
  • 【Axure高保真原型】中继器表格更多操作
  • API:解锁数字化协作的钥匙及开放实现路径深度剖析
  • 产品升级 | 新一代高性能数据采集平台BRICK2 X11,助力ADAS与自动驾驶开发
  • 【AI】模型vs算法(以自动驾驶为例)
  • RPA与Agent技术如何结合,以实现跨系统、跨平台的工作流程自动化?