当前位置: 首页 > news >正文

使用pytorch保存和加载预训练的模型方法

需要使用到的函数

在 PyTorch 中,torch.save()torch.load() 是用于保存和加载模型的核心函数。

torch.save() 函数

  • 主要用途:将模型或模型的状态字典(state_dict)保存到文件中。

  • 语法

torch.save(obj, f, pickle_module=pickle, pickle_protocol=None, _use_new_zipfile_serialization=True)
  • obj: 要保存的对象,可以是整个模型(nn.Module)或模型的状态字典(state_dict)。

  • f: 保存文件的路径。可以是一个字符串路径(如 'model.pth''model.pkl')或一个打开的文件对象。

  • pickle_module: 默认是 pickle,用于序列化对象。你可以使用其他兼容的序列化模块。

  • pickle_protocol: pickle 协议版本。默认值为 None,表示使用最高可用协议版本。

  • _use_new_zipfile_serialization: 默认值为 True,控制是否使用新的序列化格式(推荐使用)。

# 保存整个模型
torch.save(model, 'model.pth')# 保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')

torch.load() 函数

  • 主要用途:从文件中加载保存的模型或模型的状态字典。

  • 语法

torch.load(f, map_location=None, pickle_module=pickle)
  • f: 要加载的文件路径。可以是一个字符串路径或一个打开的文件对象。

  • map_location: 控制如何将存储位置映射到当前设备。例如,map_location='cuda:0' 表示将模型加载到 GPU 上。

  • pickle_module: 默认是 pickle,用于反序列化对象。

# 加载整个模型
model = torch.load('model.pth', map_location='cpu')  # 加载到 CPU# 加载模型的状态字典
model_state_dict = torch.load('model_state_dict.pth', map_location='cuda:0')  # 加载到 GPU

加载状态字典到模型

  • 加载状态字典后,通常需要将其加载到一个已经实例化的模型中。可以使用 model.load_state_dict() 方法:

  • 语法

model.load_state_dict(state_dict, strict=True)
  • state_dict: 从文件中加载的模型状态字典。

  • strict: 默认为 True,表示严格加载状态字典中的所有键。如果设置为 False,可以忽略不匹配的键。

# 实例化模型
model = SimpleModel()# 加载状态字典
model_state_dict = torch.load('model_state_dict.pth')# 将状态字典加载到模型中
model.load_state_dict(model_state_dict)

 注意事项

  • 设备映射:使用 torch.load() 时,可以指定 map_location 参数来控制模型加载到的设备(如 CPU 或 GPU)。

  • 自定义类:保存和加载整个模型时,需要确保自定义的模型类在加载代码中已经定义,否则会报错。

  • 兼容性torch.save()torch.load() 使用 pickle 序列化,可能会受到 Python 版本和 PyTorch 版本的影响。建议使用相同版本的 PyTorch 和 Python 进行保存和加载。

  • 推荐使用状态字典:保存和加载状态字典(state_dict)比保存整个模型更灵活和可移植。这样可以避免保存自定义类的依赖关系。

通过以上方法,你可以灵活地保存和加载 PyTorch 模型,无论是 .pth 还是 .pkl 格式,都可以根据需要选择合适的保存方式。

保存和读取.pth格式的预训练模型

保存

import torch
import torch.nn as nn# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleModel()# 假设已经训练了模型,这里只是演示保存
# 保存整个模型
torch.save(model, 'model.pth')
# 或者只保存模型的参数
torch.save(model.state_dict(), 'model_state_dict.pth')

读取

# 如果保存的是整个模型
loaded_model = torch.load('model.pth')
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
model_load.load_state_dict(torch.load('model_state_dict.pth'))
###########################################################################
# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载预训练模型
model = SimpleModel()
model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))# 将模型转移到 GPU
model.to(device)# 示例输入数据
input_data = torch.randn(1, 10).to(device)  # 确保输入数据也在 GPU 上# 前向传播
output = model(input_data)
print(output)

在使用 model.load_state_dict(torch.load('model_state_dict.pth', map_location=device)) 读取模型时,已经指定了 map_location=device,这确保了模型的参数(张量)被加载到指定的设备上。但是,是否还需要调用 model.to(device) 取决于具体的情况。

详细分析

  1. map_location=device 的作用

    • map_location=device 参数用于指定加载的张量应该被放置到哪个设备上。当你加载模型的状态字典时,这个参数确保所有张量(如模型的权重和偏置)被加载到指定的设备(CPU 或 GPU)。

    • 这个参数主要用于处理加载时的设备映射,特别是在加载存储在不同设备上的模型时(例如,从 GPU 上保存的模型加载到 CPU 上或反之)。

2 .model.to(device) 的作用

  • model.to(device) 用于将整个模型(包括模型的参数、缓冲区等)转移到指定的设备上。这是一个递归操作,会遍历模型的所有子模块并将其转移到目标设备。

  • 如果模型在加载时已经将所有张量加载到了正确的设备上(通过 map_location=device),那么调用 model.to(device) 是冗余的,但它不会产生负面影响。

具体情况分析

  • 加载到 CPU: -你在 如果 CPU 上加载模型,并且使用 map_location='cpu',那么模型的张量已经被加载到 CPU 上。在这种情况下,调用 model.to('cpu') 是不必要的,因为模型已经在 CPU 上了。

  • 加载到 GPU

    • 如果你在 GPU 上加载模型,并且使用 map_location='cuda'map_location=device(其中 device 是 GPU),那么模型的张量已经被加载到 GPU 上。但是,模型对象本身(如模型的结构)可能仍然在 CPU 上。

    • 此,调用 model.to(device) 可以确保模型的所有部分(包括模型的结构和参数)都正确地在 GPU 上。

推荐做法

为了确保模型及其所有组成部分都在正确的设备上,建议在加载模型后调用 model.to(device)。这样可以避免潜在的设备不一致问题。

保存和读取.pkl格式的预训练模型

保存

import torch
import torch.nn as nn# 义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleModel()# 保存整个模型
with open('model.pkl', 'wb') as f:torch.save(model, f)
# 或者只保存模型的参数
with open('model_state_dict.pkl', 'wb') as f:torch.save(model.state_dict(), f)

读取

# 如果保存的是整个模型
with open('model.pkl', 'rb') as f:loaded_model = torch.load(f)
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
with open('model_state_dict.pkl', 'rb') as f:model_load.load_state_dict(torch.load(f))

两种格式的区别

  • pth 格式

    • 是 PyTorch 推荐的模型保存格式。它使用 Python 的 pickle 模块来序列化模型对象。对于模型的存储来说,它能够较好地保存和加载模型的结构以及参数。当你想要完整地保存和恢复一个模型的训练状态(包括模型结构、参数、优化器等时),使用.pth 格式很方便。

  • pkl 格式

    • 本质上也是使用 pickle 序列化对象。它是一种通用的 Python 对象序列化格式。在 PyTorch 的早期版本中,pkl 格式被广泛用于保存模型。但是使用 pkl 格式时,可能会受到 Python 版本的限制。因为不同 Python 版本之间,pickle 序列化后的对象在反序列化时可能会出现兼容性问题。例如,你在 Python 3.7 环境下用 pickle 保存了一个模型,然后在 Python 3.8 环境下尝试加载时,可能会因为 pickle 协议版本或者对象结构差异等原因导致加载失败。而.pth 格式会更好地处理这些兼容性问题。

注意事项

  • 当保存整个模型时,如果自定义了模型类,加载模型时也需要提供相同的自定义类定义。否则加载时会出现错误,因为无法识别自定义类的结构。

  • 如果只保存模型参数(state_dict),在加载时必须先实例化一个与保存时相同的模型结构,然后将保存的参数加载到这个结构中。这样可以避免保存自定义类的依赖关系,增加模型的可移植性,但前提是你要清楚地知道模型的结构。

相关文章:

  • 基于Transformer的多资产收益预测模型实战(附PyTorch实现与避坑指南)
  • OpenHarmony平台驱动开发(九),MIPI DSI
  • 如何使用npm下载指定版本的cli工具
  • 【MySQL】存储引擎 - MyISAM详解
  • FPGA_Verilog实现QSPI驱动,完成FLASH程序固化
  • [ctfshow web入门] web57
  • 到达最后一个房间的最少时间II 类似棋盘转移规律查找
  • QTDesinger如何给label加边框
  • Java后端程序员学习前端之JavaScript
  • k8s的pod挂载共享内存
  • Mysql-OCP PPT课程讲解并翻译
  • 数据结构 - 9( 位图 布隆过滤器 并查集 LRUCache 6000 字详解 )
  • 9. 从《蜀道难》学CSS基础:三种选择器的实战解析
  • 分贝计在评估噪音对学习的影响中起着至关重要作用
  • android-ndk开发(10): use of undeclared identifier ‘pthread_getname_np‘
  • exo:打造家用设备AI集群的开源解决方案
  • 基于Flink的用户画像 OLAP 实时数仓统计分析
  • Android NDK版本迭代与FFmpeg交叉编译完全指南
  • CTF - PWN之ORW记录
  • 手写 vue 源码 ===:自定义调度器、递归调用规避与深度代理
  • 巴基斯坦称对印度发起军事行动
  • 欧洲理事会前主席米歇尔受聘中欧国际工商学院特聘教授,上海市市长龚正会见
  • 国常会:研究深化国家级经济技术开发区改革创新有关举措等
  • 上汽享道出行完成13亿元C轮融资,已启动港股IPO计划
  • AI药企英矽智能第三次递表港交所:去年亏损超1700万美元,收入多数来自对外授权
  • 习近平出席俄罗斯纪念苏联伟大卫国战争胜利80周年庆典