torch-xla动态shape——通过torch.nonzero分析mhlo实现
pytorch api:
- torch.where
- torch.nonzero
torch.where(condition)
is identical totorch.nonzero(condition, as_tuple=True)
.
特别注意torch.nonzero
的as_tuple
参数:
mhlo算法理解:
python脚本
脚本来源xla/issues/4432
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["PJRT_DEVICE"] = "XPU"
# os.environ["PJRT_DEVICE"] = "GPU"
os.environ["XLA_EXPERIMENTAL"]="nonzero"
import torch
import torch_xla.core.xla_model as xm
a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=xm.xla_device())
a2 = torch.nonzero(a1)
"""
IR {
%0 = s64[1,6]{1,0} xla::device_data(), xla_shape=s64[1,6]{1,0}, device=TX8:0
%1 = (s32[<=6,2]{1,0}, s32[]) aten::nonzero(%0), num_outputs=2, xla_shape=(s32[<=6,2]{1,0}, s32[]), ROOT=0
}
"""
print(torch_xla._XLAC._get_xla_tensors_text([a2]))
print(f'{a2.shape=}') # a2.shape=torch.Size([<=6, 2])
print('a2=', a2)
运行结果:
mhlo代码
module @SyncTensorsGraph.40 {
// %arg0: [[1, 0, 0, 5, 0, 6]]
func.func @main(%arg0: tensor<1x6xi64>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>> {
%0 = mhlo.constant dense<0> : tensor<i64>
%1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i64>) -> tensor<1x6xi64>
%2 = mhlo.compare NE, %arg0, %1 : (tensor<1x6xi64>, tensor<1x6xi64>) -> tensor<1x6xi1>
%3 = mhlo.convert(%2) : (tensor<1x6xi1>) -> tensor<1x6xi32>
// %4: [1, 0, 0, 1, 0, 1]
%4 = mhlo.reshape %3 : (tensor<1x6xi32>) -> tensor<6xi32>
%5 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x6xi32>
// %6: [0, 0, 0, 0, 0, 0], row indices
%6 = mhlo.reshape %5 : (tensor<1x6xi32>) -> tensor<6xi32>
%7 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x6xi32>
// %8: [0, 1, 2, 3, 4, 5], col indices
%8 = mhlo.reshape %7 : (tensor<1x6xi32>) -> tensor<6xi32>
"""
对所有的operand(%4, %6, %8)分别排序得到3个排序后的tensor,且排序条件一致如下
这里3个operand,对应block中有6个element,对应关系为:一个operand对应2个arg
%arg1和%arg2为%4中的元素,所以这里操作的结果为:将%6和%8按照%4中元素的大小关系排序
即:将第0,3,5位置的元素排到最前。排序后的结果为:
%4'=%9#0 = [1, 1, 1, 0, 0, 0]
%6'=%9#1 = [0, 0, 0, 0, 0, 0] 即%10
%8'=%9#2 = [0, 3, 5, 1, 2, 4] 即%11
"""
%9:3 = "mhlo.sort"(%4, %6, %8) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>):
// %20总是true
%20 = mhlo.constant dense<true> : tensor<i1>
// 排序条件: %arg1 > %arg2
%21 = mhlo.compare GT, %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
%22 = "mhlo.select"(%20, %21, %20) : (tensor<i1>, tensor<i1>, tensor<i1>) -> tensor<i1>
mhlo.return %22 : tensor<i1>
}) {dimension = 0 : i64, is_stable = true} : (tensor<6xi32>, tensor<6xi32>, tensor<6xi32>) -> (tensor<6xi32>, tensor<6xi32>, tensor<6xi32>)
%10 = mhlo.reshape %9#1 : (tensor<6xi32>) -> tensor<6x1xi32>
%11 = mhlo.reshape %9#2 : (tensor<6xi32>) -> tensor<6x1xi32>
// %10和%22 concat后,即为非0元素的indices
%12 = "mhlo.concatenate"(%10, %11) {dimension = 1 : i64} : (tensor<6x1xi32>, tensor<6x1xi32>) -> tensor<6x2xi32>
// 统计非0 index的个数: reduce sum
%13 = mhlo.constant dense<0> : tensor<i32>
%14 = "mhlo.broadcast_in_dim"(%13) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<6xi32>
%15 = mhlo.compare GT, %4, %14 : (tensor<6xi32>, tensor<6xi32>) -> tensor<6xi1>
%16 = mhlo.convert(%15) : (tensor<6xi1>) -> tensor<6xi32>
%17 = mhlo.reduce(%16 init: %13) across dimensions = [0] : (tensor<6xi32>, tensor<i32>) -> tensor<i32>
reducer(%arg1: tensor<i32>, %arg2: tensor<i32>) {
%20 = mhlo.add %arg1, %arg2 : tensor<i32>
mhlo.return %20 : tensor<i32>
}
// 0维度为动态维度,维度大小,即为统计的个数
%18 = "mhlo.set_dimension_size"(%12, %17) {dimension = 0 : i64} : (tensor<6x2xi32>, tensor<i32>) -> tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>
%19 = "mhlo.tuple"(%18) {xla_shape = "(s32[<=6,2]{1,0})"} : (tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>>
return %19 : tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>>
}
}