当前位置: 首页 > news >正文

Triton Linalg - WrapFuncBodyWithSingleBlockPass

功能

WrapFuncBodyWithSingleBlockPass用于将函数体中有多个Block的Function,用scf.execute_region封起来,让Function变为单Block的形式。

示例:

转换前:

 tt.func @foo(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {cf.cond_br %arg0, ^bb1, ^bb2^bb1:tt.return %arg1: i32^bb2:tt.return %arg2: i32}

转换后:

tt.func @foo(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {%0 = scf.execute_region -> i32 {cf.cond_br %arg0, ^bb1, ^bb2^bb1:cb.br ^bb3(%arg1 : i32)^bb2:cb.br ^bb3(%arg2 : i32)^bb3(%1 : i32):scf.yield %1 : i32}tt.return %0 : i32}

WrapFuncBodyWithSingleBlockPass简化了Function的控制流,解决了在scf.for 循环中内联函数调用时由于多块(multi-block)问题导致的内联问题。

如对foo调用的代码:

scf.for %iv = %0 to %10 step %1 {%result = call @foo(%arg0, %arg1, %iv) : (i1, i32, i32) -> i32
}

假如inline优化需要把foo做内联,则原始的多Block的foo会导致生成的IR不正确或难以优化。

实现

Pass.td

WrapFuncBodyWithSingleBlock的td部分实现在triton-linalg/triton-linalg/include/triton-linalg/Dialect/Triton/Transforms/Passes.td:

  • WrapFuncBodyWithSingleBlock是一个ModuleOp范围的Pass。
  • constructor为createWrapFuncBodyWithSingleBlockPass。
def WrapFuncBodyWithSingleBlock : Pass<"wrap-func-body-with-single-block", "::mlir::ModuleOp"> {let summary = "Wrap function body with single block";let description = [{This pass wraps function body into a block by moving body to a `scf.execute_region`.}];let constructor = "mlir::triton::createWrapFuncBodyWithSingleBlockPass()";let dependentDialects = ["triton::TritonDialect", "scf::SCFDialect"];
}

Pass.cpp

WrapFuncBodyWithSingleBlock的cpp部分实现在:triton-linalg/triton-linalg/lib/Dialect/Triton/Transforms/WrapFuncBodyWithSingleBlock.cpp

runOnOperation

runOnOperation实现了对Module内每个实现了FunctionOpInterface的Operation(即Function)的遍历,并调用encapsulateMultiBlock。

struct WrapFuncBodyWithSingleBlockPass: public WrapFuncBodyWithSingleBlockBase<WrapFuncBodyWithSingleBlockPass> {void runOnOperation() override {getOperation()->walk([&](FunctionOpInterface func) { encapsulateMultiBlock(func); });}
};

encapsulateMultiBlock

创建Block

首先使用Function的Entry Block的参数,创新一个newBlock,插入Function的第一个Block之前,即变成新的Entry Block:

static void encapsulateMultiBlock(FunctionOpInterface funcOp) {...auto &entryBlock = body.front();auto blockArgTypes = entryBlock.getArgumentTypes();SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);Block *newBlock = builder.createBlock(&body, body.begin(), blockArgTypes, blockArgLocs);...
}

上述示例foo:

"tt.func"() <{function_type = (i1, i32, i32) -> i32, sym_name = "wrap_multi_block_triton_func"}> ({
^bb0(%arg0: i1, %arg1: i32, %arg2: i32):
^bb1(%0: i1, %1: i32, %2: i32):  // no predecessors"cf.cond_br"(%0)[^bb2, ^bb3] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (i1) -> ()
^bb2:  // pred: ^bb1"tt.return"(%1) : (i32) -> ()
^bb3:  // pred: ^bb1"tt.return"(%2) : (i32) -> ()
}) : () -> ()
插入scf.execute_region

接着在newBlock(新的Entry Block)的Operation List头部插入scf.execute_region(返回值类型和Function一致):

static void encapsulateMultiBlock(FunctionOpInterface funcOp) {...builder.setInsertionPointToStart(newBlock);FunctionType funcType = cast<FunctionType>(funcOp.getFunctionType());auto containerOp = builder.create<scf::ExecuteRegionOp>(loc, funcType.getResults());...
}

上述示例foo:

"tt.func"() <{function_type = (i1, i32, i32) -> i32, sym_name = "wrap_multi_block_triton_func"}> ({
^bb0(%arg0: i1, %arg1: i32, %arg2: i32):%0 = "scf.execute_region"() ({}) : () -> i32
^bb1(%1: i1, %2: i32, %3: i32):  // no predecessors"cf.cond_br"(%1)[^bb2, ^bb3] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (i1) -> ()
^bb2:  // pred: ^bb1"tt.return"(%2) : (i32) -> ()
^bb3:  // pred: ^bb1"tt.return"(%3) : (i32) -> ()
}) : () -> ()
拷贝blocks

将原Function Body的Blocks追加到scf.execute_region的Region(跳过第一个,即newBlock):

static void encapsulateMultiBlock(FunctionOpInterface funcOp) {...auto &containerRegion = containerOp.getRegion();auto &blocksToMove = body.getBlocks();containerRegion.getBlocks().splice(containerRegion.end(), blocksToMove,std::next(blocksToMove.begin()),blocksToMove.end());...
}

上述示例foo:

"tt.func"() <{function_type = (i1, i32, i32) -> i32, sym_name = "wrap_multi_block_triton_func"}> ({
^bb0(%arg0: i1, %arg1: i32, %arg2: i32):%0 = "scf.execute_region"() ({^bb0(%arg3: i1, %arg4: i32, %arg5: i32):"cf.cond_br"(%arg3)[^bb1, ^bb2] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (i1) -> ()^bb1:  // pred: ^bb0"tt.return"(%arg4) : (i32) -> ()^bb2:  // pred: ^bb0"tt.return"(%arg5) : (i32) -> ()}) : () -> i32
}) : () -> ()
替换Block Arguments

删除scf.execute_region的Entry Block中的参数,并将后续Block使用到的地方替换为newBlock的参数(按照设计scf.execute_region不能有Block Arguments):

  auto &containerEntryBlock = containerRegion.getBlocks().front();for (auto &arg : containerEntryBlock.getArguments()) {for (OpOperand &use : llvm::make_early_inc_range(arg.getUses())) {auto newValue = newBlock->getArgument(arg.getArgNumber());use.getOwner()->setOperand(use.getOperandNumber(), newValue);}}containerEntryBlock.eraseArguments(0, containerEntryBlock.getNumArguments());

上述示例foo:

"tt.func"() <{function_type = (i1, i32, i32) -> i32, sym_name = "wrap_multi_block_triton_func"}> ({
^bb0(%arg0: i1, %arg1: i32, %arg2: i32):%0 = "scf.execute_region"() ({"cf.cond_br"(%arg0)[^bb1, ^bb2] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (i1) -> ()^bb1:  // pred: ^bb0"tt.return"(%arg1) : (i32) -> ()^bb2:  // pred: ^bb0"tt.return"(%arg2) : (i32) -> ()}) : () -> i32
}) : () -> ()
插入Yield Block

在scf.execute_region包含的block中追加一个yieldBlock(参数与包含的Block中的ReturnLike Op一致):

  auto &yieldBlock = containerRegion.emplaceBlock();auto operandTypes = returnOps.front()->getOperandTypes();auto unknownLoc = UnknownLoc::get(ctx);SmallVector<Location> unknownLocs(operandTypes.size(), unknownLoc);yieldBlock.addArguments(operandTypes, unknownLocs);builder.setInsertionPointToEnd(&yieldBlock);builder.create<scf::YieldOp>(unknownLoc, yieldBlock.getArguments());

上述示例foo:

"tt.func"() <{function_type = (i1, i32, i32) -> i32, sym_name = "wrap_multi_block_triton_func"}> ({
^bb0(%arg0: i1, %arg1: i32, %arg2: i32):%0 = "scf.execute_region"() ({"cf.cond_br"(%arg0)[^bb1, ^bb2] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (i1) -> ()^bb1:  // pred: ^bb0"tt.return"(%arg1) : (i32) -> ()^bb2:  // pred: ^bb0"tt.return"(%arg2) : (i32) -> ()^bb3(%1: i32):  // no predecessors"scf.yield"(%1) : (i32) -> ()}) : () -> i32
}) : () -> ()
插入branch

在sch.execute_region直接包含的ReturnLike Op后面插入一个branch语句,跳转到上面插入的yield Block(scf.execute_region的Bufferization只支持1个Yield):

  SmallVector<Operation *> returnOps;containerOp.walk([&](Operation *op) {if (op->hasTrait<OpTrait::ReturnLike>() &&op->getParentOp() == containerOp.getOperation())returnOps.push_back(op);});for (auto *returnOp : returnOps) {builder.setInsertionPoint(returnOp);builder.create<cf::BranchOp>(returnOp->getLoc(), &yieldBlock,returnOp->getOperands());}

上述示例foo:

"tt.func"() <{function_type = (i1, i32, i32) -> i32, sym_name = "wrap_multi_block_triton_func"}> ({
^bb0(%arg0: i1, %arg1: i32, %arg2: i32):%0 = "scf.execute_region"() ({"cf.cond_br"(%arg0)[^bb1, ^bb2] <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (i1) -> ()^bb1:  // pred: ^bb0"cf.br"(%arg1)[^bb3] : (i32) -> ()"tt.return"(%arg1) : (i32) -> ()^bb2:  // pred: ^bb0"cf.br"(%arg2)[^bb3] : (i32) -> ()"tt.return"(%arg2) : (i32) -> ()^bb3(%1: i32):  // 2 preds: ^bb1, ^bb2"scf.yield"(%1) : (i32) -> ()}) : () -> i32
}) : () -> ()
删除ReturnLike Op

让子Block中的最后一个ReturnLike Op(任选一个)返回scf.execute_region的Result,并作为newBlock的terminator,删除其他的ReturnLike Op:

  Operation *terminator = returnOps.pop_back_val();for (auto *returnOp : returnOps)returnOp->erase();terminator->setOperands(containerOp.getResults());builder.setInsertionPointToEnd(&body.back());terminator->remove();builder.insert(terminator);

上述示例foo:

tt.func @wrap_multi_block_triton_func(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {%0 = scf.execute_region -> i32 {cf.cond_br %arg0, ^bb1, ^bb2^bb1:  // pred: ^bb0cf.br ^bb3(%arg1 : i32)^bb2:  // pred: ^bb0cf.br ^bb3(%arg2 : i32)^bb3(%1: i32):  // 2 preds: ^bb1, ^bb2scf.yield %1 : i32}tt.return %0 : i32
}

至此就完成了将原Function的所有Block封装到scf.execute_region的全部工作,新的Function只包含1个Block,即newBlock。


文章转载自:

http://j0mc4JMn.jmwrj.cn
http://ENER3th4.jmwrj.cn
http://7dFfH1bx.jmwrj.cn
http://Bjfw2YOE.jmwrj.cn
http://Uo9pPu0g.jmwrj.cn
http://IirELJWJ.jmwrj.cn
http://KaQ2srcx.jmwrj.cn
http://vdEsK9PN.jmwrj.cn
http://5BmsvPCZ.jmwrj.cn
http://2QgQlkdB.jmwrj.cn
http://IfN6jSXK.jmwrj.cn
http://hfDIMvn3.jmwrj.cn
http://S63mfkmc.jmwrj.cn
http://UbQcWHAd.jmwrj.cn
http://XtulUGPK.jmwrj.cn
http://tQj0FPtW.jmwrj.cn
http://GxcDfdvP.jmwrj.cn
http://rj85sqfV.jmwrj.cn
http://JGKsofRD.jmwrj.cn
http://P7tQ0J8M.jmwrj.cn
http://SCGhVqZu.jmwrj.cn
http://AiLOxRNc.jmwrj.cn
http://fBOlem9I.jmwrj.cn
http://uUEM1Wjh.jmwrj.cn
http://FDwFecSL.jmwrj.cn
http://TsyYeCOG.jmwrj.cn
http://8yEcdBEn.jmwrj.cn
http://A9kHaQSx.jmwrj.cn
http://sAJjpKuI.jmwrj.cn
http://Qr68Q9KN.jmwrj.cn
http://www.dtcms.com/a/365395.html

相关文章:

  • 软件设计师备考-(十) 多媒体基础
  • 两个子进程之间使用命名pipe
  • 如何构建企业级RAG知识库?实战方法、关键细节与平台选型
  • 并发编程——14 线程池参数动态化
  • PyTorch 损失函数与优化器全面指南:从理论到实践
  • 归一化的定义与作用
  • IO进程线程;进程,发送信号;进程,消息队列通信;0903
  • 消息传递模型实现
  • 阿里开源首个图像生成基础模型——Qwen-Image本地部署教程,中文渲染能力刷新SOTA
  • AI 生成内容(AIGC)版权归属引争议:创作者、平台、AI 公司,谁该拥有 “作品权”?
  • 弧焊工业机器人保护气节约的关键
  • Windows/Linux下vscode+vcpkg管理C++包链接方法
  • 相关性分析与常用相关系数
  • React学习教程,从入门到精通, React 组件语法知识点(9)
  • 记一次VMware虚拟机(BC-linux)网络配置过程
  • LVGL9.3 vscode 模拟环境搭建
  • 【医疗行业案例】基于 React 的预约系统:DHTMLX 助力高效排班与预约管理
  • kafka Partition(分区)详解
  • 线性代数基础 | 基底 / 矩阵 / 行列式 / 秩 / 线性方程组
  • UniApp 混合开发:Plus API 从基础到7大核心场景实战的完整指南
  • 老年综合实训室建设方案:产教融合新实践助力养老人才供需精准对接
  • pytorch初级
  • 【FPGA】DDS信号发生器
  • leetcode210.课程表II
  • 蓝光三维扫描技术赋能内衣胸垫设计:从精准制造到个性化体验的革新之旅
  • 【OC】属性关键字
  • 3027. 人员站位的方案数 II
  • 前端自动化打包服务器无法安装高版本 Node.js v22 问题解决
  • 高效文本处理:cut、sort、uniq 和 tr 命令详解与实战
  • 巨头撤退,玩家内卷!2025,IoT平台的生死劫与重生路