当前位置: 首页 > news >正文

pytorch学习笔记-模型的保存与加载(自定义模型、网络模型)

博主最近勤奋更新的原因是一来之前攒了一些囤的,二来是终于要学完了一鼓作气啊啊啊啊

这一节写一下模型保存&加载,推荐方式2,方式1了解一下就ok

现有的网络模型的保存与加载

先要引入现有的网络模型

import torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)

保存方式1&加载方式1

保存方式:
#保存方式1
torch.save(vgg16,"vgg16_method1.pth")
加载方式:

注意这里不写weights_only会提示:
(1) In PyTorch 2.6, we changed the default value of the weights_only argument in torch.load from False to True. Re-running torch.load with weights_only set to False will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.懒得翻译了大概看一眼吧
(2)xxxxx…就是建议你用方式2写

#加载方式1
model = torch.load("vgg16_method1.pth",weights_only=False)
# print(model)

保存方式2&加载方式2

保存方式:
#方式2,以字典形式存储
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
加载方式:

注意点就是要先定义一个模型,然后再把参数导入到模型中

#方式2存储加载
#要先定义一个模型,然后再把参数导入到模型中
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))

自定义网络模型的保存与加载

假设你在model_save.py文件中定义了这样一个model:

class MyModule(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self,x):x = self.conv1(x)return xmy_module = MyModule()

保存方式1&加载方式1

保存方式:
#保存方式1
torch.save(my_module,"my_module_method1.pth")
加载方式:

注意一下就是如果你在model_load.py中如果没有对应的网络结构,会加载失败,因此需要引入自定义的模型,不用实例化,可选方式有1.引入包(推荐)2.把网络结构复制过来

#加载自定义的模型
#要求引入自定义的模型,不用实例化
#可选方式可以引入包(推荐)或者把网络结构复制过来
model = torch.load("my_module_method1.pth",weights_only=False)
# print(model)

保存方式2&加载方式2

保存方式:
#自定义模型存储2
torch.save(my_module.state_dict(),"my_module_method2.pth")
加载方式:

和加载现有的网络模型差不多,都是要先定义一个模型,然后再把参数导入到模型中

model2 = MyModule()
model2.load_state_dict(torch.load("my_module_method2.pth"))
print(model2)
http://www.dtcms.com/a/333630.html

相关文章:

  • 大白话解析 Solidity 中的防重放参数
  • USENIX Security ‘24 Fall Accepted Papers (1)
  • 归并排序和统计排序
  • 用matlab实现的svdd算法
  • 2025年机械制造、机器人与计算机工程国际会议(MMRCE 2025)
  • gnu arm toolchain中的arm-none-eabi-gdb.exe的使用方法?
  • C#WPF实战出真汁05--左侧导航
  • 日常反思总结
  • 异步开发:协程、线程、Unitask
  • 线性代数 · 直观理解矩阵 | 空间变换 / 特征值 / 特征向量
  • 树莓派开机音乐
  • 模板引用(Template Refs)全解析2
  • CVE-2025-8088复现
  • 汽车行业 AI 视觉检测方案(二):守护车身密封质量
  • 【总结】Python多线程
  • 华清远见25072班C语言学习day10
  • 342. 4的幂
  • 自定义数据集(pytorchhuggingface)
  • 附046.集群管理-EFK日志解决方案-Filebeat
  • 考研复习-计算机组成原理-第七章-IO
  • NumPy基础入门
  • 第40周——GAN入门
  • 详解区块链技术及主流区块链框架对比
  • PSME2通过IL-6/STAT3信号轴调控自噬
  • 【机器学习】核心分类及详细介绍
  • 控制块在SharedPtr中的作用(C++)
  • 【秋招笔试】2025.08.15饿了么秋招机考-第二题
  • 基于MATLAB的机器学习、深度学习实践应用
  • Matlab(5)进阶绘图
  • 后端学习资料 持续更新中