人工智能-python-深度学习-tensor基操
文章目录
- Tensor常见操作
- 1. 获取元素值
- 1.1 获取单个元素的值
- 1.2 获取指定维度的元素
- 2. 元素值运算
- 2.1 加法和减法
- 2.2 乘法和除法
- 2.3 指数和对数运算
- 3. 阿达玛积
- 4. Tensor相乘
- 4.1 矩阵乘法
- 4.2 点积
- 5. 形状操作
- 5.1 查看形状
- 5.2 重塑形状
- 5.2.1内存连续性
- 5.3 转置
- 5.4 transpose
- 5.5 permute
- 5.6 升维和降维
- 5.5.1 squeeze降维
- 5.5.2 unsqueeze升维
- 6. 广播机制
- 6.1 广播机制规则
- 6.2 广播案例
- ✨ Tensor操作核心总结
- 元素操作
- 维度掌控
- 智能广播
- 💎 实践建议
Tensor常见操作
1. 获取元素值
1.1 获取单个元素的值
可以通过索引获取Tensor中的单个元素。例如,获取Tensor中第一个元素的值:
import torch
tensor = torch.tensor([[1, 2], [3, 4]])
element = tensor[0, 0] # 获取第一个元素
print(element) # 输出:1
1.2 获取指定维度的元素
可以通过torch.index_select
或者使用切片获取指定维度的元素:
row = torch.index_select(tensor, 0, torch.tensor([0])) # 获取第一行
print(row) # 输出:tensor([[1, 2]])
2. 元素值运算
2.1 加法和减法
Tensor之间可以直接进行加法和减法运算,支持广播机制:
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
result_add = tensor1 + tensor2
result_sub = tensor1 - tensor2
print(result_add) # 输出:tensor([5, 7, 9])
print(result_sub) # 输出:tensor([-3, -3, -3])
2.2 乘法和除法
除了加法和减法,还可以进行乘法和除法:
result_mul = tensor1 * tensor2 # 元素级乘法
result_div = tensor1 / tensor2 # 元素级除法
print(result_mul) # 输出:tensor([4, 10, 18])
print(result_div) # 输出:tensor([0.2500, 0.4000, 0.5000])
2.3 指数和对数运算
可以进行指数运算和对数运算:
tensor = torch.tensor([1.0, 2.0, 3.0])
result_exp = torch.exp(tensor) # 求指数
result_log = torch.log(tensor) # 求自然对数
print(result_exp) # 输出:tensor([ 2.7183, 7.3891, 20.0855])
print(result_log) # 输出:tensor([0.0000, 0.6931, 1.0986])
3. 阿达玛积
阿达玛积(Hadamard Product)是指两个相同形状的Tensor之间的逐元素乘积,也叫做元素级乘法:
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
hadamard_product = tensor1 * tensor2
print(hadamard_product)
# 输出:
# tensor([[ 5, 12],
# [21, 32]])
4. Tensor相乘
4.1 矩阵乘法
矩阵乘法可以使用torch.matmul
或@
运算符:
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
result_matmul = torch.matmul(tensor1, tensor2)
print(result_matmul)
# 输出:
# tensor([[19, 22],
# [43, 50]])
4.2 点积
点积(内积)是通过torch.dot
或torch.mm
进行:
tensor1 = torch.tensor([1, 2])
tensor2 = torch.tensor([3, 4])
result_dot = torch.dot(tensor1, tensor2)
print(result_dot) # 输出:11
5. 形状操作
5.1 查看形状
通过.shape
属性可以查看Tensor的形状:
tensor = torch.randn(3, 4)
print(tensor.shape) # 输出:torch.Size([3, 4])
5.2 重塑形状
使用torch.view
或torch.reshape
改变Tensor的形状:
tensor = torch.randn(6)
reshaped_tensor = tensor.view(2, 3)
print(reshaped_tensor)
5.2.1内存连续性
张量的内存布局决定了其元素在内存中的存储顺序。对于多维张量,内存布局通常按照最后一个维度优先的顺序存储,即先存列,后存行。列如对于一个张量A,其形状为(m,n),其内存布局是先存储第0行的所有列元素,然后是第1行的所有列元素,以此内推。
5.3 转置
使用.t()
方法可以进行转置操作:
tensor = torch.randn(3, 4)
transposed_tensor = tensor.t()
print(transposed_tensor)
5.4 transpose
transpose 用于交换张量的两个维度,注意,是2个维度,它返回的是原张量的视图。
torch.transpose(input, dim0, dim1)
参数
- input: 输入的张量。
- dim0: 要交换的第一个维度。
- dim1: 要交换的第二个维度。
import torchdef test003():data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)# 使用transpose进行形状变换transpose_data = torch.transpose(data,0,1)# transpose_data = data.transpose(0, 1)print(transpose_data, transpose_data.shape)if __name__ == "__main__":test003()
5.5 permute
它通过重新排列张量的维度来返回一个新的张量,不改变张量的数据,只改变维度的顺序。
torch.permute(input, dims)
参数
- input: 输入的张量。
- dims: 一个整数元组,表示新的维度顺序。
import torchdef test004():data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)# 使用permute进行多维度形状变换permute_data = data.permute(1, 2, 0)print(permute_data, permute_data.shape)if __name__ == "__main__":test004()
和 transpose 一样,permute 返回新张量,原张量不变。
重排后的张量可能是非连续的(is_contiguous() 返回 False),必要时需调用 .contiguous():
y = x.permute(2, 1, 0).contiguous()
维度顺序必须合法:dims 中的维度顺序必须包含所有原始维度,且不能重复或遗漏。例如,对于一个形状为 (2, 3, 4) 的张量,dims=(2, 0, 1) 是合法的,但 dims=(0, 1) 或 dims=(0, 1, 2, 3) 是非法的。
与 transpose() 的对比
特性 | permute() | transpose() |
---|---|---|
功能 | 可以同时调整多个维度的顺序 | 只能交换两个维度的顺序 |
灵活性 | 更灵活 | 较简单 |
使用场景 | 适用于多维张量 | 适用于简单的维度交换 |
5.6 升维和降维
在后续的网络训练学习中,升维和降维是常用操作,需要掌握。
-
unsqueeze:用于在指定位置插入一个大小为 1 的新维度。
-
squeeze:用于移除所有大小为 1 的维度,或者移除指定维度的大小为 1 的维度。
5.5.1 squeeze降维
torch.squeeze(input, dim=None)
参数
- input: 输入的张量。
- dim (可选): 指定要移除的维度。如果指定了 dim,则只移除该维度(前提是该维度大小为 1);如果不指定,则移除所有大小为 1 的维度。
import torchdef test006():data = torch.randint(0, 10, (1, 4, 5, 1))print(data, data.shape)# 进行降维操作data1 = data.squeeze(0).squeeze(-1)print(data.shape)# 移除所有大小为 1 的维度data2 = torch.squeeze(data)# 尝试移除第 1 维(大小为 3,不为 1,不会报错,张量保持不变。)data3 = torch.squeeze(data, dim=1)print("尝试移除第 1 维后的形状:", data3.shape)if __name__ == "__main__":test006()
5.5.2 unsqueeze升维
torch.unsqueeze(input, dim)
参数
- input: 输入的张量。
- dim: 指定要增加维度的位置(从 0 开始索引)。
import torchdef test007():data = torch.randint(0, 10, (32, 32, 3))print(data.shape)# 升维操作data = data.unsqueeze(0)print(data.shape)if __name__ == "__main__":test007()
6. 广播机制
6.1 广播机制规则
广播机制是指当操作两个不同形状的Tensor时,PyTorch会自动调整Tensor的形状,使它们兼容进行逐元素运算。广播机制遵循以下规则:
- 从后向前对齐:两个Tensor的形状从最右边开始对齐,逐维度检查。
- 如果两个维度不同,则大小为1的维度会被扩展:如果一个Tensor在某维度上没有对应的元素(即该维度的大小为1),它会将该维度的元素沿该维度进行广播。
6.2 广播案例
例如,两个Tensor的形状分别是 (3, 1)
和 (3, 4)
,广播机制会将形状为 (3, 1)
的Tensor沿第二维度复制,直到它和 (3, 4)
形状兼容:
tensor1 = torch.randn(3, 1)
tensor2 = torch.randn(3, 4)
result = tensor1 + tensor2
print(result.shape) # 输出:torch.Size([3, 4])
✨ Tensor操作核心总结
PyTorch中的Tensor操作是深度学习开发的基石,其设计兼顾高效性与灵活性。以下是对关键操作的精华提炼:
🔍 核心要点
元素操作
索引切片:支持多维索引(tensor[i,j])和函数索引(index_select)
数学运算:原生支持±*/和广播机制,专有函数实现exp/log等高级运算
乘积类型:区分元素级(阿达玛积 *)与矩阵级(@/matmul)运算
维度掌控
透视结构:.shape查看维度,view/reshape改变形状(注意内存连续性)
维度变换:
transpose:交换两个维度
permute:自由重排所有维度顺序(需配合.contiguous()保证内存连续)
维度压缩:
squeeze:消除大小为1的维度
unsqueeze:在指定位置插入新维度
智能广播
自动扩展:从右向左对齐维度,将大小为1的维度复制扩展(如(3,1)+(3,4)→(3,4))
规则优先:维数不等时前置补1,维度大小需为1或相等
💎 实践建议
形状操作首选:permute处理复杂维度变换,view用于连续内存重塑
广播验证:使用tensor.expand_as()显式检查广播兼容性
维度控制:unsqueeze/squeeze特别适用于网络输入/输出适配
掌握这些操作,Tensor将成为你手中灵活的数据魔方!🚀