当前位置: 首页 > news >正文

Relay算子注册(在pytorch.py端调用)

1. Relay算子注册 (C++层)

(a) 算子属性注册

路径: src/relay/op/nn/nn.cc

RELAY_REGISTER_OP("hardswish").set_num_inputs(1).add_argument("data", "Tensor", "Input tensor.").set_support_level(3).add_type_rel("Identity", Identity);
(b) 调用节点构造

路径: src/relay/op/nn/activation.cc

TVM_REGISTER_GLOBAL("relay.op._make.hardswish").set_body_typed([](Expr data) {static const Op& op = Op::Get("hardswish");return Call(op, {data}, Attrs(), {});});

2. TOPI计算实现 (C++层)

© TOPI注册入口

路径: src/topi/elemwise.cc

TVM_REGISTER_GLOBAL("topi.hardswish").set_body([](TVMArgs args, TVMRetValue* rv) {*rv = hardswish(args[0]);});
(d) 数学内核实现

路径: include/tvm/topi/nn.h

inline Tensor hardswish(const Tensor& x, std::string name = "T_hardswish") {auto three = make_const(x->dtype, 3);auto six = make_const(x->dtype, 6);return compute(x->shape,[&](const Array<Var>& i) {return x(i) * max(min(x(i) + three, six), 0) / six;},name, kElementWise);
}

3. Python接口层

(e) Relay Python API

路径: python/tvm/relay/op/nn/_nn.py

def hardswish(data):return _make.hardswish(data)
(f) TOPI通用接口

路径: python/tvm/topi/nn.py

@tvm.target.generic_func
def hardswish(x):return cpp.hardswish(x)

4. 计算调度注册

(g) Compute注册

路径: python/tvm/relay/op/strategy/generic.py

@register_compute("hardswish")
def hardswish_compute(attrs, inputs, out_type):return [topi.hardswish(inputs[0])]
(h) 调度策略

路径: `python/tvm/relay/op/op.py**

register_broadcast_schedule("hardswish")
register_shape_func("hardswish", False, elemwise_shape_func)

5. 硬件专用实现

(i) NPU支持声明

路径: `src/relay/backend/contrib/npu/src/op_map.cc**

const std::vector<std::string> _NPU_OP = {...,"hardswish"  // 添加算子名
};
(j) NPU内核实现

路径: `python/tvm/relay/backend/contrib/npu/ops.py**

def custom_hardswish(x):x1 = custom_add(x, te.extern_scalar_value(3.0))x2 = custom_relu(x1)return npu_hardwish(x2, ...)
(k) NPU策略注册

路径: `python/tvm/relay/op/strategy/npu.py**

@hardswish.register("npu")
def hardswish_npu(x):return npu_api.custom_hardswish(x)

6. 前端框架对接

(l) PyTorch转换器

路径: `python/tvm/relay/frontend/pytorch.py**

def _hardswish():def _impl(inputs, input_types):return _op.hardswish(inputs[0])return _impl

关键文件路径总结

功能模块关键路径
Relay核心注册src/relay/op/nn/{nn.cc, activation.cc}
TOPI计算{include,src}/topi/{nn.h, elemwise.cc}
Python接口python/tvm/{relay/op/nn/_nn.py, topi/nn.py}
策略注册python/tvm/relay/op/strategy/{generic.py, npu.py}
硬件后端src/relay/backend/contrib/npu/
前端对接python/tvm/relay/frontend/pytorch.py

开发流程示意图

Relay注册
TOPI实现
Python接口
硬件后端
前端框架

通过这种清晰的路径划分,TVM实现了:

  1. 模块化开发:各层级代码物理隔离
  2. 可扩展性:新增硬件只需在对应目录添加实现
  3. 维护性:相关功能的代码集中存放

相关文章:

  • 项目中为什么选择RabbitMQ
  • Ubuntu 22.04 安装配置远程桌面环境指南
  • Android 中解决 annotations 库多版本冲突问题
  • 从零搭建体育比分网站完整步骤
  • 高等数学第六章---定积分(§6.1元素法6.2定积分在几何上的应用1)
  • 【C++游戏引擎开发】第30篇:物理引擎(Bullet)—软体动力学系统
  • 【Linuc】深入理解 Linux 文件权限
  • 【MySQL】-- 数据库约束
  • SPP 和 yolo 中的SPP
  • 栈与队列详解及模拟实现
  • spring cloud gateway(网关)简介
  • 【HTML5】显示-隐藏法 实现网页轮播图效果
  • 路线 北大国际医院
  • Deepseek流式操作与用户行为数据分析day01
  • MySQL中MVCC指什么?
  • SQL大场笔试真题
  • 笔记本外接显示器检测不到hdmi信号
  • RabbitMq(尚硅谷)
  • 基于docker使用showdoc搭建API开发文档服务器
  • python + whisper 读取蓝牙耳机, 转为文字
  • 网民反映“潜水时遭遇服务质量不佳”,三亚开展核查调查
  • 上海飞银川客机触地复飞后备降西安,亲历者:不少乘客都吐了
  • 新闻分析:电影关税能“让好莱坞再次伟大”?
  • 胖东来关闭官网内容清空?工作人员:后台维护升级
  • “高校领域突出问题系统整治”已启动,聚焦招生、基建、师德等重点
  • 宋涛就许历农逝世向其家属致唁电