当前位置: 首页 > news >正文

人工智能-python-深度学习-tensor基操

文章目录

    • Tensor常见操作
      • 1. 获取元素值
        • 1.1 获取单个元素的值
        • 1.2 获取指定维度的元素
      • 2. 元素值运算
        • 2.1 加法和减法
        • 2.2 乘法和除法
        • 2.3 指数和对数运算
      • 3. 阿达玛积
      • 4. Tensor相乘
        • 4.1 矩阵乘法
        • 4.2 点积
      • 5. 形状操作
        • 5.1 查看形状
        • 5.2 重塑形状
            • 5.2.1内存连续性
        • 5.3 转置
        • 5.4 transpose
        • 5.5 permute
        • 5.6 升维和降维
        • 5.5.1 squeeze降维
        • 5.5.2 unsqueeze升维
      • 6. 广播机制
        • 6.1 广播机制规则
        • 6.2 广播案例
    • ✨ Tensor操作核心总结
      • 元素操作
      • 维度掌控
      • 智能广播
    • 💎 实践建议

Tensor常见操作

1. 获取元素值

1.1 获取单个元素的值

可以通过索引获取Tensor中的单个元素。例如,获取Tensor中第一个元素的值:

import torch
tensor = torch.tensor([[1, 2], [3, 4]])
element = tensor[0, 0]  # 获取第一个元素
print(element)  # 输出:1
1.2 获取指定维度的元素

可以通过torch.index_select或者使用切片获取指定维度的元素:

row = torch.index_select(tensor, 0, torch.tensor([0]))  # 获取第一行
print(row)  # 输出:tensor([[1, 2]])

2. 元素值运算

2.1 加法和减法

Tensor之间可以直接进行加法和减法运算,支持广播机制:

tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
result_add = tensor1 + tensor2
result_sub = tensor1 - tensor2
print(result_add)  # 输出:tensor([5, 7, 9])
print(result_sub)  # 输出:tensor([-3, -3, -3])
2.2 乘法和除法

除了加法和减法,还可以进行乘法和除法:

result_mul = tensor1 * tensor2  # 元素级乘法
result_div = tensor1 / tensor2  # 元素级除法
print(result_mul)  # 输出:tensor([4, 10, 18])
print(result_div)  # 输出:tensor([0.2500, 0.4000, 0.5000])
2.3 指数和对数运算

可以进行指数运算和对数运算:

tensor = torch.tensor([1.0, 2.0, 3.0])
result_exp = torch.exp(tensor)  # 求指数
result_log = torch.log(tensor)  # 求自然对数
print(result_exp)  # 输出:tensor([ 2.7183,  7.3891, 20.0855])
print(result_log)  # 输出:tensor([0.0000, 0.6931, 1.0986])

3. 阿达玛积

阿达玛积(Hadamard Product)是指两个相同形状的Tensor之间的逐元素乘积,也叫做元素级乘法:

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
hadamard_product = tensor1 * tensor2
print(hadamard_product)
# 输出:
# tensor([[ 5, 12],
#         [21, 32]])

4. Tensor相乘

4.1 矩阵乘法

矩阵乘法可以使用torch.matmul@运算符:

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
result_matmul = torch.matmul(tensor1, tensor2)
print(result_matmul)
# 输出:
# tensor([[19, 22],
#         [43, 50]])
4.2 点积

点积(内积)是通过torch.dottorch.mm进行:

tensor1 = torch.tensor([1, 2])
tensor2 = torch.tensor([3, 4])
result_dot = torch.dot(tensor1, tensor2)
print(result_dot)  # 输出:11

5. 形状操作

5.1 查看形状

通过.shape属性可以查看Tensor的形状:

tensor = torch.randn(3, 4)
print(tensor.shape)  # 输出:torch.Size([3, 4])
5.2 重塑形状

使用torch.viewtorch.reshape改变Tensor的形状:

tensor = torch.randn(6)
reshaped_tensor = tensor.view(2, 3)
print(reshaped_tensor)
5.2.1内存连续性

张量的内存布局决定了其元素在内存中的存储顺序。对于多维张量,内存布局通常按照最后一个维度优先的顺序存储,即先存列,后存行。列如对于一个张量A,其形状为(m,n),其内存布局是先存储第0行的所有列元素,然后是第1行的所有列元素,以此内推。

5.3 转置

使用.t()方法可以进行转置操作:

tensor = torch.randn(3, 4)
transposed_tensor = tensor.t()
print(transposed_tensor)
5.4 transpose

transpose 用于交换张量的两个维度,注意,是2个维度,它返回的是原张量的视图。

torch.transpose(input, dim0, dim1)

参数

  • input: 输入的张量。
  • dim0: 要交换的第一个维度。
  • dim1: 要交换的第二个维度。
import torchdef test003():data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)# 使用transpose进行形状变换transpose_data = torch.transpose(data,0,1)# transpose_data = data.transpose(0, 1)print(transpose_data, transpose_data.shape)if __name__ == "__main__":test003()
5.5 permute

它通过重新排列张量的维度来返回一个新的张量,不改变张量的数据,只改变维度的顺序。

torch.permute(input, dims)

参数

  • input: 输入的张量。
  • dims: 一个整数元组,表示新的维度顺序。
import torchdef test004():data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)# 使用permute进行多维度形状变换permute_data = data.permute(1, 2, 0)print(permute_data, permute_data.shape)if __name__ == "__main__":test004()

和 transpose 一样,permute 返回新张量,原张量不变。

重排后的张量可能是非连续的(is_contiguous() 返回 False),必要时需调用 .contiguous():

y = x.permute(2, 1, 0).contiguous()

维度顺序必须合法:dims 中的维度顺序必须包含所有原始维度,且不能重复或遗漏。例如,对于一个形状为 (2, 3, 4) 的张量,dims=(2, 0, 1) 是合法的,但 dims=(0, 1) 或 dims=(0, 1, 2, 3) 是非法的。

与 transpose() 的对比

特性permute()transpose()
功能可以同时调整多个维度的顺序只能交换两个维度的顺序
灵活性更灵活较简单
使用场景适用于多维张量适用于简单的维度交换
5.6 升维和降维

在后续的网络训练学习中,升维和降维是常用操作,需要掌握。

  • unsqueeze:用于在指定位置插入一个大小为 1 的新维度。

  • squeeze:用于移除所有大小为 1 的维度,或者移除指定维度的大小为 1 的维度。

5.5.1 squeeze降维
torch.squeeze(input, dim=None)

参数

  • input: 输入的张量。
  • dim (可选): 指定要移除的维度。如果指定了 dim,则只移除该维度(前提是该维度大小为 1);如果不指定,则移除所有大小为 1 的维度。
import torchdef test006():data = torch.randint(0, 10, (1, 4, 5, 1))print(data, data.shape)# 进行降维操作data1 = data.squeeze(0).squeeze(-1)print(data.shape)# 移除所有大小为 1 的维度data2 = torch.squeeze(data)# 尝试移除第 1 维(大小为 3,不为 1,不会报错,张量保持不变。)data3 = torch.squeeze(data, dim=1)print("尝试移除第 1 维后的形状:", data3.shape)if __name__ == "__main__":test006()
5.5.2 unsqueeze升维
torch.unsqueeze(input, dim)

参数

  • input: 输入的张量。
  • dim: 指定要增加维度的位置(从 0 开始索引)。
import torchdef test007():data = torch.randint(0, 10, (32, 32, 3))print(data.shape)# 升维操作data = data.unsqueeze(0)print(data.shape)if __name__ == "__main__":test007()

6. 广播机制

6.1 广播机制规则

广播机制是指当操作两个不同形状的Tensor时,PyTorch会自动调整Tensor的形状,使它们兼容进行逐元素运算。广播机制遵循以下规则:

  1. 从后向前对齐:两个Tensor的形状从最右边开始对齐,逐维度检查。
  2. 如果两个维度不同,则大小为1的维度会被扩展:如果一个Tensor在某维度上没有对应的元素(即该维度的大小为1),它会将该维度的元素沿该维度进行广播。
6.2 广播案例

例如,两个Tensor的形状分别是 (3, 1)(3, 4),广播机制会将形状为 (3, 1) 的Tensor沿第二维度复制,直到它和 (3, 4) 形状兼容:

tensor1 = torch.randn(3, 1)
tensor2 = torch.randn(3, 4)
result = tensor1 + tensor2
print(result.shape)  # 输出:torch.Size([3, 4])

✨ Tensor操作核心总结

PyTorch中的Tensor操作是深度学习开发的基石,其设计兼顾高效性与灵活性。以下是对关键操作的精华提炼:
🔍 核心要点

元素操作

索引切片:支持多维索引(tensor[i,j])和函数索引(index_select)
数学运算:原生支持±*/和广播机制,专有函数实现exp/log等高级运算
乘积类型:区分元素级(阿达玛积 *)与矩阵级(@/matmul)运算

维度掌控

透视结构:.shape查看维度,view/reshape改变形状(注意内存连续性)
维度变换
transpose:交换两个维度
permute:自由重排所有维度顺序(需配合.contiguous()保证内存连续)
维度压缩
squeeze:消除大小为1的维度
unsqueeze:在指定位置插入新维度

智能广播

自动扩展:从右向左对齐维度,将大小为1的维度复制扩展(如(3,1)+(3,4)→(3,4))
规则优先:维数不等时前置补1,维度大小需为1或相等

💎 实践建议

形状操作首选:permute处理复杂维度变换,view用于连续内存重塑
广播验证:使用tensor.expand_as()显式检查广播兼容性
维度控制
:unsqueeze/squeeze特别适用于网络输入/输出适配
掌握这些操作,Tensor将成为你手中灵活的数据魔方!🚀

http://www.dtcms.com/a/348712.html

相关文章:

  • 数学建模(摸索中……)
  • CUDA安装,pytorch库安装
  • 如何实现模版引擎
  • Shell 学习笔记 - Shell 三剑客篇
  • unity热更新总结
  • 【如何使用Redis实现分布式锁详解讲解】
  • [快乐数](哈希表)
  • 解决编译osgEarth中winsocket2.h找不到头文件问题
  • 基于Spark的热门旅游景点数据分析系统的设计-django+spider
  • Spring Boot测试陷阱:失败测试为何“传染”其他用例?
  • 【追涨抄底关注】副图指标 紫色主力线上行表明资金介入明显 配合价格突破时可靠性更高
  • deepseek连接solidworks设计一台非标设备 (part1)
  • 阿里云ECS服务器搭建ThinkPHP环境
  • 互联网大厂AI/大模型应用开发工程师面试剧本与解析
  • 阿里云云数据库RDS PostgreSQL管控功能使用
  • 基于SpringBoot的婚纱影楼服务预约平台【2026最新】
  • Spring AI 学习笔记(2)
  • GitHub 热榜项目 - 日榜(2025-08-24)
  • Wireshark USRP联合波形捕获(下)
  • windows上如何实现把指定网段的流量转发到指定的端口,有哪些界面化的软件用来配置完成,类似于 Linux中的iptables规则实现
  • 6.1Element UI布局容器
  • 【Luogu】P2602 [ZJOI2010] 数字计数 (数位DP)
  • 基于大模型的对话式推荐系统技术架构设计-- 大数据平台层
  • 07 - spring security基于数据库的账号密码
  • window11无法连接Fortinet SSL VPN
  • Elasticsearch如何确保数据一致性?
  • 『深度编码』操作系统-进程之间的通信方法
  • 记录一下TVT投稿过程
  • 阿里云大模型应用实战:从技术落地到业务提效
  • Dify 从入门到精通(第 53/100 篇):Dify 的分布式架构(进阶篇)