当前位置: 首页 > news >正文

2025-05-31 Python深度学习9——网络模型的加载与保存

文章目录

  • 1 使用现有网络
  • 2 修改网络结构
    • 2.1 添加新层
    • 2.2 替换现有层
  • 3 保存网络模型
    • 3.1 完整保存
    • 3.2 参数保存(推荐)
  • 4 加载网络模型
    • 4.1 加载完整模型文件
    • 4.2 加载参数文件
  • 5 Checkpoint
    • 5.1 保存 Checkpoint
    • 5.2 加载 Checkpoint

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ PyTorch 通过torchvision.models提供预训练模型(如 VGG16)。

​ 网址链接:https://docs.pytorch.org/vision/stable/models.html。

1 使用现有网络

​ 以 VGG16 为例,进入网址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

image-20250531103635500

方法一:使用随机初始化权重

​ 将 weights 设置为 None,从 0 开始训练自己的网络。

vgg16_false = torchvision.models.vgg16(weights=None)  # 权重随机初始化

方法二:加载预训练权重

​ 也可以使用预训练好的网络参数,加载后可直接使用网络。
这将从官网上下载已训练好的模型文件。

vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

​ 可打印网络查看其模型结构:

print(vgg16_true)
image-20250531104433678
...
image-20250531104447912

2 修改网络结构

2.1 添加新层

​ 使用add_module在分类器(classifier)后追加全连接层:

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
image-20250531104536554

2.2 替换现有层

​ 直接修改分类器的最后一层(如适配 CIFAR10 的 10 分类任务):

vgg16_false.classifier[6] = nn.Linear(4096, 10)  # 替换第6层
image-20250531104551228

3 保存网络模型

​ 使用torch.save()方法保存网络模型。文件扩展名推荐使用.pt.pth

3.1 完整保存

​ 将模型类和参数一并保存到文件中。

torch.save(vgg16, 'vgg16_method1.pth')  # 包含模型类和参数
  • 优点:加载时无需重新定义模型结构。
  • 缺点:文件较大,且依赖原始代码环境(见 4.1 节)。

3.2 参数保存(推荐)

​ 仅保存参数字典到文件中。

torch.save(vgg16.state_dict(), 'vgg16_method2.pth')  # 仅保存参数字典
  • 优点:文件小,灵活性强,适合生产部署。

示例

import torch
import torchvision.models
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)# 保存方式 1,模型结构 + 模型参数
torch.save(vgg16, 'vgg16_method1.pth')# 保存方式 2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

4 加载网络模型

​ 使用torch.load()方法加载网络模型。

4.1 加载完整模型文件

​ 加载完整模型时,需将 weights_only 参数设置为 False。

model = torch.load('vgg16_method1.pth', weights_only=False)  # 需确保模型类已定义

​ 模型打印结果如下:

print(model)
image-20250531111142517

注意

​ 若保存自定义模型,加载时必须确保环境中也有该模型的定义,否则会出现报错。

  • model_save.py

    # model_save.pyimport torch
    from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):return self.conv1(x)model = MyModel()
    torch.save(model, 'my_model_method1.pth')
    
  • model_load.py

    import torchmodel = torch.load('my_model_method1.pth', weights_only=False)  # 报错,找不到 MyModel 的定义
    

    先运行 model_save.py,再运行 model_load.py,则会出现以下报错:

image-20250531110244566

4.2 加载参数文件

​ 首先,使用torch.load()方法加载网络模型。

​ 使用模型时,需先创建匹配的网络结构,再使用model.load_state_dict()加载参数数据。

vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict)  # 需结构匹配

​ 模型打印结果是参数字典:

print(model_dict)
image-20250531111411199

注意

​ 模型保存时若在 GPU 上,加载时需指定 map_location 为 cup。

torch.load('model.pth', map_location=torch.device('cpu'))

​ 将参数加载到模型后,手动迁移到 GPU:

model = MyModel()
model.load_state_dict(model_dict)
model.to('cuda:0')

5 Checkpoint

​ 使用 Checkpoint 可以在训练过程中定期保存模型的状态,以便在中断后可以恢复训练,或者在测试时使用最终的模型。文件扩展名推荐使用.tar

5.1 保存 Checkpoint

​ 要保存一个模型的 Checkpoint,通常需要保存以下数据:

  • 模型的 state_dict(状态字典);
  • 优化器的状态;
  • 额外的信息,如 epoch 等。
import torch# 假设 model 是你的模型,optimizer 是你的优化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
}# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')

5.2 加载 Checkpoint

​ 加载 Checkpoint,首先需要加载文件,然后将其内容恢复到模型和优化器的状态中。

# 假设 model 和 optimizer 是你的模型和优化器实例
checkpoint = torch.load('checkpoint.tar')model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果需要,可以继续训练
model.train()  # 确保模型处于训练模式

相关文章:

  • Mybatis-Plus简单介绍
  • 深入探讨redis:主从复制
  • Flutter - 原生交互 - 相机Camera - 01
  • 快速掌握 GO 之 RabbitMQ
  • 【iOS】方法交换
  • c/c++的opencv车牌识别
  • MATLAB实现井字棋
  • 可灵2.1 vs Veo 3:AI视频生成谁更胜一筹?
  • matlab/simulink TLC语法基础练习实例
  • Java 数据处理 - 数值转不同进制的字符串(数值转十进制字符串、数值转二进制字符串、数值转八进制字符串、数值转十六进制字符串)
  • C++23 已移除特性解析
  • CQF预备知识:一、微积分 -- 1.8.1 链式法则 I 详解
  • 电子电路:怎么理解时钟脉冲上升沿这句话?
  • PostgreSQL性能监控双雄:深入解析pg_stat_statements与pg_statsinfo
  • 深度学习驱动的超高清图修复技术——综述
  • 【数据结构】图的存储(邻接矩阵与邻接表)
  • LeetCode 1010. 总持续时间可被 60 整除的歌曲
  • 力扣HOT100之动态规划:300. 最长递增子序列
  • Vue-Router简版手写实现
  • go|context源码解析
  • wordpress定制主题/广州百度seo代理
  • 音乐网站开发毕业论文/数字营销是干啥的
  • wordpress 用户中心主题/seo深圳网络推广
  • 山东大学青岛校区建设指挥部网站/软文推广名词解释
  • 做网站是怎么收费的是按点击率/网站seo优化多少钱
  • flask做视频网站/个人博客网站设计毕业论文