当前位置: 首页 > news >正文

Ascend的aclgraph(五)PrimTorch TorchInductor

1 PrimTorch

参照文章:Torch.compile()流程解析——4. PrimTorch & TorchInductor

在上一篇Ascend的aclgraph(四)AOT Autograd构建joint graph的时候提及过op执行的时候,通过ProxyTorchDispatchModetorch_dispatch对op进行decompose,具体流程是:

  1. 调用maybe_handle_decomp()函数在CURRENT_DECOMPOSITION_TABLE(一个Aten op映射表)中查找op对应的函数实现并返回,若未实现则进入b;

  2. 若不是则调用decompose()函数继续进行拆解,decompose()实现逻辑如下:

# decompose()函数实现
# path:/torch/_ops.py
def decompose(self, *args, **kwargs):dk = torch._C.DispatchKey.CompositeImplicitAutogradif dk in self.py_kernels:# NB: This branch is not too necessary anymore, because we can# apply Python CompositeImplicitAutograd *before* tracing# using Python dispatcher (also taking advantage of the autograd# formula).  But it's included for completenessreturn self.py_kernels[dk](*args, **kwargs)elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):return self._op_dk(dk, *args, **kwargs)else:return NotImplemented

从而实现high level op一步步拆解到Aten op的过程。总的来说,PrimTorch是一种规定,将所有的op拆解为一个约定的op规范集合,并作为开发者和硬件厂商之间的一种中间桥梁,Pytorch前端将op拆解映射到PrimTorch,而硬件厂商针对这些特定的op进行优化即可
在这里插入图片描述

2 TorchInductor

TorchInductor 是 PyTorch 的一个高性能编译后端,专注于将优化后的计算图转换为高效的、针对特定硬件(如 CPU、GPU)的内核代码。它利用多种优化技术,包括内存优化、并行化和低层次的代码生成,以最大化计算性能。
aot_dispatch_autograd()函数在拿到前反向的FX Graph后,分别调用fw_compilerbw_compiler对前反向图进行编译,这里的fw_compilerbw_compiler可以是不同的compiler(npu就是自定义的),在inductor的默认实现中调用的是compile_fx_inner,而其中的核心函数是fx_codegen_and_compile(),负责对FX Graph进行图优化、Triton内核代码生成等

TorchInductor的核心实现逻辑如下,感兴趣的小伙伴也可以看看后面的代码解析部分
在这里插入图片描述

2.1 TorchInductor函数调用

fx_codegen_and_compile()中比较重要的三个函数是:

  1. _recursive_post_grad_passes:负责对计算图进一步的优化,包括:

    • group_batch_fusion_passes:对batch_linear、batch_relu、batch_sigmoid等归一化操作进行算子融合,根据融合规则,然后以BFS的方式查找符合规则的op进行融合。
    • remove_noop_ops:移除图形中本质上是 aten.clone 和aten.alias的操作。
    • fuse_ddp_communication:对ddp通信的部分操作调用合并函数进行融合。
    • decompose_auto_functionalized:对high-level op进一步进行拆解(因为前面进行算子融合那些操作可能会引入新的high level op所以这里再操作一遍),将高层次的操作逐步转换为更低层次的实现。
  2. GraphLowering:把FX Graph进一步降为Inductor IR,即前面的计算图被进一步转换为低层次的中间表示。这一表示更加接近最终的机器代码,并且适合进一步的代码生成和优化。

  3. GraphLowering.compile_to_fn():负责对前面生成的IR表示转换为针对目标硬件低层次代码,GPU上会生成Triton,CPU上会生成OpenMP/C++,npu上是?后续解答,同时可能会利用 SIMD 指令和多线程并行化来加速计算,是inductor中一个核心的实现。

2.2 compile_to_fn()——内核代码生成

compile_to_fn()在Scheduler类中实现内核代码编译的核心功能。而Scheduler的两个函数值得关注:

  1. Scheduler.init():实现算子融合等优化,基本流程为:

    • compute_dependencies():分析op之间的依赖关系;
    • fuse_nodes():合并节点,核心逻辑是通过get_possible_fusions获取可融合算子组合(这里只是先选出可融合的,因为可能op之间有交集,所以并未直接执行融合,而是筛出可融合的组合并排序再进行按序融合),然后再调用can_fuse()进一步检查是否可融合,最后进行融合,其中两个重要的函数是can_fuse()检查两个op融合是否合法,score_fusion()对给定的融合op排一个优先级(当融合op组合冲突时以排序分数高的先融合,排序得分基于<1>节省的内存操作的估计,<2> 尽量保持原始操作顺序);
  2. Scheduler.codegen():(比较复杂,目前没看懂怎么生成的。。。)

    • codegen_extern_call():是对部分kernel决策进行就地更改并记录决策(没看明白什么操作)
    • self.get_backend(device).codegen_node(node):根据device调用codegen_node生成针对目标硬件的内核代码,如在/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/cuda_combined_scheduling.py::codegen_node中实现了生成Triton内核代码。

回到compile_to_module(),将前面生成内核代码以.py文件方式(triton实现)保存到PyCodeCache中,最后调用PyCodeCache.load_by_key_path()获得编译后的module(这个module包含triton代码的临时文件路径),返回到fx_codegen_and_compile()函数中将进一步封装成CompiledFxGraph

最后回到compile_fx_inner函数中,若支持cudagraph还会对编译后的图进行cudagraph编译优化(torch.compile的recude-overhead模式下会自动添加 CUDA Graph 来减小运行时开销)。读到这里,以为aclgraph中的recude-overhead是专为npu添加,原来是来自于这里。

cudagraph优化具体流程为:

  1. has_incompatible_cudagraph_ops():检查是否存在与cudagraph不兼容的op
  2. cudagraphify():将子图转为cudagraph进行优化

到此完成TorchInductor的编译部分,返回一个Triton内核代码实现的CompiledFxGraph,最后一路返回到compile_fx()即inductor的入口,又回到了调用call_user_compiler处,继续进行后续的操作(Ascend的aclgraph(三)TorchDynamo)完成整个流程。

# Inductor核心函数实现
def fx_codegen_and_compile(gm: torch.fx.GraphModule,example_inputs: List[torch.Tensor],cudagraphs: Optional[BoxedBool] = None,static_input_idxs: Optional[List[int]] = None,is_backward: bool = False,graph_id: Optional[int] = None,cpp_wrapper: bool = False,aot_mode: bool = False,is_inference: bool = False,# Use a dict with None value rather than a set for deterministic# iteration order just in case.user_visible_outputs: Optional[Dict[str, None]] = None,layout_opt: Optional[bool] = None,extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
) -> Union[CompiledFxGraph, str]:# 省略中间...V.debug.fx_graph(gm, example_inputs)shape_env = _shape_env_from_inputs(example_inputs)view_to_reshape(gm)with torch.no_grad():fake_mode = fake_tensor_prop(gm, example_inputs)with V.set_fake_mode(fake_mode):# has some issues with memory in training_recursive_post_grad_passes(gm, is_inference=is_inference)    # 优化计算图,包括group_batch_fusion、remove_noop_ops(拷贝别名处理)、fuse_ddp_communication等V.debug.fx_graph_transformed(gm, example_inputs)post_grad_graphs_log.debug("%s",lazy_format_graph_code("AFTER POST GRAD", gm, include_stride=True, include_device=True),)trace_structured("inductor_post_grad_graph",payload_fn=lambda: gm.print_readable(print_output=False, include_stride=True, include_device=True),)if config.is_fbcode():log_optimus_to_scuba(extra_logging={"pt2_configs": str(get_patched_config_dict())})with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(example_inputs):const_output_index = Noneconst_graph = Noneconst_code = Noneif aot_mode and config.aot_inductor.use_runtime_constant_folding:const_gm, const_output_index = split_const_gm(gm)const_graph = GraphLowering(const_gm,example_inputs=[],shape_env=shape_env,graph_id=graph_id,cpp_wrapper=cpp_wrapper,aot_mode=aot_mode,user_visible_outputs=user_visible_outputs,extern_node_serializer=extern_node_serializer,is_inference=is_inference,is_const_graph=True,)with V.set_graph_handler(const_graph):assert cpp_wrapper, "AOT mode only supports C++ wrapper"const_graph.run()const_code, _ = const_graph.codegen_with_cpp_wrapper()# 降为inductor IR以进一步的优化graph = GraphLowering(gm,# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,# we currently use fake tensors and defake them later.example_inputs=example_inputs,shape_env=shape_env,graph_id=graph_id,cpp_wrapper=cpp_wrapper,aot_mode=aot_mode,user_visible_outputs=user_visible_outputs,extern_node_serializer=extern_node_serializer,is_inference=is_inference,const_output_index=const_output_index,const_code=const_code,const_module=const_graph,)metrics_helper = metrics.CachedMetricsHelper()with V.set_graph_handler(graph):graph.run(*example_inputs)output_strides: List[Optional[Tuple[int, ...]]] = []if graph.graph_outputs is not None:# We'll put the output strides in the compiled graph so we# can later return them to the caller via TracingContextfor out in graph.graph_outputs:if (hasattr(out, "layout")and len(free_unbacked_symbols(out.layout.stride)) == 0):output_strides.append(tuple(V.graph.sizevars.size_hint(s) for s in out.layout.stride))else:output_strides.append(None)_check_triton_bf16_support(graph)compiled_fn = graph.compile_to_fn()    # 生成对应的后端内核代码,GPU为Triton,CPU为C++/OpenMP# 省略中间代码...# 将编译后代码封装成CompiledFxGraph并返回compiled_graph = CompiledFxGraph(compiled_fn,graph,output_strides,V.graph.disable_cudagraphs_reason,metrics_helper.get_deltas(),)return compiled_graph

到此梳理完了torch.compile()函数的整体流程,解析了从TorchDynamo捕获计算图、再到AOTAutograd捕获前反向计算图并进行算子decompose、以及最后在TorchInductor中完成算子融合和kernel代码生成的实现逻辑,后续再对其中的部分实现细节进行深入分析。

相关文章:

  • 网页Web端无人机直播RTSP视频流,无需服务器转码,延迟300毫秒
  • Dagster Pipes系列-1:调用外部Python脚本
  • 按钮导航组件 | 纯血鸿蒙组件库AUI
  • 基于STM32、HAL库的DPS368XTSA1气压传感器 驱动程序设计
  • Java高频面试之并发编程-16
  • 设置环境变量启动jar报
  • 基于SpringBoot的蜗牛兼职网设计与实现|源码+数据库+开发说明文档
  • Qt Creator 配置 Android 编译环境
  • 火山RTC 6 自定义视频
  • 深入解析MySQL联合查询(UNION):案例与实战技巧
  • 区块链技术构建电子发票平台“税链”
  • JVM之垃圾回收器
  • 开源 RPA 工具深度解析与官网指引
  • 【Git】GitHub上传图片遇到的问题
  • Spark,序列化反序列化
  • C# 基础 try-catch代码块
  • 「华为」人形机器人赛道投资首秀!
  • 单片机学习Day08--相邻流水灯
  • 【落羽的落羽 C++】stack和queue、deque、priority_queue、仿函数
  • MySQL 8.0安装(压缩包方式)
  • 朝着解决问题的正确方向迈进——中美经贸高层会谈牵动世界目光
  • 寒武纪陈天石:公司的产品力获得了行业客户广泛认可,市场有望迎来新增量需求
  • 行知读书会|换一个角度看见社会
  • 上海建筑领域绿色发展2025年工作要点发布
  • 打击网络谣言、共建清朗家园,中国互联网联合辟谣平台2025年4月辟谣榜
  • 巴基斯坦对印度发起网络攻击,致其约70%电网瘫痪