第6节 torch.nn.Module
Containers 包含6个模块:Module、Sequential、ModuleList、ModuleDict\ParameterList、ParameterDict
6.1 torch.nn.Module介绍
torch.nn.Module是 PyTorch 中构建神经网络的基础类,所有的神经网络模块都应该继承这个类。它提供了一种便捷的方式来组织和管理网络中的各个组件,包括层、参数等,同时还内置了许多用于模型训练和推理的功能。
官网:torch.nn — PyTorch 1.8.1 documentation
核心功能:
(1)、网络构建:通过继承torch.nn.Module类,我们可以自定义自己的神经网络结构。在__init__方法中定义网络的各个层,在forward方法中定义数据的前向传播过程。
(2)、参数管理:torch.nn.Module会自动跟踪和管理网络中的参数(如权重和偏置)。我们可以通过parameters()方法获取网络的所有参数,方便进行优化器的配置和参数的更新。
(3)、设备转换:可以使用to()方法将模型转移到指定的设备(如 CPU 或 GPU)上,以利用不同设备的计算能力。
(4)、状态切换:提供了train()和eval()方法来切换模型的训练和评估状态。在训练状态下,一些具有随机性的层(如 Dropout、BatchNorm)会正常工作;在评估状态下,这些层会采用确定性的行为。
6.2 torch.nn.Module常用方法
__init__(self):构造函数,用于初始化网络的各个层和参数。在自定义网络时,需要在该方法中调用super().__init__()来初始化父类。
forward(self, x):前向传播方法,定义了数据在网络中的流动过程。当对模型进行调用时(如model(x)),实际上是调用了该方法。
parameters(self):返回一个迭代器,包含网络中的所有可学习参数。
named_parameters(self):返回一个迭代器,包含网络中参数的名称和对应的参数值。
to(self, device):将模型转移到指定的设备上。例如,model.to('cuda')将模型转移到 GPU 上。
train(self, mode=True):将模型设置为训练模式。
eval(self):将模型设置为评估模式,相当于train(mode=False)。
save_state_dict(self, path):保存模型的参数状态字典到指定路径。
load_state_dict(self, state_dict):从参数状态字典中加载模型的参数。
6.3 程序演示
6.3.1 官网提供的例子
import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module): #搭建的神经网络 Model继承了 Module类(父类)def __init__(self): #初始化函数super(Model, self).__init__() #必须要这一步,调用父类的初始化函数self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x): #前向传播(为输入和输出中间的处理过程),x为输入x = F.relu(self.conv1(x)) #conv为卷积,relu为非线性处理return F.relu(self.conv2(x))
注意:前向传播 forward(在所有子类中进行重写)
6.3.2 自定义Model
import torch
from torch import nn# 定义一个自定义模型类Custom_Model,继承自nn.Module
# 所有的神经网络模型都应该继承nn.Module,以利用其提供的参数管理、设备转换等功能
class Custom_Model(nn.Module):# 构造函数,用于初始化模型的层和参数def __init__(self):# 调用父类nn.Module的构造函数,确保模型能够正确初始化super().__init__()# 前向传播方法,定义数据在模型中的流动和计算过程# 当对模型实例传入输入数据时,会自动调用该方法def forward(self, input):# 定义模型的计算逻辑:输入数据加1output = input + 1# 返回计算结果return outputCustom_Model = Custom_Model()
# 创建一个张量x,值为1.0,作为模型的输入数据
x = torch.tensor(1.0)
# 将输入数据x传入模型,模型会自动调用forward方法进行计算,得到输出结果
output = Custom_Model(x)
# 打印输出结果,此时输出应为2.0(1.0 + 1)
print(output)