当前位置: 首页 > news >正文

torch.meshgrid()

前言:

在看b站up 霹雳W的mask rcnn课程, 训练数据遇到报错,发现一些函数使用版本已经更新,故记录一下。

   def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
        r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.

        This is helpful when you want to visualize data over some
        range of inputs. See below for a plotting example.

        Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
        inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
        this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
        G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
        the output :math:`G_i` is constructed by expanding :math:`T_i`
        to the result shape.

        .. note::
            0D inputs are treated equivalently to 1D inputs of a
            single element.

        .. warning::
            `torch.meshgrid(*tensors)` currently has the same behavior
            as calling `numpy.meshgrid(*arrays, indexing='ij')`.

            In the future `torch.meshgrid` will transition to
            `indexing='xy'` as the default.

            https://github.com/pytorch/pytorch/issues/50276 tracks
            this issue with the goal of migrating to NumPy's behavior.

        .. seealso::

            :func:`torch.cartesian_prod` has the same effect but it
            collects the data in a tensor of vectors.

        Args:
            tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
                treated as tensors of size :math:`(1,)` automatically

            indexing: (str, optional): the indexing mode, either "xy"
                or "ij", defaults to "ij". See warning for future changes.

                If "xy" is selected, the first dimension corresponds
                to the cardinality of the second input and the second
                dimension corresponds to the cardinality of the first
                input.

                If "ij" is selected, the dimensions are in the same
                order as the cardinality of the inputs.

        Returns:
            seq (sequence of Tensors): If the input has :math:`N`
            tensors of size :math:`S_0 \ldots S_{N-1}``, then the
            output will also have :math:`N` tensors, where each tensor
            is of shape :math:`(S_0, ..., S_{N-1})`.

        Example::

            >>> x = torch.tensor([1, 2, 3])
            >>> y = torch.tensor([4, 5, 6])

            Observe the element-wise pairings across the grid, (1, 4),
            (1, 5), ..., (3, 6). This is the same thing as the
            cartesian product.
            >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
            >>> grid_x
            tensor([[1, 1, 1],
                    [2, 2, 2],
                    [3, 3, 3]])
            >>> grid_y
            tensor([[4, 5, 6],
                    [4, 5, 6],
                    [4, 5, 6]])

            This correspondence can be seen when these grids are
            stacked properly.
            >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
            ...             torch.cartesian_prod(x, y))
            True

            `torch.meshgrid` is commonly used to produce a grid for
            plotting.
            >>> # xdoctest: +REQUIRES(module:matplotlib)
            >>> import matplotlib.pyplot as plt
            >>> xs = torch.linspace(-5, 5, steps=100)
            >>> ys = torch.linspace(-5, 5, steps=100)
            >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
            >>> z = torch.sin(torch.sqrt(x * x + y * y))
            >>> ax = plt.axes(projection='3d')
            >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
            >>> plt.show()

        .. image:: ../_static/img/meshgrid.png
            :width: 512

        """
        return _meshgrid(*tensors, indexing=indexing)

解释

`torch.meshgrid` 函数用于创建坐标网格,它接受一个或多个一维张量作为输入,并返回一个网格张量列表。这个函数在处理多维数据时非常有用,特别是在可视化和数据采样等场景中。

### 参数说明
- **`tensors`**: 一个或多个一维张量,表示需要生成网格的坐标轴。
- **`indexing`**: 一个可选的字符串参数,用于指定索引模式,可以是 `"xy"` 或 `"ij"`。默认值为 `"ij"`。
  - **`"xy"`**: 表示笛卡尔坐标索引顺序(列优先)。
  - **`"ij"`**: 表示矩阵索引顺序(行优先)。

### 返回值
返回一个包含网格张量的元组,每个张量的形状为 `(S_0, S_1, ..., S_{N-1})`,其中 `S_i` 是输入张量的长度。

### 示例代码
```python
import torch

# 创建两个一维张量
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

# 使用默认的索引模式 "ij"
grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')

print("grid_x:")
print(grid_x)
print("grid_y:")
print(grid_y)
```

### 输出结果
```
grid_x:
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
grid_y:
tensor([[4, 5, 6],
        [4, 5, 6],
        [4, 5, 6]])
```

### 注意事项
- **`indexing` 参数的默认值**:在未来的 PyTorch 版本中,`indexing` 参数的默认值将从 `"ij"` 更改为 `"xy"`。因此,建议显式指定 `indexing` 参数,以确保代码的兼容性。
- **与 NumPy 的区别**:`torch.meshgrid` 的行为与 NumPy 的 `np.meshgrid` 有所不同。NumPy 默认使用 `"xy"` 模式,而 PyTorch 默认使用 `"ij"` 模式。

### GitHub 问题
在 GitHub 问题中提到,`torch.meshgrid` 的行为与 NumPy 的 `np.meshgrid` 不一致。PyTorch 计划在未来版本中将默认的 `indexing` 参数更改为 `"xy"`,以与 NumPy 的行为保持一致。这个更改将影响依赖于当前默认行为的代码,建议开发者在代码中显式指定 `indexing` 参数,以避免未来的兼容性问题。

遇见报错:

up发布的代码:

rpn_function.py   142-147行

# 计算预测特征矩阵上每个点对应原图上的坐标(anchors模板的坐标偏移量)
            # torch.meshgrid函数分别传入行坐标和列坐标,生成网格行坐标矩阵和网格列坐标矩阵
            # shape: [grid_height, grid_width]
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)

日志中出现以下警告:

UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.

原因:PyTorch 1.10+ 对 torch.meshgrid 的接口进行了修改,需显式指定 indexing 参数('ij' 或 'xy')。

解决方案

  1. 定位代码中调用 meshgrid 的位置

    • 此警告通常出现在 Anchor 生成或 RoI 对齐等代码中。

    • 示例修改:

      # 原始代码(可能引发警告)
      grid_x, grid_y = torch.meshgrid(x, y)
      
      # 修改后(添加 indexing='ij')
      grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')

       

参考资料:

1.b站up 

霹雳吧啦Wz

2.kimi,deepseek

http://www.dtcms.com/a/122988.html

相关文章:

  • 【OCR】总结目前流行的主要的OCR工具
  • Jenkins安装流程
  • 联邦学习研读笔记
  • printf
  • 【NLP 面经 9、逐层分解Transformer】
  • 第十一章 Python语言-高阶技巧(终章)
  • Dubbo(44)如何排查Dubbo的服务依赖问题?
  • 17. git pull
  • 6、nRF52xx蓝牙学习(nrf_gpiote.c库函数学习)
  • 基于 AI智能体、大模型、RAG、Agent 等技术构建公司内部闭环智能问答系统的详细方案,结合 Spring Boot + Vue 管理系统 的改造思路
  • Http代理服务器选型与搭建
  • Starrocks的Bitmap索引和Bloom filter索引以及全局字典
  • 基于微信小程序的志愿服务系统的设计与实现
  • 数字图像处理作业3
  • fuse-python使用fuse来挂载fs
  • 汽车软件开发常用的建模工具汇总
  • Joomla 常用模块 - 在线用户与Joomla 常用模块 - 自定义HTML模块
  • [leetcode]判断质数
  • 关于C++日志库spdlog
  • JS 函数提升
  • 蓝桥杯十一届C++B组真题题解
  • 革新电销流程,数企云外呼开启便捷 “直通车”
  • 各种场景的ARP攻击描述笔记(超详细)
  • stream流Collectors.toMap(),key值重复问题
  • Bootstrap Table动态修改列标题
  • C++中命名空间namespace|头文件h文件|源文件cpp文件详解
  • pyecharts常用图形
  • Mysql索引(二)
  • 8.第二阶段x64游戏实战-string类
  • UE学习记录part15