当前位置: 首页 > news >正文

PyTorch 模型文件介绍

在 PyTorch 中,保存和加载模型是训练流程中的关键环节。主要涉及以下几种文件类型和概念:

1. .pt/ .pth文件 (最常见)

  • 本质:​​ 这些是 PyTorch 使用 Python 的 pickle模块序列化 Python 对象后保存的文件。扩展名 .pt.pth是约定俗成的,PyTorch 本身没有强制要求,但强烈推荐使用它们。

  • 保存内容:​

    • 模型的状态字典 (state_dict):​​ 这是最推荐的保存方式。state_dict是一个 Python 字典对象,它将模型的每一层映射到其可学习参数(权重 weight、偏置 bias等)的 Tensor。它不包含模型的结构定义(类代码)。

      torch.save(model.state_dict(), 'model_state_dict.pt')
    • 整个模型:​​ 可以直接保存整个模型对象(包括结构和参数)。这种方式不推荐,因为它依赖于特定的 Python 环境、类定义和文件路径,导致代码难以移植且可能在不同环境或 PyTorch 版本中出错。

      torch.save(model, 'entire_model.pt') # 不推荐
  • 加载方式:​

    • 加载 state_dict:​​ 需要先实例化模型结构(类),然后将 state_dict加载到该实例中。

      model = MyModelClass(*args, **kwargs) # 1. 创建相同结构的模型实例
      model.load_state_dict(torch.load('model_state_dict.pt')) # 2. 加载参数
      model.eval() # 3. 设置为评估模式(影响 dropout, batchnorm 等层)
    • 加载整个模型:​​ 直接加载即可,但存在上述限制。

      model = torch.load('entire_model.pt') # 不推荐
      model.eval()
  • 优点 (state_dict方式):​

    • 文件较小(只保存参数)。

    • 代码更灵活、可移植。模型类定义可以独立修改(只要结构匹配),方便在不同项目或脚本间共享参数。

    • 是保存和加载模型的标准做法

  • 缺点 (整个模型方式):​

    • 文件较大(包含结构信息)。

    • 严重依赖保存时的具体环境(类定义、导入路径等),难以复用。

    • 在不同 PyTorch 版本间可能不兼容。

  • 警告:​pickle模块可能存在安全风险。​只加载你信任的来源的 .pt/.pth文件!​

2. .zip文件 (TorchScript)

  • 本质:​​ 当使用 torch.jit.save()保存 TorchScript 模型时,默认生成一个 .zip文件(虽然也可以指定 .pt.pth,但 .zip是标准输出)。

  • 保存内容:​​ TorchScript 是一种 PyTorch 模型的表示形式,它可以在脱离 Python 环境的情况下被高性能 C++ 运行时(torch::jit::load)或 Python 运行时(torch.jit.load)加载和执行。它包含了模型的结构(计算图)​​ 和参数

  • 生成方式:​

    • 追踪 (torch.jit.trace):​​ 用一个示例输入“运行”模型,记录执行的操作。适用于没有控制流(如 if/for)的模型。

      example_input = torch.rand(1, 3, 224, 224)
      traced_model = torch.jit.trace(model, example_input)
      torch.jit.save(traced_model, 'traced_model.zip')
    • 脚本化 (torch.jit.script):​​ 直接解析模型代码(或部分代码)生成 TorchScript。适用于包含控制流的模型。

      scripted_model = torch.jit.script(model)
      torch.jit.save(scripted_model, 'scripted_model.zip')
  • 加载方式:​

    # Python 中加载
    loaded_model = torch.jit.load('traced_model.zip')
    loaded_model.eval()
    output = loaded_model(torch.rand(1, 3, 224, 224))
    // C++ 中加载 (示例)
    #include <torch/script.h>
    torch::jit::script::Module module;
    module = torch::jit::load("traced_model.zip");
    module.eval();
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({1, 3, 224, 224}));
    at::Tensor output = module.forward(inputs).toTensor();
  • 优点:​

    • 跨平台/语言:​​ 可以在 Python 和 C++ 中加载运行。

    • 独立于 Python:​​ 不需要原始的 Python 模型类定义即可运行(在 C++ 中尤其重要)。

    • 序列化优化:​​ 针对部署进行了优化。

    • 模型保护:​​ 原始 Python 代码不是必需的(虽然可以反编译,但增加了难度)。

  • 缺点:​

    • 生成过程可能复杂(尤其对于动态模型)。

    • 可能不完全支持所有 Python 特性(需要调整模型代码)。

    • 调试 TorchScript 有时比调试纯 Python 模型困难。

3. .onnx文件 (ONNX 格式)

  • 本质:​​ Open Neural Network Exchange (ONNX) 是一种开放标准的格式,用于表示深度学习模型。它定义了一个通用的计算图表示。

  • 保存内容:​​ 模型的结构(计算图)​​ 和参数

  • 生成方式:​​ 使用 PyTorch 的 torch.onnx.export函数将 PyTorch 模型转换为 ONNX 格式。

    torch.onnx.export(model,               # 要转换的模型torch.rand(1, 3, 224, 224), # 示例输入"model.onnx",        # 输出文件名input_names=["input"], # 输入节点名称output_names=["output"], # 输出节点名称opset_version=11)    # ONNX 算子集版本
  • 加载方式:​​ ONNX 模型本身不能直接在 PyTorch 中运行(除非使用 ONNX Runtime 的 PyTorch 绑定)。它主要用于:

    • 导入到其他支持 ONNX 的深度学习框架(如 TensorFlow, MXNet, Caffe2)。

    • 使用专门的 ONNX 运行时进行推理(如 ONNX Runtime, TensorRT),这些运行时通常针对不同硬件做了高度优化。

    • 使用工具进行模型可视化、优化或格式转换。

  • 优点:​

    • 框架互操作性:​​ 实现不同深度学习框架之间模型的转换和共享。

    • 硬件供应商支持:​​ 许多硬件加速器(如 NVIDIA TensorRT, Intel OpenVINO)优先支持或优化 ONNX 模型。

    • 标准化:​​ 统一的模型表示格式。

  • 缺点:​

    • 转换过程可能存在精度损失或算子不支持的问题(需要检查转换日志)。

    • ONNX 标准本身在不断发展,不同版本间可能有兼容性问题。

    • 在 PyTorch 中不能直接加载运行 ONNX 模型进行训练或微调(主要用于推理或迁移到其他框架)。

总结与推荐

  1. 日常训练/研究 (PyTorch 环境内):​

    • 保存:​​ 使用 torch.save(model.state_dict(), 'model.pt')。这是最标准、最灵活的方式。

    • 加载:​​ 实例化模型结构 + model.load_state_dict(torch.load('model.pt'))+ model.eval()

  2. 生产部署 (脱离 Python 或 C++ 环境):​

    • 首选:​​ ​TorchScript (.zip.pt)​。它是 PyTorch 官方的部署方案,支持 Python 和 C++,优化良好。

    • 备选/互操作:​​ ​ONNX (.onnx)​。当目标平台(如特定硬件加速器)或框架(如 TensorFlow Serving)对 ONNX 有更好支持时使用。需要额外的运行时(ONNX Runtime, TensorRT 等)。

  3. 避免:​​ 保存整个模型对象 (torch.save(model, ...)),除非有非常特殊且理解其风险的原因。

选择哪种格式取决于你的具体需求:在 PyTorch 内部继续工作就用 state_dict;需要部署到非 Python 环境或 C++ 就用 TorchScript;需要与其他框架或特定硬件加速器交互就用 ONNX。


文章转载自:

http://2IDrBqpQ.wyjhq.cn
http://nSSnrLcL.wyjhq.cn
http://Q12BlSlQ.wyjhq.cn
http://yaLnpear.wyjhq.cn
http://Gs8EmQrV.wyjhq.cn
http://2BRqH3ov.wyjhq.cn
http://TA9cpjyi.wyjhq.cn
http://2EmMRb3N.wyjhq.cn
http://9u5JWiTT.wyjhq.cn
http://nB0Kdy12.wyjhq.cn
http://tWo0QY7U.wyjhq.cn
http://CXx2I6HK.wyjhq.cn
http://GMhjWU9O.wyjhq.cn
http://eUaijumw.wyjhq.cn
http://4JbSEsqs.wyjhq.cn
http://oiFzazRV.wyjhq.cn
http://VgPq2FZw.wyjhq.cn
http://IYixRUcF.wyjhq.cn
http://qHRAroqj.wyjhq.cn
http://fCv8MFZA.wyjhq.cn
http://WhTxrOkd.wyjhq.cn
http://pGxUCXZl.wyjhq.cn
http://3WWI8m3Q.wyjhq.cn
http://uOHRSEx9.wyjhq.cn
http://ykEvXRlC.wyjhq.cn
http://Xlvlw3vz.wyjhq.cn
http://pcYzVH7d.wyjhq.cn
http://ixDqSBB5.wyjhq.cn
http://ifEnxU5P.wyjhq.cn
http://dGrtZRFn.wyjhq.cn
http://www.dtcms.com/a/370249.html

相关文章:

  • Valgrind检测内存泄漏入门指南
  • echarts实现点击图表添加标记
  • Python带状态生成器完全指南:从基础到高并发系统设计
  • python入门常用知识
  • 【算法】92.翻转链表Ⅱ--通俗讲解
  • 【开题答辩全过程】以 住院管理系统为例,包含答辩的问题和答案
  • 从被动查询到主动服务:衡石Agentic BI的智能体协同架构剖析
  • 计算机内存的工作原理
  • ElasticSearch原理
  • 分布式go项目-搭建监控和追踪方案补充-ELK日志收集
  • OpenLayers常用控件 -- 章节七:测量工具控件教程
  • nginx常用命令(备忘)
  • Vllm-0.10.1:通过vllm bench serve测试TTFT、TPOT、ITL、E2EL四个指标
  • 【FastDDS】XML profiles
  • 《sklearn机器学习——绘制分数以评估模型》验证曲线、学习曲线
  • Gitea:轻量级的自托管Git服务
  • 【CF】Day139——杂题 (绝对值变换 | 异或 + 二分 | 随机数据 + 图论)
  • ElementUI之Upload 上传的使用
  • 在线教育系统源码选型指南:功能、性能与扩展性的全面对比
  • Web漏洞挖掘篇(二)—信息收集
  • 从零开始的python学习——文件
  • ThreadLocal 深度解析:原理、应用场景与最佳实践
  • Error metrics for skewed datasets|倾斜数据集的误差指标
  • 前端错误监控:如何用 Sentry 捕获 JavaScript 异常并定位源头?
  • 9.6 前缀和
  • 快捷:常见ocr学术数据集预处理版本汇总(适配mmocr)
  • Linux系统检测硬盘失败解救方法
  • 内网后渗透攻击--linux系统(横向移动)
  • 【软考架构】第二章 计算机系统基础知识:计算机网络
  • equals 定义不一致导致list contains错误