2025-05-31 Python深度学习9——网络模型的加载与保存
文章目录
- 1 使用现有网络
- 2 修改网络结构
- 2.1 添加新层
- 2.2 替换现有层
- 3 保存网络模型
- 3.1 完整保存
- 3.2 参数保存(推荐)
- 4 加载网络模型
- 4.1 加载完整模型文件
- 4.2 加载参数文件
- 5 Checkpoint
- 5.1 保存 Checkpoint
- 5.2 加载 Checkpoint
本文环境:
- Pycharm 2025.1
- Python 3.12.9
- Pytorch 2.6.0+cu124
PyTorch 通过torchvision.models
提供预训练模型(如 VGG16)。
网址链接:https://docs.pytorch.org/vision/stable/models.html。
1 使用现有网络
以 VGG16 为例,进入网址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

方法一:使用随机初始化权重
将 weights 设置为 None,从 0 开始训练自己的网络。
vgg16_false = torchvision.models.vgg16(weights=None) # 权重随机初始化
方法二:加载预训练权重
也可以使用预训练好的网络参数,加载后可直接使用网络。
这将从官网上下载已训练好的模型文件。
vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
可打印网络查看其模型结构:
print(vgg16_true)


2 修改网络结构
2.1 添加新层
使用add_module
在分类器(classifier
)后追加全连接层:
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

2.2 替换现有层
直接修改分类器的最后一层(如适配 CIFAR10 的 10 分类任务):
vgg16_false.classifier[6] = nn.Linear(4096, 10) # 替换第6层

3 保存网络模型
使用torch.save()
方法保存网络模型。文件扩展名推荐使用.pt
或.pth
。
3.1 完整保存
将模型类和参数一并保存到文件中。
torch.save(vgg16, 'vgg16_method1.pth') # 包含模型类和参数
- 优点:加载时无需重新定义模型结构。
- 缺点:文件较大,且依赖原始代码环境(见 4.1 节)。
3.2 参数保存(推荐)
仅保存参数字典到文件中。
torch.save(vgg16.state_dict(), 'vgg16_method2.pth') # 仅保存参数字典
- 优点:文件小,灵活性强,适合生产部署。
示例
import torch
import torchvision.models
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)# 保存方式 1,模型结构 + 模型参数
torch.save(vgg16, 'vgg16_method1.pth')# 保存方式 2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')
4 加载网络模型
使用torch.load()
方法加载网络模型。
4.1 加载完整模型文件
加载完整模型时,需将 weights_only 参数设置为 False。
model = torch.load('vgg16_method1.pth', weights_only=False) # 需确保模型类已定义
模型打印结果如下:
print(model)

注意
若保存自定义模型,加载时必须确保环境中也有该模型的定义,否则会出现报错。
model_save.py
# model_save.pyimport torch from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):return self.conv1(x)model = MyModel() torch.save(model, 'my_model_method1.pth')
model_load.py
import torchmodel = torch.load('my_model_method1.pth', weights_only=False) # 报错,找不到 MyModel 的定义
先运行 model_save.py,再运行 model_load.py,则会出现以下报错:
![]()
4.2 加载参数文件
首先,使用torch.load()
方法加载网络模型。
使用模型时,需先创建匹配的网络结构,再使用model.load_state_dict()
加载参数数据。
vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict) # 需结构匹配
模型打印结果是参数字典:
print(model_dict)

注意
模型保存时若在 GPU 上,加载时需指定 map_location 为 cup。
torch.load('model.pth', map_location=torch.device('cpu'))
将参数加载到模型后,手动迁移到 GPU:
model = MyModel() model.load_state_dict(model_dict) model.to('cuda:0')
5 Checkpoint
使用 Checkpoint 可以在训练过程中定期保存模型的状态,以便在中断后可以恢复训练,或者在测试时使用最终的模型。文件扩展名推荐使用.tar
。
5.1 保存 Checkpoint
要保存一个模型的 Checkpoint,通常需要保存以下数据:
- 模型的 state_dict(状态字典);
- 优化器的状态;
- 额外的信息,如 epoch 等。
import torch# 假设 model 是你的模型,optimizer 是你的优化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
}# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')
5.2 加载 Checkpoint
加载 Checkpoint,首先需要加载文件,然后将其内容恢复到模型和优化器的状态中。
# 假设 model 和 optimizer 是你的模型和优化器实例
checkpoint = torch.load('checkpoint.tar')model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果需要,可以继续训练
model.train() # 确保模型处于训练模式