当前位置: 首页 > news >正文

torch.tensor 用法

在 PyTorch 中,torch.tensor 是创建张量的核心函数。以下是详细用法指南:


一、基础用法

import torch

# 从 Python 列表/数组创建
data = [1, 2, 3]
tensor = torch.tensor(data)  # 输出:tensor([1, 2, 3])

# 从 NumPy 数组创建(需先转 NumPy)
import numpy as np
numpy_arr = np.array([4, 5, 6])
tensor_from_np = torch.tensor(numpy_arr)  # tensor([4, 5, 6])

二、关键参数详解

torch.tensor(
    data,                    # 输入数据 (list/np.array 等)
    dtype=None,              # 指定数据类型 (torch.float32, torch.int64 等)
    device=None,             # 设备 ("cpu" 或 "cuda:0")
    requires_grad=False,     # 是否跟踪梯度
    pin_memory=False         # 是否固定内存(加速 GPU 传输)
)

三、常用场景示例

1. 指定数据类型
# 创建浮点型张量
float_tensor = torch.tensor([1, 2, 3], dtype=torch.float32)  # tensor([1., 2., 3.])

# 创建布尔型张量
bool_tensor = torch.tensor([0, 1, 0], dtype=torch.bool)      # tensor([False,  True, False])
2. 设备选择 (CPU/GPU)
# 在 GPU 上创建张量 (需 CUDA 可用)
gpu_tensor = torch.tensor([7, 8, 9], device="cuda:0")

# 将 CPU 张量移动到 GPU
cpu_tensor = torch.tensor([10, 11, 12])
gpu_tensor = cpu_tensor.to("cuda:0")
3. 梯度跟踪
# 创建需要计算梯度的张量
x = torch.tensor([3.0], requires_grad=True)
y = x**2
y.backward()  # 自动计算梯度
print(x.grad) # 输出:tensor([6.])

四、与其他创建方式的区别

方法特点示例
torch.tensor()显式复制数据并推断类型torch.tensor([1, 2, 3])
torch.Tensor()类构造函数,默认 dtype=float32torch.Tensor([1, 2, 3]) → float
torch.from_numpy()与 NumPy 共享内存torch.from_numpy(np_array)

五、注意事项

  1. 数据复制行为torch.tensor() 会复制输入数据,若需要共享内存用 torch.from_numpy()
  2. 隐式类型转换
    # 列表包含混合类型时会向上转型
    mixed_data = [1, 2.0]  # 会被转为 float
    
  3. 性能优化:对于大型数据集,优先使用 torch.utils.data.DataLoader 而非多次调用 torch.tensor

六、高级技巧

# 创建未初始化张量(需后续填充数据)
empty_tensor = torch.tensor([], dtype=torch.float32)
empty_tensor.resize_(3, 4)  # 调整形状为 3x4

# 从生成器创建张量
gen = (i*2 for i in range(5))  # 生成器表达式
tensor_from_gen = torch.tensor(list(gen))  # tensor([0, 2, 4, 6, 8])

如果需要创建特殊张量(全零/单位矩阵等),建议使用:

torch.zeros(2, 3)    # 全零矩阵
torch.ones(3, 2)     # 全一矩阵
torch.eye(4)         # 单位矩阵
torch.randn(5, 5)    # 标准正态分布随机数
http://www.dtcms.com/a/98484.html

相关文章:

  • OpenAI API - 快速入门开发
  • 链表(C++)
  • WPF 自定义行为AssociatedObject详解
  • 全包圆玛奇朵样板间亮相,极简咖啡风引领家装新潮流
  • 程序化广告行业(39/89):广告投放的数据分析与优化秘籍
  • 腾讯系AI应用,可以生视频,3D模型...
  • 北森测评的经验
  • 二层框架组合实验
  • linux压缩指令
  • 数据结构与算法:算法分析
  • 轮询和长轮询
  • html5基于Canvas的动态时钟实现详解
  • 论文内可解释性分析
  • 《ZooKeeper Zab协议深度剖析:构建高可用分布式系统的基石》
  • 0101-vite创建react_ts-环境准备-仿低代码平台项目
  • latex笔记
  • 复现文献中的三维重建图像生成,包括训练、推理和可视化
  • StarRocks 存算分离在京东物流的落地实践
  • GOC L2 第四课模运算和周期
  • 软件工程之需求工程(需求获取、分析、验证)
  • Unity顶点优化:UV Splits与Smoothing Splits消除技巧
  • 基于 Python 深度学习 lstm 算法的电影评论情感分析可视化系统(2.0 系统全新升级,已获高分通过)
  • CUDA专题3:为什么GPU能改变计算?深度剖析架构、CUDA®与可扩展编程
  • 软件信息安全性测试工具有哪些?安全性测试报告如何获取?
  • C++ 类型转换
  • java基础以及内存图
  • presto任务优化参数
  • RAG、大模型与智能体的关系
  • Binlog、Redo log、Undo log的区别
  • 【从零实现Json-Rpc框架】- 项目实现 - Dispatcher模块实现篇