当前位置: 首页 > news >正文

Python训练Day35

@浙大疏锦行

  1. 三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化
  2. 进度条功能:手动和自动写法,让打印结果更加美观
  3. 推理的写法:评估模式

一、模型结构可视化

理解一个深度学习网络最重要的2点:

1. 了解损失如何定义的,知道损失从何而来----把抽象的任务通过损失函数量化出来

2. 了解参数总量,即知道每一层的设计才能退出---层设计决定参数总量

为了了解参数总量,我们需要知道层设计,以及每一层参数的数量。下面介绍几个层可视化工具:

1.1 nn.model自带的方法

#  nn.Module 的内置功能,直接输出模型结构
print(model)

这是最基础、最简单的方法,直接打印模型对象,它会输出模型的结构,显示模型中各个层的名称和参数信息

1.2 torchsummary库的summary方法

# pip install torchsummary -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchsummary import summary
# 打印模型摘要,可以放置在模型定义后面
summary(model, input_size=(4,))

    该方法不显示输入层的尺寸,因为输入的神经网是自己设置的,所以不需要显示输入层的尺寸。但是在使用该方法时,input_size=(4,) 参数是必需的,因为 PyTorch 需要知道输入数据的形状才能推断模型各层的输出形状和参数数量。

    这是因为PyTorch 的模型在定义时是动态的,它不会预先知道输入数据的具体形状nn.Linear(4, 10) 只定义了 “输入维度是 4,输出维度是 10”,但不知道输入的批量大小和其他维度,比如卷积层需要知道输入的通道数、高度、宽度等信息。----并非所有输入数据都是结构化数据。因此,要生成模型摘要(如每层的输出形状、参数数量),必须提供一个示例输入形状,让 PyTorch “运行” 一次模型,从而推断出各层的信息。

summary 函数的核心逻辑是:

1. 创建一个与 input_size 形状匹配的虚拟输入张量(通常填充零)

2. 将虚拟输入传递给模型,执行一次前向传播(但不计算梯度)

3. 记录每一层的输入和输出形状,以及参数数量

4. 生成可读的摘要报告

1.3 torchinfo库的summary方法

torchinfo 是提供比 torchsummary 更详细的模型摘要信息,包括每层的输入输出形状、参数数量、计算量等。

# pip install torchinfo -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchinfo import summary
summary(model, input_size=(4, ))

二、 进度条功能

     介绍下tqdm这个库,他非常适合用在循环中观察进度。尤其在深度学习这种训练是循环的场景中。他最核心的逻辑如下

1. 创建一个进度条对象,并传入总迭代次数。一般用with语句创建对象,这样对象会在with语句结束后自动销毁,保证资源释放。with是常见的上下文管理器,这样的使用方式还有用with打开文件,结束后会自动关闭文件。

2. 更新进度条,通过pbar.update(n)指定每次前进的步数n(适用于非固定步长的循环)。

2.1 手动更新

from tqdm import tqdm  # 先导入tqdm库
import time  # 用于模拟耗时操作# 创建一个总步数为10的进度条
with tqdm(total=10) as pbar:  # pbar是进度条对象的变量名# pbar 是 progress bar(进度条)的缩写,约定俗成的命名习惯。for i in range(10):  # 循环10次(对应进度条的10步)time.sleep(0.5)  # 模拟每次循环耗时0.5秒pbar.update(1)  # 每次循环后,进度条前进1步
from tqdm import tqdm
import time# 创建进度条时添加描述(desc)和单位(unit)
with tqdm(total=5, desc="下载文件", unit="个") as pbar:# 进度条这个对象,可以设置描述和单位# desc是描述,在左侧显示# unit是单位,在进度条右侧显示for i in range(5):time.sleep(1)pbar.update(1)  # 每次循环进度+1

unit 参数的核心作用是明确进度条中每个进度单位的含义,使可视化信息更具可读性。在深度学习训练中,常用的单位包括:

- epoch:训练轮次(遍历整个数据集一次)。

- batch:批次(每次梯度更新处理的样本组)。

- sample:样本(单个数据点)

2.2 自动更新

from tqdm import tqdm
import time# 直接将range(3)传给tqdm,自动生成进度条
# 这个写法我觉得是有点神奇的,直接可以给这个对象内部传入一个可迭代对象,然后自动生成进度条
for i in tqdm(range(3), desc="处理任务", unit="epoch"):time.sleep(1)

for i in tqdm(range(3), desc="处理任务", unit="个")这个写法则不需要在循环中调用update()方法,更加简洁

三、 模型的推理

测试这个词在大模型领域叫做推理(inference),意味着把数据输入到训练好的模型的过程。

# 在测试集上评估模型,此时model内部已经是训练好的参数了
# 评估模型
model.eval() # 设置模型为评估模式
with torch.no_grad(): # torch.no_grad()的作用是禁用梯度计算,可以提高模型推理速度outputs = model(X_test)  # 对测试数据进行前向传播,获得预测结果_, predicted = torch.max(outputs, 1) # torch.max(outputs, 1)返回每行的最大值和对应的索引#这个函数返回2个值,分别是最大值和对应索引,参数1是在第1维度(行)上找最大值,_ 是Python的约定,表示忽略这个返回值,所以这个写法是找到每一行最大值的下标# 此时outputs是一个tensor,p每一行是一个样本,每一行有3个值,分别是属于3个类别的概率,取最大值的下标就是预测的类别# predicted == y_test判断预测值和真实值是否相等,返回一个tensor,1表示相等,0表示不等,然后求和,再除以y_test.size(0)得到准确率# 因为这个时候数据是tensor,所以需要用item()方法将tensor转化为Python的标量# 之所以不用sklearn的accuracy_score函数,是因为这个函数是在CPU上运行的,需要将数据转移到CPU上,这样会慢一些# size(0)获取第0维的长度,即样本数量correct = (predicted == y_test).sum().item() # 计算预测正确的样本数accuracy = correct / y_test.size(0)print(f'测试集准确率: {accuracy * 100:.2f}%')

    模型的评估模式简单来说就是评估阶段会关闭一些训练相关的操作和策略 ,比如更新参数 正则化等操作,确保模型输出结果的稳定性和一致性。

为什么评估模式不关闭梯度计算,推理不是不需要更新参数么?

    主要还是因为在某些场景下,评估阶段可能需要计算梯度(虽然不更新参数)。例如:计算梯度用于可视化(如 CAM 热力图,主要用于cnn相关)。所以为了避免这种需求不被满足,还是需要手动关闭梯度计算。

http://www.dtcms.com/a/318426.html

相关文章:

  • Python在生物计算与医疗健康领域的应用(2025深度解析)
  • 局域网内某服务器访问其他服务器虚拟机内相关服务配置
  • 无人机遥控器舵量技术解析
  • 线上Linux服务器的优化设置、系统安全与网络安全策略
  • Android14的QS面板的加载解析
  • 云平台托管集群:EKS、GKE、AKS 深度解析与选型指南-第四章
  • k8s 网络插件 flannel calico
  • 第14届蓝桥杯Scratch选拔赛初级及中级(STEMA)真题2023年1月15日
  • 链式数据结构
  • LangChain4j实战
  • 深入解析系统调试利器:strace 从入门到精通
  • Linux——(16)深入理解程序运行的基石
  • 12. SELinux 加固 Linux 安全
  • react 流式布局(图片宽高都不固定)的方案及思路
  • npm run dev npm run build
  • Activiti7 调用子流程的配置和处理
  • 【Day 17】Linux-SSH远程连接
  • TMS320F2837xD的CLA加速器开发手册
  • mobaxterm怎么复制全局内容
  • ABP VNext + SQL Server Temporal Tables:审计与时序数据管理
  • 串口通信 day48
  • 华清远见25072班C语言学习day3
  • EXCEL-业绩、目标、达成、同比、环比一图呈现
  • Etcd,真的需要集群部署吗?
  • 消防通道占用识别误报率↓79%!陌讯动态融合算法实战优化
  • 模 板 方 法 模 式
  • 人大金仓数据库逻辑备份与恢复命令
  • PostgreSQL报错“maximum number of prepared transactions reached”原因及高效解决方案解析
  • 百货零售行业数字化蓝图整体规划方案(165页PPT)满分可编辑PPT
  • 构建语义搜索引擎:Weaviate的实践与探索