torch.meshgrid()
前言:
在看b站up 霹雳W的mask rcnn课程, 训练数据遇到报错,发现一些函数使用版本已经更新,故记录一下。
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
This is helpful when you want to visualize data over some
range of inputs. See below for a plotting example.
Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
the output :math:`G_i` is constructed by expanding :math:`T_i`
to the result shape.
.. note::
0D inputs are treated equivalently to 1D inputs of a
single element.
.. warning::
`torch.meshgrid(*tensors)` currently has the same behavior
as calling `numpy.meshgrid(*arrays, indexing='ij')`.
In the future `torch.meshgrid` will transition to
`indexing='xy'` as the default.
https://github.com/pytorch/pytorch/issues/50276 tracks
this issue with the goal of migrating to NumPy's behavior.
.. seealso::
:func:`torch.cartesian_prod` has the same effect but it
collects the data in a tensor of vectors.
Args:
tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
treated as tensors of size :math:`(1,)` automatically
indexing: (str, optional): the indexing mode, either "xy"
or "ij", defaults to "ij". See warning for future changes.
If "xy" is selected, the first dimension corresponds
to the cardinality of the second input and the second
dimension corresponds to the cardinality of the first
input.
If "ij" is selected, the dimensions are in the same
order as the cardinality of the inputs.
Returns:
seq (sequence of Tensors): If the input has :math:`N`
tensors of size :math:`S_0 \ldots S_{N-1}``, then the
output will also have :math:`N` tensors, where each tensor
is of shape :math:`(S_0, ..., S_{N-1})`.
Example::
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([4, 5, 6])
Observe the element-wise pairings across the grid, (1, 4),
(1, 5), ..., (3, 6). This is the same thing as the
cartesian product.
>>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
>>> grid_x
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
>>> grid_y
tensor([[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
This correspondence can be seen when these grids are
stacked properly.
>>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
... torch.cartesian_prod(x, y))
True
`torch.meshgrid` is commonly used to produce a grid for
plotting.
>>> # xdoctest: +REQUIRES(module:matplotlib)
>>> import matplotlib.pyplot as plt
>>> xs = torch.linspace(-5, 5, steps=100)
>>> ys = torch.linspace(-5, 5, steps=100)
>>> x, y = torch.meshgrid(xs, ys, indexing='xy')
>>> z = torch.sin(torch.sqrt(x * x + y * y))
>>> ax = plt.axes(projection='3d')
>>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
>>> plt.show()
.. image:: ../_static/img/meshgrid.png
:width: 512
"""
return _meshgrid(*tensors, indexing=indexing)
解释
`torch.meshgrid` 函数用于创建坐标网格,它接受一个或多个一维张量作为输入,并返回一个网格张量列表。这个函数在处理多维数据时非常有用,特别是在可视化和数据采样等场景中。
### 参数说明
- **`tensors`**: 一个或多个一维张量,表示需要生成网格的坐标轴。
- **`indexing`**: 一个可选的字符串参数,用于指定索引模式,可以是 `"xy"` 或 `"ij"`。默认值为 `"ij"`。
- **`"xy"`**: 表示笛卡尔坐标索引顺序(列优先)。
- **`"ij"`**: 表示矩阵索引顺序(行优先)。
### 返回值
返回一个包含网格张量的元组,每个张量的形状为 `(S_0, S_1, ..., S_{N-1})`,其中 `S_i` 是输入张量的长度。
### 示例代码
```python
import torch
# 创建两个一维张量
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
# 使用默认的索引模式 "ij"
grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
print("grid_x:")
print(grid_x)
print("grid_y:")
print(grid_y)
```
### 输出结果
```
grid_x:
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
grid_y:
tensor([[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
```
### 注意事项
- **`indexing` 参数的默认值**:在未来的 PyTorch 版本中,`indexing` 参数的默认值将从 `"ij"` 更改为 `"xy"`。因此,建议显式指定 `indexing` 参数,以确保代码的兼容性。
- **与 NumPy 的区别**:`torch.meshgrid` 的行为与 NumPy 的 `np.meshgrid` 有所不同。NumPy 默认使用 `"xy"` 模式,而 PyTorch 默认使用 `"ij"` 模式。
### GitHub 问题
在 GitHub 问题中提到,`torch.meshgrid` 的行为与 NumPy 的 `np.meshgrid` 不一致。PyTorch 计划在未来版本中将默认的 `indexing` 参数更改为 `"xy"`,以与 NumPy 的行为保持一致。这个更改将影响依赖于当前默认行为的代码,建议开发者在代码中显式指定 `indexing` 参数,以避免未来的兼容性问题。
遇见报错:
up发布的代码:
rpn_function.py 142-147行
# 计算预测特征矩阵上每个点对应原图上的坐标(anchors模板的坐标偏移量)
# torch.meshgrid函数分别传入行坐标和列坐标,生成网格行坐标矩阵和网格列坐标矩阵
# shape: [grid_height, grid_width]
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
日志中出现以下警告:
UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.
原因:PyTorch 1.10+ 对 torch.meshgrid
的接口进行了修改,需显式指定 indexing
参数('ij'
或 'xy'
)。
解决方案:
-
定位代码中调用
meshgrid
的位置:-
此警告通常出现在 Anchor 生成或 RoI 对齐等代码中。
-
示例修改:
# 原始代码(可能引发警告) grid_x, grid_y = torch.meshgrid(x, y) # 修改后(添加 indexing='ij') grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
-
参考资料:
1.b站up
霹雳吧啦Wz
2.kimi,deepseek