当前位置: 首页 > news >正文

pytorch底层原理学习--JIT与torchscript

文章目录

    • 0 目的
    • 1 TorchScript
      • 1.1 语言特性的限定性
      • 1.2 设计目的:模型表达的专注性
    • 2pytorch JIT(Just-in-time compilation)
      • 2.1pytorch JIT定义
      • 2.1pytorch JIT整个过程:
        • 1. 前端转换层:生成静态计算图
        • 2. 中间表示层(IR):静态计算图
        • 3. 优化与编译层
        • 4. 执行层
    • 3与pt文件的关系
      • 3.1核心概念定义与层级关系
      • 3.2pt文件构成与pth差异

0 目的

在部署时候需要静态图,静态图在pytorch中也称为Script mode。为了script模式下的计算图表达与优化,引入

两个工具torchscript和pytorch JIT,一个工具将pytroch动态图转为静态图,另一个工具进行静态图优化。


1 TorchScript

TorchScript是Python静态子集。TorchScript和完整Python语言之间最大的区别在于,TorchScript只支持表达神经网络模型所需的一小部分静态类型。TorchScript可以看成一种新的编程语言,设计的目的是为了脱离pytorch,python环境,作为python和其他语言(如c++)的一种中间桥接工具,方便部署。将pytorch写的模型转换成TorchScript语言后的代码,称为中间表示(IR),也称为TorchScript IR,之前在计算图中讨论过,模型可以用计算图(有向无环图)表示,因此TorchScript IR也就是计算图的中间表示。

1.1 语言特性的限定性

  • 静态类型约束
    TorchScript 仅支持 Python 中与神经网络模型表达直接相关的有限类型(如 TensorTupleintList 等),而舍弃了动态类型(如 typing.Any)和复杂控制流。
    示例:变量必须声明单一静态类型,禁止运行时类型变更。
  • 语法子集化
    仅保留 if/for 等基础控制语句,且需符合静态图编译要求(如循环次数需在编译时可推断)。其他 Python 特性(如动态类继承、反射)被排除。

1.2 设计目的:模型表达的专注性

TorchScript 的语法设计聚焦于高效描述神经网络计算图,而非通用编程。这使其成为:

  • 模型序列化载体:脱离 Python 环境后仍可完整表示计算逻辑。
  • 编译器友好格式:静态类型便于优化器分析数据依赖与内存布局。

2pytorch JIT(Just-in-time compilation)

2.1pytorch JIT定义

JIT编译器在模型运行时(而非训练时)对代码进行即时编译与优化。在pytorch中JIT编译器它不会将编译的过程一口气完成,而是先对代码进行一些处理,存储成某种序列化表示(比如计算图);然后在实际的运行时环境中,通过 profiling 的方式,进行针对环境的优化并执行代码。

在这里插入图片描述

pytorch JIT就是为了解决部署而诞生的工具。包括代码的追踪及解析、中间表示的生成、模型优化、序列化等各种功能,可以说是覆盖了模型部署的方方面面。一方面使用TorchScript 作为python代码的另一种表现形式,一方面对TorchScript IR进行优化。

其核心目标包括:

  • 性能提升:通过算子融合、内存复用等优化手段加速推理(部分场景性能提升50%)。
  • 部署解耦:脱离Python依赖,支持C++/移动端等非Python环境。
  • 硬件适配:针对不同后端(CPU/GPU/TPU)生成优化机器码。

2.1pytorch JIT整个过程:

在这里插入图片描述

TorchScript与PyTorch JIT的依赖关系

  • TorchScript是JIT的前提
    动态图模型必须首先转换为TorchScript IR,才能被JIT编译器优化。
  • JIT赋予TorchScript执行能力
    TorchScript IR需通过JIT编译为机器码,否则仅为静态数据结构。
1. 前端转换层:生成静态计算图
转换方式原理适用场景
Tracing记录示例输入下的张量操作轨迹,生成IR(无法捕获分支/循环)无控制流的简单模型(如CNN)
Scripting解析Python源码,直接编译为TorchScript(支持条件分支)含动态逻辑的模型(如RNN)
  • 1. 追踪模式(Tracing)

  • 原理
    输入示例数据(如dummy_input),记录模型前向传播的算子调用序列 → 生成线性计算图

    dummy_input = torch.rand(1, 3, 224, 224)
    jit_model = torch.jit.trace(model, dummy_input)  # 生成IR
    
  • 局限性
    无法捕获条件分支(如if x>0)或循环(如for i in range(n)),仅适合无控制流模型

  • 2. 脚本模式(Scripting)

  • 原理
    直接解析Python源码 → 词法分析(Lexer)→ 语法树(AST)→ 语义分析 → 生成带控制流的IR

    class DynamicModel(nn.Module):def forward(self, x):if x.sum() > 0: return x * 2  # 分支逻辑可被保留else: return x / 2
    jit_model = torch.jit.script(DynamicModel())  # 直接编译源码
    
2. 中间表示层(IR):静态计算图

IR数据结构

基于有向无环图(DAG,这个词在计算图中出现过):

  • Graph:顶级容器,表示一个函数(如forward()
  • Block:基本块(Basic Block),包含有序的Node序列
  • Node:算子节点(如aten::conv2d),含输入/输出Value
  • Value:数据流边,具有静态类型(如Tensor/int
  • 属性(Attributes) :存储常量(如权重张量)
表示算子
输入/输出
管理节点集合
Graph
Node
Block
Value
Operator
Tensor
  • 数据结构:,包含Graph(函数)、Node(算子)、Value(数据流)。
  • 关键特性
    • 类型静态化:所有变量需明确定义类型(如Tensor/int),舍弃Python动态类型。
    • 控制流显式化:将if/for转换为图节点(如prim::If)。
3. 优化与编译层
  • 图优化(Graph Optimization)
    • 算子融合:合并相邻算子(如conv2d + relu → conv2d_relu),减少内核启动开销。
    • 常量折叠:预计算静态表达式(如a=2; b=3; c=a*bc=6)。
  • 硬件后端适配
    • NVFuser:默认GPU优化器,针对NVIDIA显卡生成高效CUDA核。
    • CPU优化:利用OpenMP加速并行计算。
4. 执行层
  • 轻量级解释器:执行优化后的IR,无全局锁(GIL),支持多线程并发。
  • 运行时剖析(Profiling) :动态收集执行数据,反馈至编译器迭代优化(如热点代码重编译)。

3与pt文件的关系

将jit编译优化后的模型进行保存,下一步就可以在C++上进行部署了。

序列化(Serialization) 是将程序中的对象(如模型参数、计算图、张量等)转换为可存储或传输的标准化格式(如字节流、文件)的过程,而 反序列化(Deserialization) 则是将存储的格式还原为内存中的对象。因此保存模型的save函数执行的就是序列化。

torch.jit.save('model.pt')

3.1核心概念定义与层级关系

PyTorch JIT保存的.pt文件、PyTorch JIT编译器与TorchScript三者构成模型部署的核心技术栈,其关系可通过以下分层架构与技术流程详解:

组件本质角色定位
TorchScriptPython的静态类型子集(IR中间表示)模型表达层:定义可编译的模型结构
PyTorch JIT运行时编译器(Just-In-Time Compiler)优化执行层:将IR编译为高效机器码
.pt文件TorchScript模块的序列化格式(ZIP归档)持久化层:存储模型结构与参数

三者关系可概括为:
**TorchScript提供标准化模型表示 → PyTorch JIT进行运行时优化编译 → .pt文件实现跨平台持久化

3.2pt文件构成与pth差异

  • 文件结构(ZIP归档格式):

    model.pt
    ├── code/         # 优化后的TorchScript IR(计算图)
    ├── data.pkl      # 模型权重(张量数据)
    ├── constants.pkl # 嵌入的常量(如超参数)
    └── version       # 格式版本号
    
  • 序列化方法

    # 保存到磁盘文件
    torch.jit.save(traced_model, "model.pt")  # 或 traced_model.save("model.pt")# 保存到内存缓冲区(适用于网络传输)
    buffer = io.BytesIO()
    torch.jit.save(traced_model, buffer)
    

pt文件 vs 普通PyTorch模型文件

特性JIT生成的.pt文件torch.save()保存的.pth文件
内容完整计算图 + 参数 + 优化后的IR仅参数(state_dict)或Python类引用
可移植性脱离Python环境(支持C++/移动端)依赖原始Python模型类定义
执行引擎JIT编译器优化后的本地代码Python解释器执行
反编译风险代码以IR存储,难以还原原始Python逻辑可直接查看模型类代码

PyTorch JIT保存的.pt文件、PyTorch JIT编译器与TorchScript,三者协同构成PyTorch生产部署的核心基础设施,覆盖从研发到落地的完整生命周期。

参考

[1](TorchScript — PyTorch 2.7 documentation)

[3](PyTorch JIT and TorchScript. A path to production for PyTorch models | by Abhishek Sharma | TDS Archive | Medium)

[4]((8 封私信 / 5 条消息) TorchScript 解读(一):初识 TorchScript - 知乎)

[5]((8 封私信 / 5 条消息) PyTorch系列「一」PyTorch JIT —— trace/ script的代码组织和优化方法 - 知乎)

[6](PyTorch JIT | Chenglu’s Log)

[7](PyTorch Architecture | harleyszhang/llm_note | DeepWiki)

[8](TorchScript for Deployment — PyTorch Tutorials 2.7.0+cu126 documentation)

[9](Loading a TorchScript Model in C++ — PyTorch Tutorials 2.7.0+cu126 documentation)

[10](Introduction to TorchScript — PyTorch Tutorials 2.7.0+cu126 documentation)

[11]((8 封私信 / 5 条消息) 什么是torch.jit - 知乎)

[12]((8 封私信 / 5 条消息) 一文带你使用即时编译(JIT)提高 PyTorch 模型推理性能! - 知乎)

[13](TorchScript 解读(二):Torch jit tracer 实现解析 - OpenMMLab的文章 - 知乎
https://zhuanlan.zhihu.com/p/489090393)

[14](Pytorch代码部署:总结使用JIT将PyTorch模型转换为TorchScript格式踩过的那些坑 - Ta没有名字的文章 - 知乎
https://zhuanlan.zhihu.com/p/662228796)

[15](TorchScript JIT & IR - 灵丹的文章 - 知乎
https://zhuanlan.zhihu.com/p/543952666)

[16](TorchScript的简介 - PyTorch官方教程中文版)
《深度学习编译器设计第五章:中间表示》
《PRINCIPLED OPTIMIZATION OF DYNAMIC NEURAL NETWORKS》. JARED ROESCH

http://www.dtcms.com/a/264680.html

相关文章:

  • 开机自动后台运行,在Windows服务中托管ASP.NET Core
  • 企业培训笔记:SpringBoot+MyBatis项目中实现分页查询
  • GraphPrompts:图神经网络领域的提示工程范式革新者
  • 学习笔记(28):随机噪声的原理、作用及代码实现详解
  • CC - Link IE转EtherCAT:石油石化软启动器的“最佳搭子”
  • 电商项目实例:基于Python京东商品API接口数据采集
  • 跨越传统界限:ChatGPT+ENVI/Python/GEE集成实战,覆盖无人机遥感、深度学习、洪水监测、矿物识别填图、土壤含水量评估等
  • 【Web前端】优化轮播图展示(源代码)
  • MDK(Keil MDK)工具链
  • cmake find_package
  • C++ 创建动态库及两种方法调用动态库
  • DINO 浅析
  • 医学+AI教育实践!南医大探索数据挖掘人才培养,清华指导发布AI教育白皮书
  • HarmonyOS应用开发高级认证知识点梳理 (四)状态管理V2应用级状态
  • AutoGen-AgentChat-1-整体了解
  • NestJS 系列教程(一):认识 NestJS 与项目初始化
  • RabbitMQ 高级特性之持久性
  • OpenCV仿射变换详解
  • 【飞算JavaAI】智能开发助手赋能Java领域,飞算JavaAI全方位解析
  • 红海云签约东莞科创金融集团,科创金融行业人力资源数字化
  • 论文阅读笔记——VGGT: Visual Geometry Grounded Transformer
  • 50天50个小项目 (Vue3 + Tailwindcss V4) ✨ | ButtonRippleEffect(按钮涟漪效果)
  • 基于[coze][dify]搭建一个智能体工作流,使用第三方插件抓取热门视频数据,自动存入在线表格
  • Node.js-http模块
  • mac Maven配置报错The JAVA_HOME environment variable is not defined correctly的解决方法
  • 21、企业行政办公(OA)数字化转型:系统如何重塑企业高效运营新范式
  • Android Native 之 inputflinger进程分析
  • 硬件选型与组网规划S7-300以太网模块适配性与网络架构搭建
  • 学习笔记(27):线性回归基础与实战:从原理到应用的简易入门
  • 利器:NPM和YARN及其他