TVM | Relay
在 TVM Relay 的设计中,“数据结构类”“容器类”“工具类” 三类模块形成了一套完整的 “模型表示→处理→执行” 体系。它们既各司其职,又通过紧密协作支撑起从前端模型解析(如 ONNX 转换)到后端高效执行的全流程。以下从三类模块的核心定位、组件细节及协同关系展开分析:
一、数据结构类:模型的 “原子组件”,定义计算的基础规则
数据结构类是 Relay 描述计算逻辑的 “最小单元”,规定了 “计算什么”“数据是什么形态”“如何引用组件”,是构建模型的基础。
relay.Expr
:计算逻辑的最小表达式
- 定位:所有计算操作的 “原子描述”,是构建复杂计算的基础积木。
- 具体形态:
Var
:表示变量(如输入数据x
、中间结果),是计算的 “载体”;Const
:表示常量(如权重、偏置),存储固定值;Call
:表示算子调用(如relay.nn.conv2d(x, w)
),将算子(Op
)与输入(Var
/Const
)绑定,描述一次具体计算;- 控制流表达式(
If
/Loop
):支持动态逻辑(如 “若输入大于阈值则执行 A 分支,否则执行 B 分支”)。
- 特性:不可变(纯函数式),确保计算逻辑的可追溯性,便于静态分析(如优化工具可安全地修改表达式而不引发副作用)。
relay.Type
:数据的 “形态说明书”
- 定位:定义
Expr
的数据类型和形状,确保计算的合法性(如 “卷积输入必须是 4 维张量”),同时为优化提供依据(如 “根据输出形状调整内存分配”)。 - 核心类型:
TensorType(shape, dtype)
:描述张量的形状(如(1, 3, 224, 224)
)和数据类型(如float32
);FuncType
:描述Function
的输入输出类型(如 “输入为 2 个 3 维张量,输出为 1 个 2 维张量”);TupleType
:描述多输出的类型(如模型同时输出特征和损失时,明确每个输出的形态)。
- 作用:通过静态类型约束避免计算错误(如 “不能对整数张量调用 ReLU”),同时为形状优化(如自动调整卷积核布局)提供关键信息。
relay.Op
:预定义的 “计算单元”
- 定位:封装基础计算逻辑(如卷积、池化、激活),是
Call
表达式的 “核心功能体”。 - 特性:
- 每个
Op
有唯一标识(如nn.conv2d
)、属性(如strides
/pads
)和后端实现(C++ 代码定义计算逻辑和硬件调度); - 支持扩展:通过
relay.op.register_op
注册自定义算子(需实现计算逻辑和硬件适配)。
- 每个
- 与
Expr
的关系:Call
表达式通过引用Op
来 “调用” 具体计算(如Call(op=Op("nn.relu"), args=[x])
表示对x
应用 ReLU 激活)。
relay.GlobalVar
:跨函数的 “唯一标识”
- 定位:为
Function
提供全局唯一的 “符号名称”,实现函数间的引用(如主函数调用子函数)。 - 示例:若子函数
res_block
通过gv = GlobalVar("res_block")
标识,则主函数可直接用gv(x)
调用该子函数,无需关心其具体实现位置。 - 作用:解决复杂模型中函数复用和依赖管理的问题(如 ResNet 中残差块被多次调用,通过
GlobalVar
统一标识)。
二、容器类:模型的 “组织框架”,整合计算逻辑为可处理单元
容器类的作用是将数据结构类的 “原子组件” 组装成 “可管理、可优化、可执行” 的整体,如同建筑中的 “房间和楼层”,将砖瓦组织为完整的建筑。
relay.Function
:封装完整计算逻辑的 “函数单元”
- 定位:将
Expr
组成的计算逻辑(如 “输入→卷积→激活→输出”)封装为带类型的可复用函数,是优化和代码生成的最小单位。 - 核心结构:
params
:输入参数列表(Var
),每个参数绑定Type
(如x: TensorType(...)
);body
:计算体(由Expr
组成的表达式树,描述输入到输出的映射);ret_type
:返回值类型(由Type
系统自动推断或手动指定)。
- 特性:
- 支持模块化:一个复杂模型可拆分为多个
Function
(如特征提取函数、分类头函数); - 可被
GlobalVar
标识,供其他函数引用(如主函数调用子函数)。
- 支持模块化:一个复杂模型可拆分为多个
relay.Module
:管理多个Function
的 “模型容器”
- 定位:整合模型中所有
Function
(主函数 + 子函数),是模型整体优化、序列化和编译的顶层单位。 - 核心结构:
functions
:GlobalVar
到Function
的映射表(如{gv_main: main_func, gv_res: res_func}
);metadata
:模型元信息(如输入输出名称、版本号)。
- 特性:
- 解决函数依赖:通过
GlobalVar
自动解析函数间的引用关系(如主函数调用的子函数必须在Module
中存在); - 支持序列化:可保存为文件(如
.json
),加载时完整恢复所有函数和依赖,便于部署。
- 解决函数依赖:通过
三、工具类:模型的 “处理流水线”,连接表示与执行
工具类是操作容器类和数据结构类的 “引擎”,负责分析模型特性、优化计算逻辑、生成硬件可执行代码,如同建筑中的 “施工设备”,将设计图纸(模型表示)转化为实际建筑(可运行代码)。
relay.analysis
:模型的 “体检工具”
- 定位:对
Module
/Function
进行静态分析,提取关键信息(如类型、依赖、计算成本),为优化提供依据。 - 核心功能:
type_infer
:自动推断Expr
和Function
的类型(如根据输入形状和卷积参数,推导输出形状);free_vars
:分析表达式中未绑定的变量(避免计算逻辑漏洞);cost_model
:估算算子 / 函数的计算量和内存访问成本(指导优化策略选择)。
relay.transform
:模型的 “优化工具”
- 定位:通过 “优化 Pass” 修改
Module
/Function
的结构,提升执行效率(如减少计算量、降低内存访问)。 - 核心 Pass:
FuseOps
:算子融合(将连续的小算子 “合并” 为一个大算子,减少中间结果的内存读写);ConstantFold
:常量折叠(提前计算可确定的常量表达式,如2+3
直接替换为5
,减少运行时计算);AlterOpLayout
:布局转换(根据硬件特性调整张量布局,如 NHWC→NCHW,适配 GPU 的内存访问模式)。
- 工作方式:通过
Sequential
组合多个 Pass,按顺序作用于Module
(如先折叠常量,再融合算子,最后调整布局)。
from tvm.relay.transform import Sequential, FuseOps, ConstantFoldpasses = Sequential([FuseOps(), ConstantFold()])
optimized_mod = passes(mod) # 对模块执行优化
- 与其他模块的关系:Pass 通过分析
Expr
结构(如Call
调用的Op
)、利用Type
信息(如形状),对Function
进行修改,最终更新Module
。
relay.build_module
:模型的 “编译工具”
- 定位:将优化后的
Module
转换为目标硬件(CPU/GPU/FPGA)可执行的代码,是连接 Relay IR 与后端执行的桥梁。 - 核心流程:
- 接收优化后的
Module
、目标硬件(如target="llvm"
)和参数(权重); - 调用对应后端的代码生成器(如 LLVM 生成 CPU 代码,CUDA 生成 GPU Kernel);
- 输出可执行模块(
tvm.runtime.Module
),包含机器码和参数加载逻辑。
- 接收优化后的
relay.param_dict
:参数管理
- 作用:管理模型的参数(权重、偏置等常量),通常与
Module
配合使用。 - 特性:参数以字典形式存储(
{name: ndarray}
),编译时需传入relay.build
,以便后端将参数嵌入可执行代码或单独加载
relay.prelude
:标准库
- 作用:提供预定义的辅助函数(如基础数学运算、数据结构操作),类似编程语言的标准库。
- 示例:
relay.prelude.Prelude()
包含add
、mul
等基础函数,可直接在Function
中引用,简化复杂计算的构建。
四、三类模块的协同关系:从模型构建到执行的全流程
三类模块通过 “构建→分析→优化→编译” 的流程紧密协作,形成完整的模型处理链路:
模型构建阶段:
- 数据结构类提供基础组件:
Op
定义计算单元,Var
/Const
提供数据载体,Call
组合两者形成Expr
(计算逻辑); - 容器类整合组件:
Function
将Expr
封装为带类型的函数,GlobalVar
为函数命名,Module
整合所有函数形成完整模型。
模型优化阶段:
- 工具类中的
analysis
分析Module
:通过类型推断确定Function
的输入输出形状,通过依赖分析明确函数间关系; transform
基于分析结果优化Module
:修改Function
的body
(如融合Call
算子),更新Expr
结构以提升效率。
模型执行阶段:
- 工具类中的
build_module
将优化后的Module
编译为硬件代码:根据Function
的body
(Expr
逻辑)和Type
信息(形状 / 类型),生成适配目标硬件的可执行代码; - 最终在硬件上加载代码和参数,执行计算。