pytorch学习笔记-模型的保存与加载(自定义模型、网络模型)
博主最近勤奋更新的原因是一来之前攒了一些囤的,二来是终于要学完了一鼓作气啊啊啊啊
这一节写一下模型保存&加载,推荐方式2,方式1了解一下就ok
现有的网络模型的保存与加载
先要引入现有的网络模型
import torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)
保存方式1&加载方式1
保存方式:
#保存方式1
torch.save(vgg16,"vgg16_method1.pth")
加载方式:
注意这里不写weights_only会提示:
(1) In PyTorch 2.6, we changed the default value of the weights_only
argument in torch.load
from False
to True
. Re-running torch.load
with weights_only
set to False
will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.懒得翻译了大概看一眼吧
(2)xxxxx…就是建议你用方式2写
#加载方式1
model = torch.load("vgg16_method1.pth",weights_only=False)
# print(model)
保存方式2&加载方式2
保存方式:
#方式2,以字典形式存储
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
加载方式:
注意点就是要先定义一个模型,然后再把参数导入到模型中
#方式2存储加载
#要先定义一个模型,然后再把参数导入到模型中
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
自定义网络模型的保存与加载
假设你在model_save.py文件中定义了这样一个model:
class MyModule(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self,x):x = self.conv1(x)return xmy_module = MyModule()
保存方式1&加载方式1
保存方式:
#保存方式1
torch.save(my_module,"my_module_method1.pth")
加载方式:
注意一下就是如果你在model_load.py中如果没有对应的网络结构,会加载失败,因此需要引入自定义的模型,不用实例化,可选方式有1.引入包(推荐)2.把网络结构复制过来
#加载自定义的模型
#要求引入自定义的模型,不用实例化
#可选方式可以引入包(推荐)或者把网络结构复制过来
model = torch.load("my_module_method1.pth",weights_only=False)
# print(model)
保存方式2&加载方式2
保存方式:
#自定义模型存储2
torch.save(my_module.state_dict(),"my_module_method2.pth")
加载方式:
和加载现有的网络模型差不多,都是要先定义一个模型,然后再把参数导入到模型中
model2 = MyModule()
model2.load_state_dict(torch.load("my_module_method2.pth"))
print(model2)