Relay算子注册
TVM 卷积算子注册代码深度解析
源码位置:src/relay/op/nn/convolution.cc
作用:conv2d等算子的注册实现
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d").set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,Array<IndexExpr> dilation, int groups, IndexExpr channels,Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,String out_layout, DataType out_dtype, bool is_depthwise) {return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,kernel_size, data_layout, kernel_layout, out_layout, out_dtype,"nn.conv2d", is_depthwise);});RELAY_REGISTER_OP("nn.conv2d").describe(R"code(2D convolution layer (e.g. spatial convolution over images).This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.- **data**: This depends on the `layout` parameter. Input is 4D array of shape(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**: This depends on the `layout` parameter. Output is 4D array of shape(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.)code" TVM_ADD_FILELINE).set_attrs_type<Conv2DAttrs>().set_num_inputs(2).add_argument("data", "Tensor", "The input tensor.").add_argument("weight", "Tensor", "The weight tensor.").set_support_level(2).add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>).set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
这段代码是 TVM 中 2D 卷积算子 (nn.conv2d
) 的完整注册实现,包含了类型系统注册、算子构造、属性定义和功能描述等多个关键部分。下面我将分层详细解析每一部分的含义和作用。
1. 属性类型注册 (TVM_REGISTER_NODE_TYPE
)
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
- 作用:将
Conv2DAttrs
结构体注册到 TVM 的类型系统中,使其可以被序列化和反序列化 - 必要性:
- 允许 Relay 在优化过程中保存和恢复卷积算子的属性
- 支持 AutoTVM 等调优工具修改属性参数
- 底层机制:生成该类型的类型反射信息,包括字段名称、类型等
2. 全局算子构造器注册 (TVM_REGISTER_GLOBAL
)
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d").set_body_typed([](Expr data, Expr weight, ...) {return MakeConv<Conv2DAttrs>(data, weight, ...);});
- 功能:注册一个全局可调用的算子构造函数
- 关键点:
"relay.op.nn._make.conv2d"
是 Python 前端nn.conv2d()
调用的 C++ 入口set_body_typed
指定 lambda 函数处理参数并返回算子表达式MakeConv
模板函数统一处理卷积类算子的构造逻辑
- 调用链路:
# Python调用 relay.nn.conv2d(data, weight, strides=(1,1), ...)↓ # 映射到C++ relay::op::nn::_make::conv2d(data, weight, strides, ...)
3. 算子描述与注册 (RELAY_REGISTER_OP
)
(1) 基础信息描述
.describe(R"code(2D convolution layer...)code" TVM_ADD_FILELINE)
- 作用:提供算子的文档字符串
- TVM_ADD_FILELINE:宏自动添加定义位置信息(文件+行号),便于调试
(2) 属性系统配置
.set_attrs_type<Conv2DAttrs>()
- 绑定关系:声明该算子使用
Conv2DAttrs
作为属性容器 - 效果:所有
nn.conv2d
的属性(strides/padding等)都会存储为Conv2DAttrs
实例
(3) 输入输出规范
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
配置项 | 含义 |
---|---|
set_num_inputs(2) | 声明需要2个输入(data和weight) |
add_argument | 定义每个输入的名称、类型和描述 |
(4) 类型关系函数
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
- 作用:注册类型推断函数
Conv2DRel
- 功能:根据输入类型推断输出类型,例如:
- 输入:
(Tensor[(1,3,224,224), float32], Tensor[(64,3,7,7), float32])
- 输出:
Tensor[(1,64,218,218), float32]
(考虑stride=1, padding=0)
- 输入:
(5) 布局推断
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>)
- 核心功能:自动处理布局转换(如
NHWC
↔NCHW
) - 典型实现:
Array<Layout> ConvInferCorrectLayout(...) {return {{data_layout, kernel_layout}, {out_layout}}; }
(6) 支持级别
.set_support_level(2)
级别 | 含义 |
---|---|
0 | 内部使用算子 |
1 | 基础算子(如add) |
2 | 高级算子(如conv2d) |
3 | 实验性算子 |
4. 关键设计模式解析
(1) 属性与计算分离
- 计算逻辑(
MakeConv
)只依赖抽象属性接口 - 属性变更(如修改strides)不影响核心计算实现
(2) 多前端统一入口
# Python前端统一通过_make调用
relay.op.nn._make.conv2d(...)↓
# 映射到C++注册的构造器
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d")
(3) 类型安全的接口
set_body_typed
使用模板自动检查参数类型add_type_rel
确保输入输出类型匹配
5. 完整调用链路示例
6. 扩展开发建议
添加新属性字段
- 修改
Conv2DAttrs
结构体:struct Conv2DAttrs {// 新增字段bool new_feature_enabled;// 注册字段TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {TVM_ATTR_FIELD(new_feature_enabled).set_default(false);} };
- 在
MakeConv
中处理新属性
自定义类型推断
// 实现新的类型关系函数
bool CustomConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs) {// 自定义类型检查逻辑
}
// 注册时替换
.add_type_rel("Conv2D", CustomConv2DRel)
这种注册机制体现了 TVM 的核心设计哲学:声明式接口与可扩展实现的分离,使得算子开发既规范又灵活。