当前位置: 首页 > news >正文

PyTorch 神经网络工具箱核心知识点总结

一、神经网络核心组件

PyTorch 构建和训练神经网络的基础依赖四大核心组件,各组件职责明确、协同工作:

组件核心作用
神经网络的基本结构单元,负责将输入张量通过参数化变换(如线性变换、卷积)转换为输出张量
模型由多层按特定逻辑组合而成的网络整体,实现从输入到预测输出的端到端映射
损失函数定义模型预测值与真实值的 “误差”,是参数优化的目标函数(如交叉熵损失、MSE)
优化器依据损失函数的梯度,调整模型参数以最小化损失(如 SGD、Adam)

组件工作流:输入数据经 “层→模型” 得到预测值,损失函数计算预测误差,优化器根据误差梯度更新模型参数,形成闭环训练。

二、构建神经网络的核心工具

PyTorch 提供两类核心工具用于网络构建:nn.Module 和 nn.functional,二者功能互补但使用场景不同。

1. 工具对比与特点

对比维度nn.Modulenn.functional
本质面向对象的类(继承自nn.Module纯函数(无状态)
参数管理自动定义、管理weight/bias等可学习参数需手动定义、传入参数,无自动管理
常用场景卷积层(nn.Conv2d)、全连接层(nn.Linear)、Dropout 层(nn.Dropout激活函数(F.relu)、池化层(F.max_pool2d)、损失计算(F.cross_entropy
nn.Sequential兼容支持(可直接作为nn.Sequential的元素)不支持(无法嵌入nn.Sequential
状态转换(如 Dropout)调用model.eval()后自动切换训练 / 测试状态需手动传入training=True/False控制状态

2. 工具使用示例

  • nn.Module 写法:需先实例化类并传入参数,再调用实例处理数据

    python

    运行

    linear = nn.Linear(in_features=784, out_features=300)  # 实例化全连接层
    x = linear(torch.randn(1, 784))  # 传入数据计算输出
    
  • nn.functional 写法:直接调用函数,需手动传入参数

    python

    运行

    weight = torch.randn(300, 784)  # 手动定义权重
    bias = torch.randn(300)         # 手动定义偏置
    x = F.linear(torch.randn(1, 784), weight, bias)  # 手动传入参数
    

三、PyTorch 模型构建的三种核心方法

PyTorch 支持多种模型构建方式,可根据网络复杂度和灵活性需求选择。

方法 1:直接继承nn.Module基类

适用场景:网络结构复杂、需自定义前向传播逻辑(如含分支、残差连接)。核心步骤

  1. __init__方法中定义网络层(如nn.Linearnn.Flatten);
  2. 重写forward方法,定义数据在层间的传播路径。

示例代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Model_Seq(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Model_Seq, self).__init__()# 1. 定义网络层self.flatten = nn.Flatten()  # 展平层(28×28→784)self.linear1 = nn.Linear(in_dim, n_hidden_1)  # 全连接层1self.bn1 = nn.BatchNorm1d(n_hidden_1)  # 批量归一化层1self.linear2 = nn.Linear(n_hidden_1, n_hidden_2)  # 全连接层2self.bn2 = nn.BatchNorm1d(n_hidden_2)  # 批量归一化层2self.out = nn.Linear(n_hidden_2, out_dim)  # 输出层def forward(self, x):# 2. 定义前向传播x = self.flatten(x)x = F.relu(self.bn1(self.linear1(x)))  # 线性→归一化→激活x = F.relu(self.bn2(self.linear2(x)))x = F.softmax(self.out(x), dim=1)  # 输出概率分布return x# 实例化模型
in_dim, n_hidden_1, n_hidden_2, out_dim = 28*28, 300, 100, 10
model = Model_Seq(in_dim, n_hidden_1, n_hidden_2, out_dim)

方法 2:使用nn.Sequential按层顺序构建

适用场景:网络为 “线性堆叠结构”(无分支、无复杂逻辑),代码简洁高效。nn.Sequential支持三种构建方式,核心差异在于是否为层指定名称:

构建方式特点示例代码片段
1. 可变参数无法指定层名,按顺序自动命名(0,1,2...)nn.Sequential(nn.Flatten(), nn.Linear(784, 300), nn.ReLU())
2. add_module手动指定层名,动态添加层seq = nn.Sequential(); seq.add_module("flatten", nn.Flatten())
3. OrderedDict用有序字典指定层名,结构清晰from collections import OrderedDict; nn.Sequential(OrderedDict([("flatten", nn.Flatten())]))

示例(add_module方式)

seq_model = nn.Sequential()
seq_model.add_module("flatten", nn.Flatten())
seq_model.add_module("linear1", nn.Linear(784, 300))
seq_model.add_module("bn1", nn.BatchNorm1d(300))
seq_model.add_module("relu1", nn.ReLU())
seq_model.add_module("out", nn.Linear(300, 10))

方法 3:继承nn.Module+ 使用模型容器

适用场景:网络结构中等复杂(含多个子模块),需用容器整合子模块,提升代码可读性。PyTorch 提供三种核心模型容器,功能各有侧重:

(1)nn.Sequential容器
  • 作用:线性整合子模块(如将 “线性层 + 归一化层” 打包为一个子模块)。
  • 示例
    class Model_Lay(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Model_Lay, self).__init__()# 用Sequential打包子模块self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))self.out = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = nn.Flatten()(x)x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))return F.softmax(self.out(x), dim=1)
    
(2)nn.ModuleList容器
  • 作用:以 “列表” 形式存储子模块,支持索引访问,适合动态调整子模块数量。
  • 特点:仅存储模块,不定义前向传播顺序,需在forward中循环调用。
  • 示例
    class Model_Lst(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Model_Lst, self).__init__()# 用ModuleList存储层(列表形式)self.layers = nn.ModuleList([nn.Flatten(),nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1),nn.ReLU(),nn.Linear(n_hidden_1, n_hidden_2),nn.Linear(n_hidden_2, out_dim)])def forward(self, x):# 循环调用ModuleList中的层for layer in self.layers:x = layer(x)return x
    
(3)nn.ModuleDict容器
  • 作用:以 “字典” 形式存储子模块(键 = 层名,值 = 层实例),支持按名称访问。
  • 特点:需手动指定前向传播的层顺序(通过列表定义键的顺序)。
  • 示例
    class Model_Dict(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Model_Dict, self).__init__()# 用ModuleDict存储层(字典形式)self.layers_dict = nn.ModuleDict({"flatten": nn.Flatten(),"linear1": nn.Linear(in_dim, n_hidden_1),"bn1": nn.BatchNorm1d(n_hidden_1),"relu": nn.ReLU(),"out": nn.Linear(n_hidden_1, out_dim)})def forward(self, x):# 手动定义层的执行顺序order = ["flatten", "linear1", "bn1", "relu", "out"]for layer_name in order:x = self.layers_dict[layer_name](x)return x
    

四、自定义网络模块(以 ResNet 残差块为例)

对于复杂网络(如 ResNet),需自定义核心模块(如残差块),再组合成完整网络。PyTorch 通过继承nn.Module实现自定义模块。

1. 两种核心残差块

残差块的核心是 “跳连(Shortcut Connection)”,即输入直接与模块输出相加,解决深层网络梯度消失问题。根据输入 / 输出形状是否一致,分为两类:

(1)正常残差块(RestNetBasicBlock
  • 适用场景:输入与输出的通道数、分辨率一致,可直接跳连。
  • 结构Conv2d(3×3) → BatchNorm2d → ReLU → Conv2d(3×3) → BatchNorm2d → 跳连 → ReLU
  • 代码
    class RestNetBasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super(RestNetBasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += x  # 直接跳连(输入与输出形状一致)return F.relu(out)
    
(2)下采样残差块(RestNetDownBlock
  • 适用场景:输入与输出的通道数 / 分辨率不一致,需用 1×1 卷积调整输入形状后再跳连。
  • 结构:在正常残差块基础上,添加 “1×1 Conv → BatchNorm2d” 的 shortcut 分支。
  • 代码
    class RestNetDownBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super(RestNetDownBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)self.bn2 = nn.BatchNorm2d(out_channels)# 1×1卷积调整输入形状(通道数、分辨率)self.extra = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),nn.BatchNorm2d(out_channels))def forward(self, x):extra_x = self.extra(x)  # 调整输入形状out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += extra_x  # 调整后跳连return F.relu(out)
    

2. 组合残差块构建 ResNet18

通过堆叠上述残差块,可构建经典的 ResNet18 网络:

class RestNet18(nn.Module):def __init__(self):super(RestNet18, self).__init__()# 初始卷积+池化self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 堆叠残差块(4个层组)self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1), RestNetBasicBlock(64, 64, 1))self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]), RestNetBasicBlock(128, 128, 1))self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]), RestNetBasicBlock(256, 256, 1))self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]), RestNetBasicBlock(512, 512, 1))# 自适应平均池化+全连接层self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))self.fc = nn.Linear(512, 10)  # 10分类任务def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)  # 展平x = self.fc(x)return x

五、PyTorch 模型训练的标准流程

模型构建完成后,需遵循固定流程进行训练,核心步骤共 6 步:

  1. 加载预处理数据集读取训练 / 测试数据,进行预处理(如归一化、数据增强),并通过DataLoader实现批量加载。示例:train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

  2. 定义损失函数根据任务类型选择(分类任务用nn.CrossEntropyLoss,回归任务用nn.MSELoss)。示例:criterion = nn.CrossEntropyLoss()

  3. 定义优化器选择优化算法并传入模型参数,设置学习率等超参数。示例:optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  4. 循环训练模型迭代训练数据,执行 “前向传播→计算损失→反向传播→参数更新”:

    model.train()  # 切换训练模式(启用Dropout、BN更新)
    for epoch in range(10):  # 迭代轮次for inputs, labels in train_loader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播(求梯度)optimizer.step()  # 更新参数
    
  5. 循环测试 / 验证模型每轮训练后,在验证集 / 测试集上评估模型性能(如准确率),避免过拟合:

    model.eval()  # 切换测试模式(禁用Dropout、固定BN)
    with torch.no_grad():  # 禁用梯度计算,节省资源correct = 0total = 0for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"测试准确率: {100 * correct / total}%")
    
  6. 可视化结果通过工具(如 Matplotlib、TensorBoard)可视化训练损失曲线、准确率曲线,分析模型训练趋势。

http://www.dtcms.com/a/395979.html

相关文章:

  • 豆包Seedream 4.0:全面测评、玩法探索与Prompt解读
  • STM32_02_GPIO
  • Flink SlotSharingGroup 机制详解
  • Final Cut Pro X fcpx音视频剪辑编辑(Mac中文)
  • 【LeetCode_88】合并两个有序数组
  • PromptPilot 发布:AI 提示词工程化新利器,首月零元体验
  • MySQL-详解数据库中的触发器
  • JVM调优实战及常量池详解
  • 字典树(Trie)
  • AI浏览器概述:Browser Use、Computer Use、Fellou
  • 「docker」三、3分钟快速安装docker
  • Altium Designer(AD)自定义PCB形状
  • 基于ZYNQ的创世SD NAND卡读写TXT文本实验
  • 文心快码入选2025人工智能AI4SE“银弹”标杆案例
  • 什么是SDN(Software Defined Netwok)
  • GitLab-如何基于现有项目仓库,复制出新的项目仓库
  • 本科大二第三周学习周报
  • 三、自定义Button模板触发器(纯XAML)
  • tar 将多个文件或目录打包成一个单独的归档文件
  • 2025新版 WSL2 + Docker Desktop 下载安装详细全流程指南 实现容器化管理,让开发效率起飞
  • 【LangChain4j】大模型实战-SpringBoot(阿里云百炼控制台)
  • Spring Security / Authorization Server 核心类中英文对照表
  • SqlHelper自定义的Sql工具类
  • 每周读书与学习->初识JMeter 元件(二)
  • 西门子 S7-200 SMART PLC 实操案例:中断程序的灵活应用定时中断实现模拟量滤波(上)
  • 测试分类(1)
  • 广州创科——湖北房县汪家河水库除险加固信息化工程(续集)
  • QT(5)
  • 仓颉语言宏(Cangjie Macros)全面解析:从基础到实战
  • linux RAID存储技术