当前位置: 首页 > news >正文

torch-xla动态shape——通过torch.nonzero分析mhlo实现

pytorch api:

  • torch.where
  • torch.nonzero

torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).

特别注意torch.nonzeroas_tuple参数:

在这里插入图片描述

mhlo算法理解:
在这里插入图片描述

python脚本

脚本来源xla/issues/4432

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["PJRT_DEVICE"] = "XPU"
# os.environ["PJRT_DEVICE"] = "GPU"
os.environ["XLA_EXPERIMENTAL"]="nonzero"

import torch
import torch_xla.core.xla_model as xm

a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=xm.xla_device())
a2 = torch.nonzero(a1)
"""
IR {
  %0 = s64[1,6]{1,0} xla::device_data(), xla_shape=s64[1,6]{1,0}, device=TX8:0
  %1 = (s32[<=6,2]{1,0}, s32[]) aten::nonzero(%0), num_outputs=2, xla_shape=(s32[<=6,2]{1,0}, s32[]), ROOT=0
}

"""
print(torch_xla._XLAC._get_xla_tensors_text([a2]))
print(f'{a2.shape=}') # a2.shape=torch.Size([<=6, 2])
print('a2=', a2)

运行结果:
在这里插入图片描述

mhlo代码

module @SyncTensorsGraph.40 {
  // %arg0: [[1, 0, 0, 5, 0, 6]]
  func.func @main(%arg0: tensor<1x6xi64>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>> {
    %0 = mhlo.constant dense<0> : tensor<i64>
    %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i64>) -> tensor<1x6xi64>
    %2 = mhlo.compare  NE, %arg0, %1 : (tensor<1x6xi64>, tensor<1x6xi64>) -> tensor<1x6xi1>
    %3 = mhlo.convert(%2) : (tensor<1x6xi1>) -> tensor<1x6xi32>
    // %4: [1, 0, 0, 1, 0, 1]
	%4 = mhlo.reshape %3 : (tensor<1x6xi32>) -> tensor<6xi32>
	%5 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x6xi32>
    // %6: [0, 0, 0, 0, 0, 0], row indices
    %6 = mhlo.reshape %5 : (tensor<1x6xi32>) -> tensor<6xi32>
    %7 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x6xi32>
    // %8: [0, 1, 2, 3, 4, 5], col indices
	%8 = mhlo.reshape %7 : (tensor<1x6xi32>) -> tensor<6xi32>
	"""
    对所有的operand(%4, %6, %8)分别排序得到3个排序后的tensor,且排序条件一致如下
	这里3个operand,对应block中有6个element,对应关系为:一个operand对应2个arg
	%arg1和%arg2为%4中的元素,所以这里操作的结果为:将%6和%8按照%4中元素的大小关系排序
	即:将第0,3,5位置的元素排到最前。排序后的结果为:
	%4'=%9#0 = [1, 1, 1, 0, 0, 0]
	%6'=%9#1 = [0, 0, 0, 0, 0, 0] 即%10
	%8'=%9#2 = [0, 3, 5, 1, 2, 4] 即%11
	"""
	%9:3 = "mhlo.sort"(%4, %6, %8) ({
    ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>):
      // %20总是true
	  %20 = mhlo.constant dense<true> : tensor<i1>
      // 排序条件: %arg1 > %arg2
	  %21 = mhlo.compare  GT, %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
	  %22 = "mhlo.select"(%20, %21, %20) : (tensor<i1>, tensor<i1>, tensor<i1>) -> tensor<i1>
      mhlo.return %22 : tensor<i1>
    }) {dimension = 0 : i64, is_stable = true} : (tensor<6xi32>, tensor<6xi32>, tensor<6xi32>) -> (tensor<6xi32>, tensor<6xi32>, tensor<6xi32>)
    %10 = mhlo.reshape %9#1 : (tensor<6xi32>) -> tensor<6x1xi32>
    %11 = mhlo.reshape %9#2 : (tensor<6xi32>) -> tensor<6x1xi32>
    // %10和%22 concat后,即为非0元素的indices
	%12 = "mhlo.concatenate"(%10, %11) {dimension = 1 : i64} : (tensor<6x1xi32>, tensor<6x1xi32>) -> tensor<6x2xi32>
    // 统计非0 index的个数: reduce sum
	%13 = mhlo.constant dense<0> : tensor<i32>
    %14 = "mhlo.broadcast_in_dim"(%13) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<6xi32>
    %15 = mhlo.compare  GT, %4, %14 : (tensor<6xi32>, tensor<6xi32>) -> tensor<6xi1>
    %16 = mhlo.convert(%15) : (tensor<6xi1>) -> tensor<6xi32>
    %17 = mhlo.reduce(%16 init: %13) across dimensions = [0] : (tensor<6xi32>, tensor<i32>) -> tensor<i32>
     reducer(%arg1: tensor<i32>, %arg2: tensor<i32>)  {
      %20 = mhlo.add %arg1, %arg2 : tensor<i32>
      mhlo.return %20 : tensor<i32>
    }
	// 0维度为动态维度,维度大小,即为统计的个数
    %18 = "mhlo.set_dimension_size"(%12, %17) {dimension = 0 : i64} : (tensor<6x2xi32>, tensor<i32>) -> tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>
    %19 = "mhlo.tuple"(%18) {xla_shape = "(s32[<=6,2]{1,0})"} : (tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>>
    return %19 : tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>>
  }
}

相关文章:

  • 第六天:requests库的用法
  • JS数组扁平化
  • Java与Go相比,有什么独特的优势
  • Openshift或者K8S上部署xxl-job
  • 电商分布式场景中如何保证数据库与缓存的一致性?实战方案与Java代码详解
  • vtkCamera类的Dolly函数作用及相机拉近拉远
  • 【力扣】199.二叉树的右视图
  • Pygame中自定义事件处理的方法2-1
  • DeepSeek-V3模型底层架构的核心技术一(多Token预测(MTP)技术)
  • Python 内置函数 isinstance
  • 【爬虫】使用 Scrapy 框架爬取豆瓣电影 Top 250 数据的完整教程
  • java八股文-mysql
  • SQL与数据库程序设计
  • 用队列实现栈
  • 人工智能之目标追踪DeepSort源码解读(yolov5目标检测,代价矩阵,余弦相似度,马氏距离,匹配与预测更新)
  • 牛顿法:用泰勒级数求解平方根的秘籍
  • 004 python变量
  • OpenVAS 工具使用全攻略
  • java练习(24)
  • 《解锁自然语言处理:让公众正确拥抱AI语言魔法》
  • 追光|铁皮房、土操场,这有一座“筑梦”摔跤馆
  • 洞天寻隐·学林纪丨玉洞桃源:仇英青绿山水画中的洞天与身体
  • 2025上海科技节将于5月17日开幕,拟设6大板块专题活动
  • 中国德国商会报告:76%在华德企受美国关税影响,但对华投资战略依然稳固
  • 住宿行业迎“最火五一”:数千家酒店连续3天满房,民宿预订量创历史新高
  • 百亿基金经理调仓路径曝光,张坤、陈皓、胡昕炜又有新动作