深度学习---PyTorch 神经网络工具箱
一、神经网络核心组件
神经网络的运行依赖四大核心组件,各组件功能明确、协同工作,具体如下表所示:
组件 | 功能描述 |
---|---|
层 | 神经网络的基础结构单元,核心作用是对输入张量进行数据变换(如通过权重运算),最终输出新张量 |
模型 | 由多个 “层” 按特定逻辑组合而成的完整网络结构,是实现分类、回归等任务的核心载体 |
损失函数 | 参数学习的目标函数,用于计算模型预测值(Y')与真实值(Y)的差异,模型通过最小化该函数优化参数 |
优化器 | 负责执行 “最小化损失函数” 的算法,通过调整模型中的权重等参数,使损失值逐步降低,提升模型性能 |
二、构建神经网络的主要工具
PyTorch 中构建网络的核心工具为nn.Module
和nn.functional
,二者适用场景与使用方式差异显著,具体对比及说明如下:
(一)核心工具对比
对比维度 | nn.Module | nn.functional |
---|---|---|
本质特性 | 面向对象,需继承该类定义网络模块 | 纯函数式接口,直接调用函数实现功能 |
参数管理 | 自动提取、定义和管理可学习参数(如 weight、bias),无需手动干预 | 需手动定义 weight、bias 等参数,且每次调用需手动传入,代码复用性差 |
适用场景 | 卷积层(nn.Conv2d )、全连接层(nn.Linear )、Dropout 层等需参数学习的组件 | 激活函数(如nn.functional.relu )、池化层(如nn.functional.max_pool2d )等无参数学习的操作 |
语法格式 | 写法为nn.Xxx ,需先实例化(如self.linear = nn.Linear(in_dim, out_dim) ),再调用实例处理数据 | 写法为nn.functional.xxx ,直接调用函数(如x = nn.functional.relu(x) ) |
状态转换 | Dropout 等需区分 “训练 / 测试” 状态的组件,调用model.eval() 可自动切换状态 | 无自动状态转换功能,需手动控制(如额外传入training 参数) |
容器兼容性 | 可与nn.Sequential 等模型容器结合使用,便于层的顺序组合 | 无法与nn.Sequential 结合,需手动按顺序调用函数 |
(二)nn.Module
的核心作用
nn.Module
是构建网络的基础类,其核心优势在于自动管理可学习参数,并支持与多种工具协同完成网络构建与训练,具体流程关联如下:
- 网络构建:通过继承
nn.Module
定义自定义模型类,在__init__
方法中定义网络层(如Linear
、Conv2d
); - 正向传播:需在模型类中实现
forward
方法,定义数据在层间的流动逻辑(即正向传播过程); - 反向传播:依托
torch.autograd
自动计算梯度(无需手动实现反向传播逻辑); - 参数更新:结合
torch.optim
(如SGD
、Adam
)的step()
方法,更新通过nn.Module
管理的参数; - 容器支持:可与
nn.Sequential
、nn.ModuleList
、nn.ModuleDict
等容器结合,灵活组织复杂网络结构。
三、模型构建的三种核心方法
PyTorch 提供三种主流模型构建方式,可根据网络复杂度和灵活性需求选择:
构建方式 | 核心逻辑 | 适用场景 |
---|---|---|
继承nn.Module 基类构建 | 自定义模型类,手动在__init__ 定义层、在forward 定义数据流动 | 网络结构复杂(如含多分支、自定义运算),需灵活控制正向传播逻辑 |
使用nn.Sequential 按层顺序构建 | 将层按执行顺序传入nn.Sequential ,无需手动实现forward | 网络为 “线性结构”(无分支、无复杂运算),如简单的全连接网络、基础卷积网络 |
继承nn.Module + 模型容器封装 | 在自定义nn.Module 类中,用nn.Sequential /nn.ModuleList /nn.ModuleDict 封装部分层 | 网络可拆分为多个子模块(如特征提取块、分类块),需平衡灵活性与代码简洁性 |
关键说明
nn.Sequential
:按顺序组合层,支持 “可变参数传入”“add_module
方法添加层”“OrderedDict
指定层名称” 三种使用方式;- 模型容器(
nn.Sequential
/nn.ModuleList
/nn.ModuleDict
)的核心作用是简化层的管理与调用,避免代码冗余,尤其适用于多层重复或动态层结构的场景。