当前位置: 首页 > news >正文

从代码学习深度学习 - 含并行连结的网络(GoogLeNet)PyTorch版

文章目录

  • 前言
  • 一、GoogLeNet的理论基础
    • 1.1 背景与创新点
    • 1.2. Inception模块的工作原理
  • 二、完整代码实现与解析
    • 2.1. 环境准备与工具函数
    • 2.2. 数据加载 - Fashion-MNIST
    • 2.3. Inception模块设计
    • 2.4. GoogLeNet完整模型
    • 2.5. 训练函数
    • 2.6. 运行训练
  • 三、训练结果与分析
    • 3.1. 性能分析
    • 3.2. 可视化结果
    • 3.3. 模型局限性
  • 四、扩展与改进建议
  • 总结


前言

深度学习近年来在计算机视觉、自然语言处理等领域取得了巨大成功,而卷积神经网络(CNN)作为其核心支柱之一,推动了许多突破性应用。GoogLeNet(Inception v1)是2014年ImageNet挑战赛(ILSVRC)的冠军模型,以其创新的Inception模块和高效设计脱颖而出。它不仅在性能上超越了当时的经典模型(如AlexNet和VGG),还在参数量和计算复杂度上实现了优化。

本文将通过PyTorch实现一个简化的GoogLeNet版本,并结合完整的代码和详细解析,帮助读者从实践中掌握深度学习的核心概念。我们将从GoogLeNet的理论基础讲起,逐步剖析代码实现,最后在Fashion-MNIST数据集上进行训练和结果分析。这不仅是一次从理论到实践的学习之旅,也是一个理解现代CNN设计思想的机会。


一、GoogLeNet的理论基础

1.1 背景与创新点

在GoogLeNet之前,CNN模型(如AlexNet、VGG)倾向于通过加深网络层数或增加卷积核大小来提升性能,但这往往导致参数量激增和计算资源浪费。GoogLeNet提出了一个全新的思路:通过多尺度特征提取计算效率优化来提升性能。其核心创新是Inception模块,它通过并行使用不同大小的卷积核(1x1、3x3、5x5)和池化操作,在单一层内捕获多种尺度的特征。

此外,GoogLeNet引入了以下关键技术:

  • 1x1卷积:用于降维,减少通道数,从而降低计算量。
  • 全局平均池化:替换传统全连接层,减少参数量并增强泛化能力。
  • 辅助分类器:在网络中间添加分支,增强梯度传播,缓解深层网络训练中的梯度消失问题。

原始GoogLeNet有22层,但参数量仅为VGG-19的1/12,展现了其高效性。在本文中,我们将实现一个简化版,如下图所示,专注于Inception模块,并省略辅助分类器,以适应Fashion-MNIST数据集的较小规模。
在这里插入图片描述

1.2. Inception模块的工作原理

Inception模块的核心思想是“多路径并行”,如下图所示。它通过以下四条路径提取特征:
在这里插入图片描述

  1. 1x1卷积:直接提取局部特征,降低计算成本。
  2. 1x1卷积 + 3x3卷积:先降维再进行中等尺度卷积。
  3. 1x1卷积 + 5x5卷积:先降维再捕获更大范围特征。
  4. 3x3最大池化 + 1x1卷积:保留空间信息并调整通道数。

这些路径的输出最后沿通道维度拼接,形成一个更丰富的特征表示。这种设计既增加了网络宽度(width),又避免了单纯加深网络带来的过拟合风险。

二、完整代码实现与解析

以下是完整的PyTorch代码实现,包括所有工具函数、数据加载、模型定义和训练逻辑。

2.1. 环境准备与工具函数

我们首先定义一些工具类,用于计时、累加指标和计算准确率。这些工具在深度学习实验中非常实用。

import time
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import multiprocessing

class Timer:
    """记录多次运行时间"""
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        self.tik = time.time()

    def stop(self):
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        return sum(self.times) / len(self.times)

    def sum(self):
        return sum(self.times)

    def cumsum(self):
        return np.array(self.times).cumsum().tolist()

class Accumulator:
    """在 n 个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.to(y.dtype) == y
    return float(cmp.to(y.dtype).sum())

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            X, y = X.to(device), y.to(device)
            metric.add(accuracy(net(X), y), y.numel())
    
http://www.dtcms.com/a/98192.html

相关文章:

  • 淘宝双十一大促监控系统开发:实时追踪爆品数据与流量波动
  • 谷粒微服务高级篇学习笔记整理---异步线程池
  • SQL Server数据库引擎服务启动失败:端口冲突
  • 电源系统的热设计与热管理--以反激式充电器为例
  • 1688 店铺清单及全商品数据、关键词检索 API 介绍
  • 【蓝桥杯】每日练习 Day15
  • 【自用记录】本地关联GitHub以及遇到的问题
  • 从代码学习深度学习 - 使用块的网络(VGG)PyTorch版
  • 谈谈你对多态的理解
  • coding ability 展开第七幕(前缀和算法——进阶巩固)超详细!!!!
  • 算法基础——二叉树
  • Java 程序员面试题:从基础到高阶的深度解析
  • Elasticsearch 完全指南
  • 【HarmonyOS 5】初学者如何高效的学习鸿蒙?
  • Bitnode和Bitree有什么区别 为什么Bitree前多了*
  • 缴纳过路费--并查集+优先队列
  • Qt进阶开发:Graphics View图形视图框架
  • QT 跨平台发布指南
  • 枚举算法-day2
  • python 列表-元组-集合-字典
  • 软件工程之软件开发模型(瀑布、迭代、敏捷、DevOps)
  • 综述速读|086.04.24.Retrieval-Augmented Generation for AI-Generated Content A Survey
  • 深度学习处理时间序列(6)
  • 自学-python-基础-注释、数据类型、运算符、判断、循环
  • 树莓派超全系列文档--(13)如何使用raspi-config工具其二
  • 中断管理常用API详解(三)
  • flatMap 介绍及作用
  • C#连接sqlite数据库实现增删改查
  • 大模型最新面试题系列:微调篇之微调框架(二)
  • AI赋能python数据处理、分析与预测操作流程