当前位置: 首页 > news >正文

PyTorch系列教程:高效保存和加载PyTorch模型

PyTorch是一个强大的深度学习库,广泛用于构建和训练神经网络。PyTorch的优点之一是它提供了简单有效的函数来保存和加载模型。这对于恢复训练、共享模型或部署模型进行推理等任务特别有用。

保存PyTorch模型

torch.save() 函数用于序列化模型并将其保存到磁盘。这个过程很简单,但是对torch.save()的特性有一个很好的理解将有助于你有效地管理你保存的模型。
在这里插入图片描述

保存PyTorch模型的通用语法包括两部分:模型的状态字典和推荐的文件格式,通常以.pt或.pth作为扩展名。

import torch
import torch.nn as nn
import torch.optim as optim

# Example model
eclass SampleModel(nn.Module):
    def __init__(self):
        super(SampleModel, self).__init__()
        self.layer = nn.Linear(10, 2)

    def forward(self, x):
        return self.layer(x)

model = SampleModel()

# Define a random optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Specify the path for saving the model
PATH = "./model.pth"

# Save the model's state_dict (recommended approach)
torch.save(model.state_dict(), PATH)

这里,state_dict是一个Python字典对象,它存储模型的所有参数和持久缓冲区(例如,运行任何批处理规范层的平均值)。

加载PyTorch模型

torch.load() 函数用于加载已保存的PyTorch模型。要记住的关键一点是,将state_dict 加载到模型中需要使用与保存时相同的体系结构实例化模型类。

# Initialize the model again
model = SampleModel()

# Load the weights from the saved file into the model's state_dict
model.load_state_dict(torch.load(PATH))

# Remember to set the model to evaluation mode after loading
model.eval()

在此代码片段中,将权重加载到模型中之后,调用model.eval()将模型设置为求值模式至关重要。如果模型包含诸如dropout层或批处理规范化层之类的层,这一点尤其重要,因为它们在评估和训练期间的行为不同。

保存完整模型

如果你不仅想保存state_dict,还想保存整个模型,包括体系结构,可以使用torch.save()。但是,一定要小心,因为保存整个模型会创建一个很大的文件大小,并引入依赖关系

# This saves the entire model including the architecture
torch.save(model, "./full_model.pth")

# To load the full model:
model = torch.load("./full_model.pth")
model.eval()

如果你计划重构代码并仍然访问模型权重,则通常不建议使用这种方法,因为加载将使用保存时的代码。与state_dict方法相比,它提供的灵活性较差。

最后总结

在PyTorch中保存和加载模型可以很简单,但需要了解何时使用状态字典而不是整个模型。当你希望干净地存储参数值并保持代码效率时,请使用state_dicts进行保存和加载。只有在必要的情况下保留体系结构时才选择保存整个模型,并认识到所涉及的权衡。无论你是在训练阶段之间进行转换,还是在进行推理部署,PyTorch 的序列化和反序列化都能为你提供所需的必备工具。

相关文章:

  • Redis中常见的问题
  • 蓝牙基础知识学习补充
  • 前端工程化之前端工程化详解 包管理工具
  • 深度学习多模态人脸情绪识别:从理论到实践
  • 卷积神经网络(CNN)的主要架构
  • 数据库的基本知识
  • pytest+allure+jenkins
  • 力扣 11.盛水最多的容器(双指针)
  • matlab 八自由度汽车垂向动力学参数优化带座椅
  • ​【C++设计模式】第二十一篇:模板方法模式(Template Method)
  • Docker命令笔记
  • 网页制作14-Javascipt时间特效の显示动态日期
  • HTB 学习笔记 【中/英】《Web 应用 - 布局》P2
  • JavaCV
  • java集合框架的List 接口提供了两种主要的访问元素的方式:迭代器(Iterator)和索引访问,优缺点对比
  • 《C++:无可替代的编程传奇》:此文为AI自动生成
  • elementui table 自动滚动 纯js实现
  • 【fNIRS可视化学习1】基于NIRS-SPM进行光极可视化并计算通道坐标
  • ubuntu系统下添加pycharm到快捷启动栏方法
  • 【漫话机器学习系列】134.基于半径的最近邻分类器(Radius-Based Nearest Neighbor Classifier)
  • 爱德华多·阿拉纳宣誓就任秘鲁新总理
  • 特朗普访问卡塔尔,两国签署多项合作协议
  • 呼吸医学专家杜晓华博士逝世,终年50岁
  • 市场监管总局召开平台企业支持个体工商户发展座谈会
  • 外交部:中方对美芬太尼反制仍然有效
  • 俄副外长:俄美两国将举行双边谈判