当前位置: 首页 > news >正文

PyTorch 张量(Tensor)详解:从基础到实战

1. 引言

在深度学习和科学计算领域,张量(Tensor) 是最基础的数据结构。PyTorch 作为当前最流行的深度学习框架之一,其核心计算单元就是张量。与 NumPy 的 ndarray 类似,PyTorch 张量支持高效的数值计算,但额外提供了 GPU 加速 和 自动微分(Autograd) 功能,使其成为构建和训练神经网络的理想选择。

本文将全面介绍 PyTorch 张量的核心概念、基本操作、高级特性及实际应用,帮助读者掌握张量的使用方法,并理解其在深度学习中的作用。

2. 什么是张量?

张量是多维数组的泛化,可以表示不同维度的数据:

  • 0D 张量(标量):单个数值,如 torch.tensor(5)

  • 1D 张量(向量):一维数组,如 torch.tensor([1, 2, 3])

  • 2D 张量(矩阵):二维数组,如 torch.tensor([[1, 2], [3, 4]])

  • 3D+ 张量(高阶张量):如 RGB 图像(3D)、视频数据(4D)等

PyTorch 张量的主要特点:

  1. 支持 GPU 加速:可无缝切换 CPU/GPU 计算。

  2. 自动微分:用于神经网络的反向传播。

  3. 动态计算图:更灵活的模型构建方式(与 TensorFlow 1.x 的静态计算图不同)。

3. 张量的创建与初始化

3.1 从 Python 列表或 NumPy 数组创建

import torch
import numpy as np# 从列表创建
t1 = torch.tensor([1, 2, 3])  # 1D 张量
t2 = torch.tensor([[1, 2], [3, 4]])  # 2D 张量# 从 NumPy 数组创建
arr = np.array([1, 2, 3])
t3 = torch.from_numpy(arr)  # 共享内存(修改一个会影响另一个)

3.2 特殊初始化方法

# 全零张量
zeros = torch.zeros(2, 3)  # 2x3 的零矩阵# 全一张量
ones = torch.ones(2)  # [1., 1.]# 随机张量
rand_uniform = torch.rand(2, 2)  # 0~1 均匀分布
rand_normal = torch.randn(2, 2)  # 标准正态分布# 类似现有张量的形状
x = torch.tensor([[1, 2], [3, 4]])
x_like = torch.rand_like(x)  # 形状与 x 相同,值随机

4. 张量的基本属性

每个 PyTorch 张量都有以下关键属性:

x = torch.rand(2, 3, dtype=torch.float32, device="cuda")print(x.shape)      # 形状: torch.Size([2, 3])
print(x.dtype)      # 数据类型: torch.float32
print(x.device)     # 存储设备: cpu / cuda
print(x.requires_grad)  # 是否启用梯度计算(用于 Autograd)

4.1 数据类型(dtype)

PyTorch 支持多种数据类型:

  • torch.float32(默认)

  • torch.int64

  • torch.bool(布尔张量)

可以通过 .to() 方法转换:

x = torch.tensor([1, 2], dtype=torch.float32)
y = x.to(torch.int64)  # 转换为整型

4.2 设备(CPU/GPU)

PyTorch 允许张量在 CPU 或 GPU 上运行:

if torch.cuda.is_available():device = torch.device("cuda")x = x.to(device)  # 移动到 GPUy = y.to("cuda")  # 简写方式

5. 张量的基本运算

5.1 算术运算

a = torch.tensor([1, 2])
b = torch.tensor([3, 4])# 加法
c = a + b  # 等价于 torch.add(a, b)# 乘法(逐元素)
d = a * b  # [3, 8]# 矩阵乘法
mat_a = torch.rand(2, 3)
mat_b = torch.rand(3, 2)
mat_c = torch.matmul(mat_a, mat_b)  # 或 mat_a @ mat_b

5.2 形状操作

x = torch.rand(4, 4)# 改变形状(类似 NumPy 的 reshape)
y = x.view(16)  # 展平为一维张量
z = x.view(2, 8)  # 调整为 2x8# 转置
x_t = x.permute(1, 0)  # 行列交换# 扩维 / 压缩
x_expanded = x.unsqueeze(0)  # 增加一个维度(1x4x4)
x_squeezed = x_expanded.squeeze()  # 去除大小为1的维度

5.3 索引与切片

x = torch.rand(3, 4)# 取第一行
row = x[0, :]# 取前两列
cols = x[:, :2]# 布尔索引
mask = x > 0.5
filtered = x[mask]  # 返回满足条件的元素

6. 自动微分(Autograd)

PyTorch 的 autograd 模块支持自动计算梯度,适用于反向传播:

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x  # 计算图构建
y.backward()  # 反向传播
print(x.grad)  # dy/dx = 2x + 3 → 7.0

6.1 禁用梯度计算

with torch.no_grad():y = x * 2  # 不记录梯度

7. 张量与 NumPy 的互操作

PyTorch 张量可以无缝转换为 NumPy 数组:

# Tensor → NumPy
a = torch.rand(2, 2)
b = a.numpy()  # 共享内存(修改一个会影响另一个)# NumPy → Tensor
c = np.array([1, 2])
d = torch.from_numpy(c)  # 共享内存

8. 实际应用示例

8.1 线性回归(手动实现)

# 数据准备
X = torch.rand(100, 1)
y = 3 * X + 2 + 0.1 * torch.randn(100, 1)# 初始化参数
w = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 训练
lr = 0.01
for epoch in range(100):y_pred = w * X + bloss = ((y_pred - y) ** 2).mean()loss.backward()  # 计算梯度with torch.no_grad():w -= lr * w.gradb -= lr * b.gradw.grad.zero_()b.grad.zero_()print(f"w: {w.item()}, b: {b.item()}")

8.2 张量在 CNN 中的应用

import torch.nn as nn# 模拟输入(batch_size=1, channels=3, height=32, width=32)
input_tensor = torch.rand(1, 3, 32, 32)# 定义一个简单的 CNN
model = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3),nn.ReLU(),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(16 * 15 * 15, 10)  # 假设输出 10 类
)output = model(input_tensor)
print(output.shape)  # torch.Size([1, 10])

9. 总结

PyTorch 张量是深度学习的基础数据结构,支持:

  • 多维数组计算(类似 NumPy)

  • GPU 加速(大幅提升计算速度)

  • 自动微分(简化神经网络训练)

  • 动态计算图(灵活调试模型)

掌握张量的基本操作是学习 PyTorch 的关键步骤。建议读者通过官方文档和实际项目加深理解,逐步掌握张量的高级用法(如广播机制、高级索引等)。

http://www.dtcms.com/a/359738.html

相关文章:

  • 1.9 初始Memory Profiler Package
  • 面试 八股文 经典题目 - HTTPS部分(一)
  • Qt组件布局的经验
  • 深度学习数据加载实战:从 PyTorch Dataset 到食品图像分类全流程解析
  • 实现需求精准预测、运输路径优化及库存高效管理的智慧物流开源了
  • 利用 Java 爬虫获取淘宝拍立淘 API 接口数据的实战指南
  • 图片格式转换v2_tif转png tif转jpg png转tif
  • mysql深度分页
  • JVM的四大组件是什么?
  • 【贪心算法】day5
  • 暄桐林曦老师关于静坐常见问题的QA
  • 矩阵待办ios app Tech Support
  • 好用的电脑软件、工具推荐和记录
  • Labview使用modbus或S7与PLC通信
  • 微服务01
  • Java与分布式系统的集成与实现:从基础到应用!
  • 从 JDK 8 到 JDK 17
  • 【Python语法基础学习笔记】函数定义与使用
  • Spring Security 6.x 功能概览与代码示例
  • 【四位加密】2022-10-25
  • 电感值过大过小会影响什么
  • 基于VS平台的QT开发全流程指南
  • 杂谈:大模型与垂直场景融合的技术趋势
  • 线程池八股文
  • 语义分析:从读懂到理解的深度跨越
  • Python基础:函数
  • Visual Studio Code中launch.json的解析笔记
  • 【Canvas与旗帜】哥伦比亚旗圆饼
  • 【芯片测试篇】:LIN总线
  • 人工智能-python-深度学习-