pytorch小记(二十四):PyTorch 中的 `torch.full` 全面指南
pytorch小记(二十四):PyTorch 中的 `torch.full` 全面指南
- PyTorch 中的 `torch.full` 全面指南
- 一、接口定义
- 二、参数详解
- 三、常见使用场景
- 四、具体示例与输出
- 五、关键字参数设计原理
- 总结
PyTorch 中的 torch.full
全面指南
在深度学习中,有时需要创建一个所有元素都相同的张量,例如作为常数初始值或掩码。PyTorch 提供了 torch.full
接口,功能灵活且参数丰富。下面我们按模块逐一展开。
一、接口定义
torch.full(*sizes,fill_value,out=None,dtype=None,layout=torch.strided,device=None,requires_grad=False) → Tensor
-
返回值:形状由
*sizes
决定,所有位置都填fill_value
的张量。 -
等价签名:
torch.full(size: Tuple[int, ...],fill_value: Number,out: Tensor = None,dtype: torch.dtype = None,layout: torch.layout = torch.strided,device: torch.device = None,requires_grad: bool = False) → Tensor
二、参数详解
参数 | 说明 | 示例 |
---|---|---|
*sizes | 张量形状:多个位置参数(如 2,3,4 ),或一个整型元组 (2,3) | torch.full(2,3,4, fill_value=5) |
fill_value | 要填充的值,常见为标量(int 、float 、bool ) | fill_value=7 |
out | 可选,已有张量写入结果,in-place;keyword-only | torch.full((2,2), 9, out=my_tensor) |
dtype | 输出数据类型;若与 fill_value 类型不匹配,会做转换 | dtype=torch.float64 |
layout | 存储布局,默认为 torch.strided (稠密张量) | layout=torch.strided |
device | 输出张量设备,如 "cpu" 、"cuda:0" | device='cuda:0' |
requires_grad | 是否开启梯度追踪(常用于可学习参数) | requires_grad=True |
注意:
out
、dtype
、layout
、device
、requires_grad
均为 关键字参数,必须以key=value
形式传入,否则会被误认为是形状(sizes)的一部分。
三、常见使用场景
-
创建常数张量
作为偏置、掩码或特殊标志值:bias = torch.full((batch_size, num_features), 0.1) mask = torch.full((H, W), True, dtype=torch.bool)
-
初始化权重
固定常数初始化:self.weight = torch.full((out_channels, in_channels), fill_value=0.01, requires_grad=True)
-
占位符
在复杂流程中预分配内存:out = torch.empty(3, 3) const = torch.full((3,3), 5, out=out) # 直接写入 out
-
与其他 API 配合
# full_like:沿用现有张量形状 ref = torch.zeros(2,4) filled = torch.full_like(ref, 3.14) # 结果 shape=(2,4), dtype=float32
四、具体示例与输出
以下示例固定随机种子(对 full 无影响,仅为演示一致性),并展示每步输出。
import torch
torch.manual_seed(0)# 示例 1:最基础的 (2,3) 常数张量
a = torch.full(2, 7)
# 等价于 torch.full((2,), 7)
print("a:", a)
# 输出:
# a: tensor([7, 7])# 示例 2:二维常数(位置参数 vs tuple)
b = torch.full(2, 3, fill_value=-1) # shape=(2,), fill_value=-1
print("\nb:", b)
# b: tensor([-1, -1])c = torch.full((2,3), 5)
print("\nc:", c)
# c:
# tensor([[5, 5, 5],
# [5, 5, 5]])d = torch.full(2, 3, fill_value=9) # 填充 9
print("\nd:", d)
# d: tensor([9, 9])# 示例 3:指定 dtype 和 device
e = torch.full((2,2), 3.14, dtype=torch.float64)
print("\ne:", e, "\ne.dtype =", e.dtype)
# e:
# tensor([[3.1400, 3.1400],
# [3.1400, 3.1400]], dtype=torch.float64)# (假设有 GPU 环境)
# f = torch.full((1,3), 0, device='cuda:0')
# print("\nf.device =", f.device)# 示例 4:使用 out 关键字
out = torch.empty(2,2)
torch.full((2,2), 42, out=out)
print("\nout(after full):", out)
# out:
# tensor([[42, 42],
# [42, 42]])# 示例 5:requires_grad=True
g = torch.full((3,), 1.0, requires_grad=True)
print("\ng:", g, "; requires_grad =", g.requires_grad)
# g: tensor([1., 1., 1.], requires_grad=True)
五、关键字参数设计原理
在 Python 里,当函数签名中出现 *sizes
时,所有位置参数都会被收集到 sizes
这个元组里,作为张量的形状。如果把 out
、dtype
等也当位置参数传入,就会被当作形状维度导致类型错误或逻辑混乱。因此,PyTorch 将它们设计成 keyword-only arguments,只允许 out=…
、dtype=…
、device=…
等形式出现,保证了接口的清晰与安全。
总结
- 接口灵活:
torch.full(2,3, fill_value=val)
或torch.full((2,3), val)
都可; - 关键字参数:
out
、dtype
、layout
、device
、requires_grad
强制使用key=value
形式; - 常见场景:常数张量、权重初始化、占位符分配等;
- 示例丰富:提供了形状、数据类型、设备、in-place 输出、梯度追踪等全方位示例。