当前位置: 首页 > wzjs >正文

北京沙河教做网站的长春什么时候解封

北京沙河教做网站的,长春什么时候解封,最新新闻热点事件2021年9月,做网站要哪些技术需要使用到的函数 在 PyTorch 中,torch.save() 和 torch.load() 是用于保存和加载模型的核心函数。 torch.save() 函数 主要用途:将模型或模型的状态字典(state_dict)保存到文件中。 语法: torch.save(obj, f, pi…

需要使用到的函数

在 PyTorch 中,torch.save()torch.load() 是用于保存和加载模型的核心函数。

torch.save() 函数

  • 主要用途:将模型或模型的状态字典(state_dict)保存到文件中。

  • 语法

torch.save(obj, f, pickle_module=pickle, pickle_protocol=None, _use_new_zipfile_serialization=True)
  • obj: 要保存的对象,可以是整个模型(nn.Module)或模型的状态字典(state_dict)。

  • f: 保存文件的路径。可以是一个字符串路径(如 'model.pth''model.pkl')或一个打开的文件对象。

  • pickle_module: 默认是 pickle,用于序列化对象。你可以使用其他兼容的序列化模块。

  • pickle_protocol: pickle 协议版本。默认值为 None,表示使用最高可用协议版本。

  • _use_new_zipfile_serialization: 默认值为 True,控制是否使用新的序列化格式(推荐使用)。

# 保存整个模型
torch.save(model, 'model.pth')# 保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')

torch.load() 函数

  • 主要用途:从文件中加载保存的模型或模型的状态字典。

  • 语法

torch.load(f, map_location=None, pickle_module=pickle)
  • f: 要加载的文件路径。可以是一个字符串路径或一个打开的文件对象。

  • map_location: 控制如何将存储位置映射到当前设备。例如,map_location='cuda:0' 表示将模型加载到 GPU 上。

  • pickle_module: 默认是 pickle,用于反序列化对象。

# 加载整个模型
model = torch.load('model.pth', map_location='cpu')  # 加载到 CPU# 加载模型的状态字典
model_state_dict = torch.load('model_state_dict.pth', map_location='cuda:0')  # 加载到 GPU

加载状态字典到模型

  • 加载状态字典后,通常需要将其加载到一个已经实例化的模型中。可以使用 model.load_state_dict() 方法:

  • 语法

model.load_state_dict(state_dict, strict=True)
  • state_dict: 从文件中加载的模型状态字典。

  • strict: 默认为 True,表示严格加载状态字典中的所有键。如果设置为 False,可以忽略不匹配的键。

# 实例化模型
model = SimpleModel()# 加载状态字典
model_state_dict = torch.load('model_state_dict.pth')# 将状态字典加载到模型中
model.load_state_dict(model_state_dict)

 注意事项

  • 设备映射:使用 torch.load() 时,可以指定 map_location 参数来控制模型加载到的设备(如 CPU 或 GPU)。

  • 自定义类:保存和加载整个模型时,需要确保自定义的模型类在加载代码中已经定义,否则会报错。

  • 兼容性torch.save()torch.load() 使用 pickle 序列化,可能会受到 Python 版本和 PyTorch 版本的影响。建议使用相同版本的 PyTorch 和 Python 进行保存和加载。

  • 推荐使用状态字典:保存和加载状态字典(state_dict)比保存整个模型更灵活和可移植。这样可以避免保存自定义类的依赖关系。

通过以上方法,你可以灵活地保存和加载 PyTorch 模型,无论是 .pth 还是 .pkl 格式,都可以根据需要选择合适的保存方式。

保存和读取.pth格式的预训练模型

保存

import torch
import torch.nn as nn# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleModel()# 假设已经训练了模型,这里只是演示保存
# 保存整个模型
torch.save(model, 'model.pth')
# 或者只保存模型的参数
torch.save(model.state_dict(), 'model_state_dict.pth')

读取

# 如果保存的是整个模型
loaded_model = torch.load('model.pth')
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
model_load.load_state_dict(torch.load('model_state_dict.pth'))
###########################################################################
# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载预训练模型
model = SimpleModel()
model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))# 将模型转移到 GPU
model.to(device)# 示例输入数据
input_data = torch.randn(1, 10).to(device)  # 确保输入数据也在 GPU 上# 前向传播
output = model(input_data)
print(output)

在使用 model.load_state_dict(torch.load('model_state_dict.pth', map_location=device)) 读取模型时,已经指定了 map_location=device,这确保了模型的参数(张量)被加载到指定的设备上。但是,是否还需要调用 model.to(device) 取决于具体的情况。

详细分析

  1. map_location=device 的作用

    • map_location=device 参数用于指定加载的张量应该被放置到哪个设备上。当你加载模型的状态字典时,这个参数确保所有张量(如模型的权重和偏置)被加载到指定的设备(CPU 或 GPU)。

    • 这个参数主要用于处理加载时的设备映射,特别是在加载存储在不同设备上的模型时(例如,从 GPU 上保存的模型加载到 CPU 上或反之)。

2 .model.to(device) 的作用

  • model.to(device) 用于将整个模型(包括模型的参数、缓冲区等)转移到指定的设备上。这是一个递归操作,会遍历模型的所有子模块并将其转移到目标设备。

  • 如果模型在加载时已经将所有张量加载到了正确的设备上(通过 map_location=device),那么调用 model.to(device) 是冗余的,但它不会产生负面影响。

具体情况分析

  • 加载到 CPU: -你在 如果 CPU 上加载模型,并且使用 map_location='cpu',那么模型的张量已经被加载到 CPU 上。在这种情况下,调用 model.to('cpu') 是不必要的,因为模型已经在 CPU 上了。

  • 加载到 GPU

    • 如果你在 GPU 上加载模型,并且使用 map_location='cuda'map_location=device(其中 device 是 GPU),那么模型的张量已经被加载到 GPU 上。但是,模型对象本身(如模型的结构)可能仍然在 CPU 上。

    • 此,调用 model.to(device) 可以确保模型的所有部分(包括模型的结构和参数)都正确地在 GPU 上。

推荐做法

为了确保模型及其所有组成部分都在正确的设备上,建议在加载模型后调用 model.to(device)。这样可以避免潜在的设备不一致问题。

保存和读取.pkl格式的预训练模型

保存

import torch
import torch.nn as nn# 义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleModel()# 保存整个模型
with open('model.pkl', 'wb') as f:torch.save(model, f)
# 或者只保存模型的参数
with open('model_state_dict.pkl', 'wb') as f:torch.save(model.state_dict(), f)

读取

# 如果保存的是整个模型
with open('model.pkl', 'rb') as f:loaded_model = torch.load(f)
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
with open('model_state_dict.pkl', 'rb') as f:model_load.load_state_dict(torch.load(f))

两种格式的区别

  • pth 格式

    • 是 PyTorch 推荐的模型保存格式。它使用 Python 的 pickle 模块来序列化模型对象。对于模型的存储来说,它能够较好地保存和加载模型的结构以及参数。当你想要完整地保存和恢复一个模型的训练状态(包括模型结构、参数、优化器等时),使用.pth 格式很方便。

  • pkl 格式

    • 本质上也是使用 pickle 序列化对象。它是一种通用的 Python 对象序列化格式。在 PyTorch 的早期版本中,pkl 格式被广泛用于保存模型。但是使用 pkl 格式时,可能会受到 Python 版本的限制。因为不同 Python 版本之间,pickle 序列化后的对象在反序列化时可能会出现兼容性问题。例如,你在 Python 3.7 环境下用 pickle 保存了一个模型,然后在 Python 3.8 环境下尝试加载时,可能会因为 pickle 协议版本或者对象结构差异等原因导致加载失败。而.pth 格式会更好地处理这些兼容性问题。

注意事项

  • 当保存整个模型时,如果自定义了模型类,加载模型时也需要提供相同的自定义类定义。否则加载时会出现错误,因为无法识别自定义类的结构。

  • 如果只保存模型参数(state_dict),在加载时必须先实例化一个与保存时相同的模型结构,然后将保存的参数加载到这个结构中。这样可以避免保存自定义类的依赖关系,增加模型的可移植性,但前提是你要清楚地知道模型的结构。


文章转载自:

http://zD0tIEcK.qzcLh.cn
http://VSzsaGPO.qzcLh.cn
http://mIPGtAg9.qzcLh.cn
http://osYmeMDt.qzcLh.cn
http://95tgRyiV.qzcLh.cn
http://yOZl4o2a.qzcLh.cn
http://ETRxPv90.qzcLh.cn
http://GwNV8N9U.qzcLh.cn
http://ZwbjqsPD.qzcLh.cn
http://peA7QFw0.qzcLh.cn
http://biO1FRHY.qzcLh.cn
http://g4GAaoBT.qzcLh.cn
http://MraP94cX.qzcLh.cn
http://a4ZYGVTw.qzcLh.cn
http://ZENFrQ0m.qzcLh.cn
http://hDSE2V41.qzcLh.cn
http://MKq4I8eK.qzcLh.cn
http://kO68FVWD.qzcLh.cn
http://WpmEOZfi.qzcLh.cn
http://GFdwBX30.qzcLh.cn
http://MbBC3cXK.qzcLh.cn
http://My6dSxc9.qzcLh.cn
http://hr38fH9o.qzcLh.cn
http://8IiceI9Y.qzcLh.cn
http://GRZpXpGF.qzcLh.cn
http://StngKynZ.qzcLh.cn
http://wCgpiyJT.qzcLh.cn
http://8cydH9zS.qzcLh.cn
http://0yMsSWPe.qzcLh.cn
http://00gYD8zx.qzcLh.cn
http://www.dtcms.com/wzjs/743920.html

相关文章:

  • 网站制作技术有哪些蚂蚁币是什么网站建设
  • 网站建设好学吗长沙大型网站设计公司
  • 山东建站北京网站名称注册证书
  • 网站侧面的虚浮代码专业做互联网招聘的网站
  • 可以做软件外包项目的网站中信建设有限责任公司薛松
  • 与网站建设关系密切的知识点一般做公司网站需要哪几点
  • php部署网站番禺人才网入库考试
  • 如何更好的建设和维护网站wordpress数据库详解
  • 个人网站做cpa建设部网站官网查询
  • 茶山网站仿做易企秀h5制作官网
  • 图书网站开发数据库的建立怎么提高网站百度权重
  • 阿里云服务器发布网站网站文字广告代码
  • wordpress建站论坛阿里巴巴网站被关闭了要怎么做
  • 电商网站开发主要的三个软件西安seo引擎搜索优化
  • 游戏网站怎么制作郑州推广优化公司
  • 手机怎样创建网站上海营业执照查询网上查询
  • 接网站做项目赚钱吗网站中如何做图片轮播
  • 在线考试系统网站模板做谷歌推广一定要网站吗
  • 酒泉市住房和城乡建设局网站工程建设标准
  • 长沙好的设计公司百度seo搜索引擎优化厂家
  • 机械设备asp企业网站源码下载wordpress plugins权限
  • 沙县建设局网站长春网站排名优化价格
  • 深圳快速网站制甘肃兰州地震最新消息
  • 松江叶榭网站建设化妆品公司网站建设方案
  • 网站免费建设百度自助建站官网
  • wordpress的psd网站优化能发外链的gvm网站大全
  • 宫廷计有哪些网站开发的有诗意的设计公司名字
  • h5网站建设图标外贸专业网站建设
  • 新网站如何做搜索引擎收录网页制作基本代码
  • 做私房蛋糕在哪些网站写东西网站建设开发详细步骤流程