当前位置: 首页 > news >正文

TorchInductor - Introduction

PyTorch 2.x通过TorchDynamo通过Python Bytecode的动态变换实现了图捕获功能,需要搭配一个Compiler Backend完成图编译。

Pytorch尝试集成了多个后端,并使用一个轻量级的autotuner来选择最优的后端图编译结果。这个解决方案存在2个问题:

  • 这些后端的Execution Model和Pytorch差异较大,引入了许多转换步骤,导致性能损失较大。
  • 大多数后端仅支持推理,一些设计决策使得支持训练非常困难。

因此Pytorch需要原生的Compiler Backend:TorchInductor。

整体设计

TorchInductor的整体设计思路是一个轻量级、易于扩展和实验的框架,用于将PyTorch的表示符号化映射到Compiler Backend:

  • 完整表达Pytorch:torch.Tensor -> TensorBox,torch.Storage -> StorageBox 等。
  • 通过symbolically strided tensor表达各种View:Reshape/Transpose/Slice等。
  • 支持训练
  • 支持多后端
  • 支持高层级优化:如Memory Planning。

TorchInductor的初始设计支持2种不同的target:

  • Triton:一种新型的编程语言,开发效率高于CUDA,同时性能可以媲美CUDA的库(如cuDNN),支持NVIDIA/AMD GPU等。
  • C++/OpenMP:一种被广泛使用的编写高性能并行Kernel的API,支持CPU。

TorchInductor设计上优先考虑对Pytorch支持的完整性,包括:

  • aliasing/mutation/views
  • scatter (间接写)
  • gather (间接读)
  • pooling/windows/reductions
  • masked/conditional execution(如padding)
  • template epilogue fusions
  • tiling
  • horizontal/vertical fusions

TorchInductor使用SymPy库来支持动态形状(dynamic shapes)和步长(strides):

使用SymPy符号化tensor的shape,并在整个程序中传播。

load和store直接通过SymPy的索引公式来表达。

通过guards来提升已编译好的子图的使用前提,在guards fail时,触发子图的重编译。

TorchInductor  IR

TorchInductor IR使用了一种define-by-run的loop-level IR。大部分IR是Python里的Callable,输入是SymPy Expression。基于这种IR做分析或代码生成的实现方式是改变ops.* 的实现,并运行IR。

例如对x.permute(1, 0) + x[2, :] 的IR:

def inner_fn(index: List[sympy.Expr]):i1, i0 = indextmp0 = ops.load("x", i1 + i0*size1)tmp1 = ops.load("x", 2*size1 + i0)return ops.add(tmp0, tmp1)torchinductor.ir.Pointwise(device=torch.device("cuda"),dtype=torch.float32,inner_fn=inner_fn,ranges=[size0, size1],
)

TODO:待补充

参考:

TorchInductor: a PyTorch-native Compiler with Define-by-Run IR and Symbolic Shapes - compiler - PyTorch Developer Mailing List

http://www.dtcms.com/a/343546.html

相关文章:

  • 50 C++ STL模板库-算法库 algorithm
  • 使用C++17标准 手写一个vector
  • Python核心技术开发指南(001)——Python简介
  • 基于单片机教室照明灯控制系统
  • 数据结构:生成 (Generating) 一棵 AVL 树
  • 域名污染怎么清洗?域名污染如何处理?
  • 8.21作业
  • 【运维进阶】if 条件语句的知识与实践
  • AI设计师-标小智旗下AI在线设计平台
  • 洛谷 P4942 小凯的数字-普及-
  • Hybrid laser 是什么?
  • BFS算法C++实现(邻接表存储)
  • 最爱--中岛美雪
  • 8 月 20 日科技新动态:多领域创新成果涌现
  • 【typenum】 19 类型相同检查(type_operators.rs片段)
  • Esp32基础(⑩超声波测距模块)
  • Pycharm SSH连接
  • Wireshark数据包波形绘制异常
  • [RestGPT] docs | RestBench评估 | 配置与环境
  • 【51单片机】【protues仿真】基于51单片机16键电子琴系统
  • 【GPT入门】第51课 Conda环境迁移教程:将xxzh环境从默认路径迁移到指定目录
  • OpenAI 开源模型 gpt-oss 是在合成数据上训练的吗?一些合理推测
  • Mysql事务特性
  • python实现根据接口返回数据生成报告和图表
  • (第二十期下)超链接的更多分类
  • 医疗元宇宙:破解医疗困局与数字化变革路径
  • gRPC 服务发现选型对比
  • 基于STM32单片机的二维码识别物联网OneNet云仓库系统
  • 最小生成树的普利姆算法和克鲁斯卡尔算法
  • ABP vNext 速率限制在多租户场景落地