torch.zeros()用法简介
torch.zeros()
是PyTorch中用于创建全零张量的核心函数,其功能和使用方法如下:
1. 基本语法
torch.zeros(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
参数说明:
*size
:定义张量形状的整数序列(如(3,4)
或3,4
)。dtype
:指定数据类型(如torch.float32
、torch.int64
),默认为torch.float32
。device
:指定存储设备(CPU/GPU)。requires_grad
:是否启用梯度计算(默认为False
)。
2. 典型示例
- 创建3×4的浮点型零矩阵:
x = torch.zeros(3, 4) # 输出为3行4列的全零张量
- 指定数据类型为整数:
y = torch.zeros(2, 3, dtype=torch.int32) # 生成整型零张量
在GPU上创建张量:
-
z = torch.zeros(5, device='cuda') # 生成GPU上的零向量
3. 与torch.empty()
的区别
torch.zeros()
会显式初始化所有元素为0,而torch.empty()
仅分配内存,内容未初始化(可能含随机值)。
4. 应用场景
- 初始化模型参数或缓冲区。
- 作为累加器或占位张量使用。
通过灵活调整参数,可满足不同维度和数据类型的零张量需求。