当前位置: 首页 > news >正文

PyTorch Tensor 操作入门:转换、运算、维度变换

目录

1. Tensor 与 NumPy 数组的转换

1.1 Tensor 转换为 NumPy 数组

1.2 NumPy 数组转换为 Tensor

1.3 获取单个元素的值

2. Tensor 的基本运算

2.1 生成新 Tensor 的运算

2.2 覆盖原 Tensor 的运算

2.3 阿达玛积(逐元素乘法)

2.4 矩阵乘法

3. Tensor 的形状变换

3.1 view() 方法

3.2 reshape() 方法

4. 维度变换

4.1 transpose() 方法

4.2 permute() 方法

5. 完整代码示例

6. 总结


在深度学习中,PyTorch 的 Tensor 是核心数据结构,它类似于 NumPy 的数组,但可以在 GPU 上高效运行。除了创建 Tensor,PyTorch 还提供了丰富的操作方法,包括 Tensor 与 NumPy 数组的转换、基本运算、维度变换等。今天,我们就通过一个简单的代码示例,学习这些基本操作。

1. Tensor 与 NumPy 数组的转换

PyTorch 提供了非常方便的接口,用于在 Tensor 和 NumPy 数组之间进行转换。这在实际应用中非常有用,因为 NumPy 是 Python 中处理数组的标准库。

1.1 Tensor 转换为 NumPy 数组

t1 = torch.tensor([1, 2, 3, 4, 5])
n1 = t1.numpy()
print(n1)
  • t1.numpy():将 Tensor 转换为 NumPy 数组。注意,这种转换是浅拷贝,即 NumPy 数组和 Tensor 共享内存。

1.2 NumPy 数组转换为 Tensor

t2 = torch.tensor(n1)
print(t2)
  • torch.tensor(n1):将 NumPy 数组转换为 Tensor。这种转换是深拷贝,即生成一个新的 Tensor,不共享内存。

t3 = torch.from_numpy(n1)
print(t3)
  • torch.from_numpy(n1):将 NumPy 数组转换为 Tensor。这种转换是浅拷贝,即 Tensor 和 NumPy 数组共享内存。

1.3 获取单个元素的值

t4 = torch.tensor([18])
print(t4.item())
  • t4.item():当 Tensor 只有一个元素时,可以使用 item() 获取该元素的值。

2. Tensor 的基本运算

PyTorch 提供了丰富的运算操作,包括加法、减法、乘法和除法。这些运算可以分为两类:生成新 Tensor 的操作和覆盖原 Tensor 的操作。

2.1 生成新 Tensor 的运算

t1 = torch.randint(1, 10, (3, 2))
print(t1.add(1))
  • t1.add(1):对 t1 的每个元素加 1,结果生成一个新的 Tensor。

2.2 覆盖原 Tensor 的运算

print(t1.add_(1))
  • t1.add_(1):对 t1 的每个元素加 1,结果覆盖原 Tensor。

2.3 阿达玛积(逐元素乘法)

t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = t1 * t2
print(t3)
  • t1 * t2:逐元素乘法,即对应位置的元素相乘。

2.4 矩阵乘法

t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6], [7, 8]])
t3 = torch.matmul(t1, t2)
print(t3)
  • torch.matmul(t1, t2):矩阵乘法,符合矩阵乘法的规则。

3. Tensor 的形状变换

在深度学习中,经常需要对 Tensor 的形状进行变换,例如在卷积神经网络中调整输入数据的维度。PyTorch 提供了 view()reshape() 方法来实现这一点。

3.1 view() 方法

t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t2 = t1.view(3, 2)
print(t2)
  • t1.view(3, 2):将 Tensor 的形状从 (2, 3) 变为 (3, 2)。注意,view() 要求 Tensor 的内存是连续的。

3.2 reshape() 方法

t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t3 = t1.reshape(3, 2)
print(t3)
  • t1.reshape(3, 2):与 view() 类似,但 reshape() 不要求内存是连续的。

4. 维度变换

在处理多维数据时,经常需要对 Tensor 的维度进行变换,例如在处理图像数据时交换通道维度。

4.1 transpose() 方法

t1 = torch.randint(1, 20, (3, 4, 5))
t2 = torch.transpose(t1, 0, 1)
print(t2)
  • torch.transpose(t1, 0, 1):交换 Tensor 的第 0 维和第 1 维。

4.2 permute() 方法

t3 = t1.permute(1, 0, 2)
print(t3)
  • t1.permute(1, 0, 2):可以同时交换多个维度,非常灵活。

5. 完整代码示例

import torchdef test01():t1 = torch.tensor([1,2,3,4,5])# numpy():将tensor转换为numpy数组,浅拷贝:如果要深拷贝,需要使用copy()# tensor():将numpy数组转换为tensor,深拷贝# from_numpy():将numpy数组转换为tensor,浅拷贝n1 = t1.numpy()print(n1)t2 = torch.tensor(n1)print(t2)t3 = torch.from_numpy(n1)print(t3)# item():当tensor只有一个元素时,使用item()获取该元素的值# t4 = torch.tensor(18)t4 = torch.tensor([18])print(t4)print(t4.item())# t5 = torch.tensor([18],device='cuda')# print(t5.item())def test02():torch.manual_seed(0)# tensor运算# add, sub , mul, div等,计算结果会生成新的tensor# add_, sub_, mul_, div_等,计算结果会覆盖原来的tensort1 = torch.randint(1 , 10, (3, 2))print(t1)print(t1.add(1))print(t1)print(t1.add_(1))print(t1)'''
阿达码积:两个矩阵对应位置相乘,得到一个新的矩阵
Cij = Aij * Bij
运算符号: mul或者*
矩阵运算:(m,p) * (p,n) = (m,n)
'''
def test03():t1 = torch.tensor([1,2],[3,4])t2 = torch.tensor([5,6],[7,8])t3 = t1 * t2print(t3)'''
view():改变tensor的形状,不改变tensor的数据,内存是连续的
reshape():改变tensor的形状,不改变tensor的数据,内存不连续
'''def test04():t1 = torch.tensor([1,2,3],[4,5,6])print(t1.is_contiguous())t2 = t1.view(3, 2)print(t2.is_contiguous())t3 = t1.t()print(t3)print(t3.is_contiguous())t4 = t3.view(2, 3)print(t4.is_contiguous())'''
维度变换
transpose():转置,交换张量的两个维度, 只能交换两个维度
permute(input,dims):维度变换,可以交换多个维度
'''
def test05():t1 = torch.randint(1, 20, (3, 4, 5))print(t1)t2 = torch.transpose(t1, 0, 1)print(t2)print(t2.is_contiguous())t3 = t1.permute(t1, (1, 0, 2))print(t3)print(t3.shape)if __name__ == '__main__':# test01()# test02()# test03()# test04()test05()

6. 总结

通过这篇文章,我们学习了 PyTorch 中 Tensor 的基本操作,包括:

  • 如何在 Tensor 和 NumPy 数组之间进行转换。

  • 如何进行基本的数学运算。

  • 如何改变 Tensor 的形状。

  • 如何对 Tensor 的维度进行变换。

这些操作是深度学习的基础,希望这篇文章能帮助你更好地理解和使用 PyTorch!

http://www.dtcms.com/a/271863.html

相关文章:

  • 【NLP入门系列六】Word2Vec模型简介,与以《人民的名义》小说原文实践
  • IPv4和IPv6双栈配置
  • 【K8S】Kubernetes 使用 Ingress-Nginx 基于 Cookie 实现会话保持的负载均衡
  • HCIA第一次实验报告:静态路由综合实验
  • day11-微服务面试篇
  • C++11 std::is_sorted 和 std::is_sorted_until 原理解析
  • CentOs 7 MySql8.0.23之前的版本主从复制
  • 无缝矩阵与普通矩阵的对比分析
  • 中老年人的陪伴,猫咪与机器人玩具有什么区别?
  • Java 与 MySQL 性能优化:MySQL连接池参数优化与性能提升
  • MySQL(127)如何解决主从同步失败问题?
  • adb 简介与常用命令
  • 分布式ID 与自增区别
  • 虚拟储能与分布式光伏协同优化:新型电力系统的灵活性解决方案
  • 异步I/O库:libuv、libev、libevent与libeio
  • 从0到1:Python与DeepSeek的深度融合指南
  • jupyter 和 kernel 之间的关系
  • .net服务器Kestrel 与反向代理
  • 【TCP/IP】11. IP 组播
  • 【C语言】学习过程教训与经验杂谈:思想准备、知识回顾(六)
  • 【博主亲测可用】PS2025最新版:Adobe Photoshop 2025 v26.8.1 激活版(附安装教程)
  • Apache Dubbo实战:JavaSDK使用
  • 前端面试十一之TS
  • 服务器重装后如何“复活”旧硬盘上的 Anaconda 环境?—— 一次完整的排错与恢复记录
  • 计算机学科专业基础综合(408)四门核心课程的知识点总结
  • 微信小程序101~110
  • 以太网基础⑤UDP 协议原理与 FPGA 实现
  • 2025年7月9日学习笔记——模式识别与机器学习——fisher线性回归、感知器、最小二乘法、最小误差判别算法、罗杰斯特回归算法——线性分类器
  • 【TCP/IP】1. 概述
  • AI赋能生活:深度解析与技术洞察