pytorch底层原理学习--JIT与torchscript
文章目录
- 0 目的
- 1 TorchScript
- 1.1 语言特性的限定性
- 1.2 设计目的:模型表达的专注性
- 2pytorch JIT(Just-in-time compilation)
- 2.1pytorch JIT定义
- 2.1pytorch JIT整个过程:
- 1. 前端转换层:生成静态计算图
- 2. 中间表示层(IR):静态计算图
- 3. 优化与编译层
- 4. 执行层
- 3与pt文件的关系
- 3.1核心概念定义与层级关系
- 3.2pt文件构成与pth差异
0 目的
在部署时候需要静态图,静态图在pytorch中也称为Script mode。为了script模式下的计算图表达与优化,引入
两个工具torchscript和pytorch JIT,一个工具将pytroch动态图转为静态图,另一个工具进行静态图优化。
1 TorchScript
TorchScript是Python静态子集。TorchScript和完整Python语言之间最大的区别在于,TorchScript只支持表达神经网络模型所需的一小部分静态类型。TorchScript可以看成一种新的编程语言,设计的目的是为了脱离pytorch,python环境,作为python和其他语言(如c++)的一种中间桥接工具,方便部署。将pytorch写的模型转换成TorchScript语言后的代码,称为中间表示(IR),也称为TorchScript IR,之前在计算图中讨论过,模型可以用计算图(有向无环图)表示,因此TorchScript IR也就是计算图的中间表示。
1.1 语言特性的限定性
- 静态类型约束:
TorchScript 仅支持 Python 中与神经网络模型表达直接相关的有限类型(如Tensor
、Tuple
、int
、List
等),而舍弃了动态类型(如typing.Any
)和复杂控制流。
示例:变量必须声明单一静态类型,禁止运行时类型变更。 - 语法子集化:
仅保留if
/for
等基础控制语句,且需符合静态图编译要求(如循环次数需在编译时可推断)。其他 Python 特性(如动态类继承、反射)被排除。
1.2 设计目的:模型表达的专注性
TorchScript 的语法设计聚焦于高效描述神经网络计算图,而非通用编程。这使其成为:
- 模型序列化载体:脱离 Python 环境后仍可完整表示计算逻辑。
- 编译器友好格式:静态类型便于优化器分析数据依赖与内存布局。
2pytorch JIT(Just-in-time compilation)
2.1pytorch JIT定义
JIT编译器在模型运行时(而非训练时)对代码进行即时编译与优化。在pytorch中JIT编译器它不会将编译的过程一口气完成,而是先对代码进行一些处理,存储成某种序列化表示(比如计算图);然后在实际的运行时环境中,通过 profiling 的方式,进行针对环境的优化并执行代码。
pytorch JIT就是为了解决部署而诞生的工具。包括代码的追踪及解析、中间表示的生成、模型优化、序列化等各种功能,可以说是覆盖了模型部署的方方面面。一方面使用TorchScript 作为python代码的另一种表现形式,一方面对TorchScript IR进行优化。
其核心目标包括:
- 性能提升:通过算子融合、内存复用等优化手段加速推理(部分场景性能提升50%)。
- 部署解耦:脱离Python依赖,支持C++/移动端等非Python环境。
- 硬件适配:针对不同后端(CPU/GPU/TPU)生成优化机器码。
2.1pytorch JIT整个过程:
TorchScript与PyTorch JIT的依赖关系
- TorchScript是JIT的前提:
动态图模型必须首先转换为TorchScript IR,才能被JIT编译器优化。 - JIT赋予TorchScript执行能力:
TorchScript IR需通过JIT编译为机器码,否则仅为静态数据结构。
1. 前端转换层:生成静态计算图
转换方式 | 原理 | 适用场景 |
---|---|---|
Tracing | 记录示例输入下的张量操作轨迹,生成IR(无法捕获分支/循环) | 无控制流的简单模型(如CNN) |
Scripting | 解析Python源码,直接编译为TorchScript(支持条件分支) | 含动态逻辑的模型(如RNN) |
-
1. 追踪模式(Tracing)
-
原理:
输入示例数据(如dummy_input
),记录模型前向传播的算子调用序列 → 生成线性计算图dummy_input = torch.rand(1, 3, 224, 224) jit_model = torch.jit.trace(model, dummy_input) # 生成IR
-
局限性:
无法捕获条件分支(如if x>0
)或循环(如for i in range(n)
),仅适合无控制流模型 -
2. 脚本模式(Scripting)
-
原理:
直接解析Python源码 → 词法分析(Lexer)→ 语法树(AST)→ 语义分析 → 生成带控制流的IRclass DynamicModel(nn.Module):def forward(self, x):if x.sum() > 0: return x * 2 # 分支逻辑可被保留else: return x / 2 jit_model = torch.jit.script(DynamicModel()) # 直接编译源码
2. 中间表示层(IR):静态计算图
IR数据结构
基于有向无环图(DAG,这个词在计算图中出现过):
- Graph:顶级容器,表示一个函数(如
forward()
) - Block:基本块(Basic Block),包含有序的Node序列
- Node:算子节点(如
aten::conv2d
),含输入/输出Value - Value:数据流边,具有静态类型(如
Tensor
/int
) - 属性(Attributes) :存储常量(如权重张量)
- 数据结构:,包含
Graph
(函数)、Node
(算子)、Value
(数据流)。 - 关键特性:
- 类型静态化:所有变量需明确定义类型(如
Tensor
/int
),舍弃Python动态类型。 - 控制流显式化:将
if
/for
转换为图节点(如prim::If
)。
- 类型静态化:所有变量需明确定义类型(如
3. 优化与编译层
- 图优化(Graph Optimization):
- 算子融合:合并相邻算子(如
conv2d + relu → conv2d_relu
),减少内核启动开销。 - 常量折叠:预计算静态表达式(如
a=2; b=3; c=a*b
→c=6
)。
- 算子融合:合并相邻算子(如
- 硬件后端适配:
- NVFuser:默认GPU优化器,针对NVIDIA显卡生成高效CUDA核。
- CPU优化:利用OpenMP加速并行计算。
4. 执行层
- 轻量级解释器:执行优化后的IR,无全局锁(GIL),支持多线程并发。
- 运行时剖析(Profiling) :动态收集执行数据,反馈至编译器迭代优化(如热点代码重编译)。
3与pt文件的关系
将jit编译优化后的模型进行保存,下一步就可以在C++上进行部署了。
序列化(Serialization) 是将程序中的对象(如模型参数、计算图、张量等)转换为可存储或传输的标准化格式(如字节流、文件)的过程,而 反序列化(Deserialization) 则是将存储的格式还原为内存中的对象。因此保存模型的save函数执行的就是序列化。
torch.jit.save('model.pt')
3.1核心概念定义与层级关系
PyTorch JIT保存的.pt
文件、PyTorch JIT编译器与TorchScript三者构成模型部署的核心技术栈,其关系可通过以下分层架构与技术流程详解:
组件 | 本质 | 角色定位 |
---|---|---|
TorchScript | Python的静态类型子集(IR中间表示) | 模型表达层:定义可编译的模型结构 |
PyTorch JIT | 运行时编译器(Just-In-Time Compiler) | 优化执行层:将IR编译为高效机器码 |
.pt文件 | TorchScript模块的序列化格式(ZIP归档) | 持久化层:存储模型结构与参数 |
三者关系可概括为:
**TorchScript提供标准化模型表示 → PyTorch JIT进行运行时优化编译 → .pt文件实现跨平台持久化
3.2pt文件构成与pth差异
-
文件结构(ZIP归档格式):
model.pt ├── code/ # 优化后的TorchScript IR(计算图) ├── data.pkl # 模型权重(张量数据) ├── constants.pkl # 嵌入的常量(如超参数) └── version # 格式版本号
-
序列化方法:
# 保存到磁盘文件 torch.jit.save(traced_model, "model.pt") # 或 traced_model.save("model.pt")# 保存到内存缓冲区(适用于网络传输) buffer = io.BytesIO() torch.jit.save(traced_model, buffer)
pt文件 vs 普通PyTorch模型文件
特性 | JIT生成的.pt文件 | torch.save()保存的.pth文件 |
---|---|---|
内容 | 完整计算图 + 参数 + 优化后的IR | 仅参数(state_dict)或Python类引用 |
可移植性 | 脱离Python环境(支持C++/移动端) | 依赖原始Python模型类定义 |
执行引擎 | JIT编译器优化后的本地代码 | Python解释器执行 |
反编译风险 | 代码以IR存储,难以还原原始Python逻辑 | 可直接查看模型类代码 |
PyTorch JIT保存的.pt
文件、PyTorch JIT编译器与TorchScript,三者协同构成PyTorch生产部署的核心基础设施,覆盖从研发到落地的完整生命周期。
参考
[1](TorchScript — PyTorch 2.7 documentation)
[3](PyTorch JIT and TorchScript. A path to production for PyTorch models | by Abhishek Sharma | TDS Archive | Medium)
[4]((8 封私信 / 5 条消息) TorchScript 解读(一):初识 TorchScript - 知乎)
[5]((8 封私信 / 5 条消息) PyTorch系列「一」PyTorch JIT —— trace/ script的代码组织和优化方法 - 知乎)
[6](PyTorch JIT | Chenglu’s Log)
[7](PyTorch Architecture | harleyszhang/llm_note | DeepWiki)
[8](TorchScript for Deployment — PyTorch Tutorials 2.7.0+cu126 documentation)
[9](Loading a TorchScript Model in C++ — PyTorch Tutorials 2.7.0+cu126 documentation)
[10](Introduction to TorchScript — PyTorch Tutorials 2.7.0+cu126 documentation)
[11]((8 封私信 / 5 条消息) 什么是torch.jit - 知乎)
[12]((8 封私信 / 5 条消息) 一文带你使用即时编译(JIT)提高 PyTorch 模型推理性能! - 知乎)
[13](TorchScript 解读(二):Torch jit tracer 实现解析 - OpenMMLab的文章 - 知乎
https://zhuanlan.zhihu.com/p/489090393)
[14](Pytorch代码部署:总结使用JIT将PyTorch模型转换为TorchScript格式踩过的那些坑 - Ta没有名字的文章 - 知乎
https://zhuanlan.zhihu.com/p/662228796)
[15](TorchScript JIT & IR - 灵丹的文章 - 知乎
https://zhuanlan.zhihu.com/p/543952666)
[16](TorchScript的简介 - PyTorch官方教程中文版)
《深度学习编译器设计第五章:中间表示》
《PRINCIPLED OPTIMIZATION OF DYNAMIC NEURAL NETWORKS》. JARED ROESCH