pytorch 中meshgrid()函数详解
说明:
torch.meshgrid 将多个一维向量扩展成多维网格,使得你可以方便地获取每个网格点的坐标
函数输入:
输入多个数据类型相同的一维tensor
函数输出:
输出多个tensor,tensor的数量为输入一维向量的个数,(以两个为例,tensor行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数)
类比理解
假设你有一个 3×3 的网格:
(0,0) (1,0) (2,0)
(0,1) (1,1) (2,1)
x的坐标为[0,1,2]
y的坐标为[0,1]
torch.meshgrid可以为你生成两个二维张量,一个存储每个点的x坐标,另一个存储每个点的y坐标
import torchx = torch.tensor([0, 1, 2])
y = torch.tensor([0, 1])X, Y = torch.meshgrid(x, y, indexing='ij')print("X (x坐标网格):")
print(X)
# 输出:
# tensor([[0, 0],
# [1, 1],
# [2, 2]])print("Y (y坐标网格):")
print(Y)
# 输出:
# tensor([[0, 1],
# [0, 1],
# [0, 1]])
解释:
X[i,j] 表示第 (i,j) 位置的 x 坐标
Y[i,j] 表示第 (i,j) 位置的 y 坐标
所有点为:(X[i,j], Y[i,j])