当前位置: 首页 > news >正文

第6节 torch.nn.Module

Containers 包含6个模块:ModuleSequentialModuleListModuleDict\ParameterListParameterDict

6.1 torch.nn.Module介绍

        torch.nn.Module是 PyTorch 中构建神经网络的基础类,所有的神经网络模块都应该继承这个类。它提供了一种便捷的方式来组织和管理网络中的各个组件,包括层、参数等,同时还内置了许多用于模型训练和推理的功能。

官网:torch.nn — PyTorch 1.8.1 documentation

核心功能

(1)、网络构建:通过继承torch.nn.Module类,我们可以自定义自己的神经网络结构。在__init__方法中定义网络的各个层,在forward方法中定义数据的前向传播过程。

(2)、参数管理:torch.nn.Module会自动跟踪和管理网络中的参数(如权重和偏置)。我们可以通过parameters()方法获取网络的所有参数,方便进行优化器的配置和参数的更新。

(3)、设备转换:可以使用to()方法将模型转移到指定的设备(如 CPU 或 GPU)上,以利用不同设备的计算能力。​

(4)、状态切换:提供了train()和eval()方法来切换模型的训练和评估状态。在训练状态下,一些具有随机性的层(如 Dropout、BatchNorm)会正常工作;在评估状态下,这些层会采用确定性的行为。

6.2 torch.nn.Module常用方法

        __init__(self):构造函数,用于初始化网络的各个层和参数。在自定义网络时,需要在该方法中调用super().__init__()来初始化父类。​

        forward(self, x):前向传播方法,定义了数据在网络中的流动过程。当对模型进行调用时(如model(x)),实际上是调用了该方法。​

        parameters(self):返回一个迭代器,包含网络中的所有可学习参数。​

        named_parameters(self):返回一个迭代器,包含网络中参数的名称和对应的参数值。​

        to(self, device):将模型转移到指定的设备上。例如,model.to('cuda')将模型转移到 GPU 上。​

        train(self, mode=True):将模型设置为训练模式。​

        eval(self):将模型设置为评估模式,相当于train(mode=False)。​

        save_state_dict(self, path):保存模型的参数状态字典到指定路径。​

        load_state_dict(self, state_dict):从参数状态字典中加载模型的参数。

6.3 程序演示

6.3.1 官网提供的例子

import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):   #搭建的神经网络 Model继承了 Module类(父类)def __init__(self):   #初始化函数super(Model, self).__init__()   #必须要这一步,调用父类的初始化函数self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):   #前向传播(为输入和输出中间的处理过程),x为输入x = F.relu(self.conv1(x))   #conv为卷积,relu为非线性处理return F.relu(self.conv2(x))

注意:前向传播 forward(在所有子类中进行重写)

6.3.2 自定义Model

import torch
from torch import nn# 定义一个自定义模型类Custom_Model,继承自nn.Module
# 所有的神经网络模型都应该继承nn.Module,以利用其提供的参数管理、设备转换等功能
class Custom_Model(nn.Module):# 构造函数,用于初始化模型的层和参数def __init__(self):# 调用父类nn.Module的构造函数,确保模型能够正确初始化super().__init__()# 前向传播方法,定义数据在模型中的流动和计算过程# 当对模型实例传入输入数据时,会自动调用该方法def forward(self, input):# 定义模型的计算逻辑:输入数据加1output = input + 1# 返回计算结果return outputCustom_Model = Custom_Model()
# 创建一个张量x,值为1.0,作为模型的输入数据
x = torch.tensor(1.0)
# 将输入数据x传入模型,模型会自动调用forward方法进行计算,得到输出结果
output = Custom_Model(x)
# 打印输出结果,此时输出应为2.0(1.0 + 1)
print(output)

http://www.dtcms.com/a/328759.html

相关文章:

  • 熬夜面膜赛道跑出的新物种
  • Spring Boot初级概念及自动配置原理
  • 【递归、搜索与回溯算法】综合练习
  • 系统分析师-数据库系统-并发控制数据库安全
  • 使用 UDP 套接字实现客户端 - 服务器通信:完整指南
  • HiSmartPerf使用WIFI方式连接Android机显示当前设备0.0.0.0无法ping通!设备和电脑连接同一网络,将设备保持亮屏重新尝试
  • 【android bluetooth 协议分析 05】【蓝牙连接详解3】【app侧该如何知道蓝牙设备的acl状态】
  • 【KO】Android 面试高频词
  • 从内核数据结构的角度理解socket
  • Android Activity 的对话框(Dialog)样式
  • RxJava 在 Android 中的深入解析:使用、原理与最佳实践
  • 基于Apache Flink的实时数据处理架构设计与高可用性实战经验分享
  • 【cs336学习笔记】[第5课]详解GPU架构,性能优化
  • 深入 Linux 线程:从内核实现到用户态实践,解锁线程创建、同步、调度与性能优化的完整指南
  • iscc2025区域赛wp
  • 服务器通过生成公钥和私钥安全登录
  • Android 在 2020-2025 都做哪些更新?
  • 如何提供对外访问的IP(内网穿透工具)
  • 【Android】ChatRoom App 技术分析
  • OpenAI 回应“ChatGPT 用多了会变傻”
  • Control Center 安卓版:个性化手机控制中心
  • ClickHouse从入门到企业级实战全解析课程简介
  • 1688商品数据抓取:Python爬虫+动态页面解析
  • 基于elk实现分布式日志
  • Windows11 运行IsaacSim GPU Vulkan崩溃
  • 三极管的基极为什么需要下拉电阻
  • Pycharm选好的env有包,但是IDE环境显示无包
  • Excel多级数据结构导入导出工具
  • Nuxt 3 跨域问题完整解决方案(开发 + 生产环境)
  • Appium-移动端自动测试框架详解