当前位置: 首页 > news >正文

Relay算子注册

TVM 卷积算子注册代码深度解析

源码位置:src/relay/op/nn/convolution.cc
作用:conv2d等算子的注册实现

TVM_REGISTER_NODE_TYPE(Conv2DAttrs);TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d").set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,Array<IndexExpr> dilation, int groups, IndexExpr channels,Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,String out_layout, DataType out_dtype, bool is_depthwise) {return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,kernel_size, data_layout, kernel_layout, out_layout, out_dtype,"nn.conv2d", is_depthwise);});RELAY_REGISTER_OP("nn.conv2d").describe(R"code(2D convolution layer (e.g. spatial convolution over images).This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.- **data**: This depends on the `layout` parameter. Input is 4D array of shape(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**:  This depends on the `layout` parameter. Output is 4D array of shape(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.)code" TVM_ADD_FILELINE).set_attrs_type<Conv2DAttrs>().set_num_inputs(2).add_argument("data", "Tensor", "The input tensor.").add_argument("weight", "Tensor", "The weight tensor.").set_support_level(2).add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>).set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

  这段代码是 TVM 中 2D 卷积算子 (nn.conv2d) 的完整注册实现,包含了类型系统注册、算子构造、属性定义和功能描述等多个关键部分。下面我将分层详细解析每一部分的含义和作用。


1. 属性类型注册 (TVM_REGISTER_NODE_TYPE)

TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
  • 作用:将 Conv2DAttrs 结构体注册到 TVM 的类型系统中,使其可以被序列化和反序列化
  • 必要性
    • 允许 Relay 在优化过程中保存和恢复卷积算子的属性
    • 支持 AutoTVM 等调优工具修改属性参数
  • 底层机制:生成该类型的类型反射信息,包括字段名称、类型等

2. 全局算子构造器注册 (TVM_REGISTER_GLOBAL)

TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d").set_body_typed([](Expr data, Expr weight, ...) {return MakeConv<Conv2DAttrs>(data, weight, ...);});
  • 功能:注册一个全局可调用的算子构造函数
  • 关键点
    • "relay.op.nn._make.conv2d" 是 Python 前端 nn.conv2d() 调用的 C++ 入口
    • set_body_typed 指定 lambda 函数处理参数并返回算子表达式
    • MakeConv 模板函数统一处理卷积类算子的构造逻辑
  • 调用链路
    # Python调用
    relay.nn.conv2d(data, weight, strides=(1,1), ...)# 映射到C++
    relay::op::nn::_make::conv2d(data, weight, strides, ...)
    

3. 算子描述与注册 (RELAY_REGISTER_OP)

(1) 基础信息描述

.describe(R"code(2D convolution layer...)code" TVM_ADD_FILELINE)
  • 作用:提供算子的文档字符串
  • TVM_ADD_FILELINE:宏自动添加定义位置信息(文件+行号),便于调试

(2) 属性系统配置

.set_attrs_type<Conv2DAttrs>()
  • 绑定关系:声明该算子使用 Conv2DAttrs 作为属性容器
  • 效果:所有 nn.conv2d 的属性(strides/padding等)都会存储为 Conv2DAttrs 实例

(3) 输入输出规范

.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
配置项含义
set_num_inputs(2)声明需要2个输入(data和weight)
add_argument定义每个输入的名称、类型和描述

(4) 类型关系函数

.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
  • 作用:注册类型推断函数 Conv2DRel
  • 功能:根据输入类型推断输出类型,例如:
    • 输入: (Tensor[(1,3,224,224), float32], Tensor[(64,3,7,7), float32])
    • 输出: Tensor[(1,64,218,218), float32] (考虑stride=1, padding=0)

(5) 布局推断

.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>)
  • 核心功能:自动处理布局转换(如 NHWCNCHW
  • 典型实现
    Array<Layout> ConvInferCorrectLayout(...) {return {{data_layout, kernel_layout}, {out_layout}};
    }
    

(6) 支持级别

.set_support_level(2)
级别含义
0内部使用算子
1基础算子(如add)
2高级算子(如conv2d)
3实验性算子

4. 关键设计模式解析

(1) 属性与计算分离

使用
Conv2DAttrs
+Array strides
+Array padding
+String data_layout
...
Conv2DOp
+MakeConv()
+Conv2DRel()
  • 计算逻辑MakeConv)只依赖抽象属性接口
  • 属性变更(如修改strides)不影响核心计算实现

(2) 多前端统一入口

# Python前端统一通过_make调用
relay.op.nn._make.conv2d(...)# 映射到C++注册的构造器
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d")

(3) 类型安全的接口

  • set_body_typed 使用模板自动检查参数类型
  • add_type_rel 确保输入输出类型匹配

5. 完整调用链路示例

Python RelayFrontend OpRegistry TOPI relay.nn.conv2d(data, weight, strides=(1,1)) 查找"relay.op.nn._make.conv2d" 调用Conv2DAttrs绑定的计算逻辑 返回计算表达式 返回CallNode 返回Relay表达式 Python RelayFrontend OpRegistry TOPI

6. 扩展开发建议

添加新属性字段

  1. 修改 Conv2DAttrs 结构体:
    struct Conv2DAttrs {// 新增字段bool new_feature_enabled;// 注册字段TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {TVM_ATTR_FIELD(new_feature_enabled).set_default(false);}
    };
    
  2. MakeConv 中处理新属性

自定义类型推断

// 实现新的类型关系函数
bool CustomConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs) {// 自定义类型检查逻辑
}
// 注册时替换
.add_type_rel("Conv2D", CustomConv2DRel)

这种注册机制体现了 TVM 的核心设计哲学:声明式接口可扩展实现的分离,使得算子开发既规范又灵活。

相关文章:

  • 7.9/Q1,Charls最新文章解读
  • Dagger中编译import报找不到ProvideClientFactory,initialize中ProvideClientFactory爆红
  • 猿人学刷题系列(第一届比赛)——第一题
  • 技术对暴力的削弱
  • 【C/C++】构造函数与析构函数
  • 强化学习+多模态 从理论到实战
  • Python Cookbook-7.4 对类和实例使用 cPickle 模块
  • 论软件的可靠性设计
  • 排序算法——堆排序
  • 【PPT制作利器】DeepSeek + Kimi生成一个初始的PPT文件
  • 椭球面长度计算的两种公式及投影选择
  • MySQL 窗口函数入门到精通
  • Coding Practice,48天强训(30)
  • 泰迪杯特等奖案例学习资料:基于卷积神经网络与集成学习的网络问政平台留言文本挖掘与分析
  • 网页截图指南
  • 存储系列知识
  • k8s node 报IPVS no destination available
  • Vue3+ Vite + Element-Plus + TypeScript 从0到1搭建
  • 卡特兰数--
  • 25_05_02Linux架构篇、第1章_03安装部署nginx
  • 超燃!走过莫斯科街头的“中国排面”
  • 金融监管总局:做好2025年小微企业金融服务工作
  • 涉个人信息收集使用问题,15款App和16款SDK被通报
  • IPO周报|节后首批3只新股本周申购,色谱设备龙头来了
  • 夹缝中的责编看行业:长视频之殇,漫长周期
  • 许昌市场监管部门对胖东来玉石开展日常检查:平均毛利率不超20%