自动生成 te.extern 接口:TVM 中第三方库的智能化接入方案
te.extern
和 tir.call_packed
的详细解析
def concat(x, axis=0):assert (x[0].ndim == 1 and axis in [0]) or (x[0].ndim == 2 and axis in [0, 1]) or (x[0].ndim == 3 and axis in [0,1,2]) or (x[0].ndim == 4 and axis in [0,1,2,3]), "concat x ndim:{0} axis:{1}".format(x[0].ndim, axis)assert len(x) == 2, "concat len(x)=%d" % len(x)assert np.all([x[0].shape[i].value == x[1].shape[i].value for i inrange(x[0].ndim) if i != axis])out_shape = concat_outshape(x, axis)# assert np.prod(out_shape) <= 2 ** 26, "concat out shape > 2**26, please check inputs shape"return te.extern(out_shape, x,lambda ins, outs: tir.call_packed("tvm.contrib.xx.Concat", axis,ins[0], outs[0]), name="Concat", dtype=x[0].dtype)
这两个函数是 TVM 中用于集成外部自定义算子的关键接口,代码展示了一个自定义的 concat
操作实现。下面我将分别解释它们的作用和协作机制:
1. te.extern
:声明外部计算
功能
- 定义外部实现的算子,允许将非TVM原生实现的操作(如C++/CUDA代码)集成到TVM计算图中
- 不描述具体计算过程,只声明输入输出关系和调用方式
参数解析
te.extern(out_shape, # 输出张量的形状x, # 输入张量列表(这里是包含2个张量的列表)lambda ins, outs: # 计算描述函数tir.call_packed("tvm.contrib.xx.Concat", axis, ins[0], outs[0]),name="Concat", # 算子名称(用于调试)dtype=x[0].dtype # 输出数据类型(与输入一致)
)
在代码中的作用
- 声明了一个名为
Concat
的外部算子 - 输出形状通过
concat_outshape(x, axis)
计算得出 - 实际计算委托给名为
"tvm.contrib.xx.Concat"
的外部函数处理
2. tir.call_packed
:调用外部函数
功能
- 生成调用外部函数的IR节点,最终会被编译为对目标函数的调用
- 是TVM运行时(Runtime)与外部代码交互的核心机制
参数解析
tir.call_packed("tvm.contrib.xx.Concat", # 注册的函数名(需在C++/CUDA中实现)axis, # 拼接轴参数ins[0], # 第一个输入张量outs[0] # 输出张量(按引用传递)
)
关键特性
特性 | 说明 |
---|---|
跨语言调用 | 可通过TVM运行时调用C++/CUDA/Python等实现的函数 |
内存管理 | outs[0] 由调用者预分配,函数直接写入结果 |
类型擦除 | 参数通过TVM的PackedFunc 机制动态传递 |
3. 完整工作流程
sequenceDiagramparticipant Pythonparticipant TE(Tensor Expression)participant TIRparticipant RuntimePython->>TE: 调用te.extern()TE->>TIR: 生成包含call_packed的IRTIR->>Runtime: 编译为可执行代码Runtime->>C++: 调用已注册的"tvm.contrib.xx.Concat"C++-->>Runtime: 写入结果到outs[0]Runtime-->>Python: 返回计算完成的张量
4. 为什么需要这种设计?
(1) 集成自定义优化
- 当TVM原生算子性能不足时,可以用手写CUDA/C++代码替代
- 示例:
xx.Concat
可能使用了特殊的内存访问优化
(2) 逐步迁移策略
(3) 支持黑盒算子
- 集成第三方库(如cuDNN、OneDNN)的已有实现
- 示例:
te.extern(..., lambda ins, outs: tir.call_packed("tvm.contrib.cudnn.conv2d", ...))
5. 实现外部函数的步骤
(1) C++端注册函数
// 在C++中实现并注册
TVM_REGISTER_GLOBAL("tvm.contrib.xx.Concat")
.set_body([](TVMArgs args, TVMRetValue* rv) {int axis = args[0];DLTensor* input = args[1];DLTensor* output = args[2];// 实际拼接逻辑...
});
(2) 编译为共享库
g++ -shared -fPIC concat.cc -o libxxconcat.so
(3) Python端加载
tvm.runtime.load_module("libxxconcat.so")
6. 注意事项
- 内存布局一致性:
- 确保外部函数处理的数据布局(如NCHW/NHWC)与TVM一致
- 线程安全:
- 若算子会被多线程调用,需自行加锁或设计无锁算法
- 性能分析:
- 使用
tvm.runtime.profiler
对比原生实现的性能
- 使用
总结
te.extern
:TVM中声明外部算子的前端接口tir.call_packed
:生成调用外部代码的IR节点- 典型应用场景:
- 集成高度优化的手写内核
- 调用第三方加速库
- 逐步将黑盒算子迁移到TVM原生实现
代码展示了一个典型的安全检查+外部实现的concat操作,这种模式在需要结合TVM自动调度与手动优化时非常有用。