当前位置: 首页 > news >正文

pytorch小记(二十四):PyTorch 中的 `torch.full` 全面指南

pytorch小记(二十四):PyTorch 中的 `torch.full` 全面指南

    • PyTorch 中的 `torch.full` 全面指南
      • 一、接口定义
      • 二、参数详解
      • 三、常见使用场景
      • 四、具体示例与输出
      • 五、关键字参数设计原理
    • 总结


PyTorch 中的 torch.full 全面指南

在深度学习中,有时需要创建一个所有元素都相同的张量,例如作为常数初始值或掩码。PyTorch 提供了 torch.full 接口,功能灵活且参数丰富。下面我们按模块逐一展开。


一、接口定义

torch.full(*sizes,fill_value,out=None,dtype=None,layout=torch.strided,device=None,requires_grad=False) → Tensor
  • 返回值:形状由 *sizes 决定,所有位置都填 fill_value 的张量。

  • 等价签名

    torch.full(size: Tuple[int, ...],fill_value: Number,out: Tensor = None,dtype: torch.dtype = None,layout: torch.layout = torch.strided,device: torch.device = None,requires_grad: bool = False) → Tensor
    

二、参数详解

参数说明示例
*sizes张量形状:多个位置参数(如 2,3,4),或一个整型元组 (2,3)torch.full(2,3,4, fill_value=5)
fill_value要填充的值,常见为标量(intfloatboolfill_value=7
out可选,已有张量写入结果,in-place;keyword-onlytorch.full((2,2), 9, out=my_tensor)
dtype输出数据类型;若与 fill_value 类型不匹配,会做转换dtype=torch.float64
layout存储布局,默认为 torch.strided(稠密张量)layout=torch.strided
device输出张量设备,如 "cpu""cuda:0"device='cuda:0'
requires_grad是否开启梯度追踪(常用于可学习参数)requires_grad=True

注意

  • outdtypelayoutdevicerequires_grad 均为 关键字参数,必须以 key=value 形式传入,否则会被误认为是形状(sizes)的一部分。

三、常见使用场景

  1. 创建常数张量
    作为偏置、掩码或特殊标志值:

    bias = torch.full((batch_size, num_features), 0.1)
    mask = torch.full((H, W), True, dtype=torch.bool)
    
  2. 初始化权重
    固定常数初始化:

    self.weight = torch.full((out_channels, in_channels), fill_value=0.01, requires_grad=True)
    
  3. 占位符
    在复杂流程中预分配内存:

    out = torch.empty(3, 3)
    const = torch.full((3,3), 5, out=out)  # 直接写入 out
    
  4. 与其他 API 配合

    # full_like:沿用现有张量形状
    ref = torch.zeros(2,4)
    filled = torch.full_like(ref, 3.14)  # 结果 shape=(2,4), dtype=float32
    

四、具体示例与输出

以下示例固定随机种子(对 full 无影响,仅为演示一致性),并展示每步输出。

import torch
torch.manual_seed(0)# 示例 1:最基础的 (2,3) 常数张量
a = torch.full(2, 7)  
# 等价于 torch.full((2,), 7)
print("a:", a)  
# 输出:
# a: tensor([7, 7])# 示例 2:二维常数(位置参数 vs tuple)
b = torch.full(2, 3, fill_value=-1)  # shape=(2,), fill_value=-1
print("\nb:", b)
# b: tensor([-1, -1])c = torch.full((2,3), 5)
print("\nc:", c)
# c:
# tensor([[5, 5, 5],
#         [5, 5, 5]])d = torch.full(2, 3, fill_value=9)  # 填充 9
print("\nd:", d)
# d: tensor([9, 9])# 示例 3:指定 dtype 和 device
e = torch.full((2,2), 3.14, dtype=torch.float64)
print("\ne:", e, "\ne.dtype =", e.dtype)
# e:
# tensor([[3.1400, 3.1400],
#         [3.1400, 3.1400]], dtype=torch.float64)# (假设有 GPU 环境)
# f = torch.full((1,3), 0, device='cuda:0')
# print("\nf.device =", f.device)# 示例 4:使用 out 关键字
out = torch.empty(2,2)
torch.full((2,2), 42, out=out)
print("\nout(after full):", out)
# out:
# tensor([[42, 42],
#         [42, 42]])# 示例 5:requires_grad=True
g = torch.full((3,), 1.0, requires_grad=True)
print("\ng:", g, "; requires_grad =", g.requires_grad)
# g: tensor([1., 1., 1.], requires_grad=True)

五、关键字参数设计原理

在 Python 里,当函数签名中出现 *sizes 时,所有位置参数都会被收集到 sizes 这个元组里,作为张量的形状。如果把 outdtype 等也当位置参数传入,就会被当作形状维度导致类型错误或逻辑混乱。因此,PyTorch 将它们设计成 keyword-only arguments,只允许 out=…dtype=…device=… 等形式出现,保证了接口的清晰与安全。


总结

  • 接口灵活torch.full(2,3, fill_value=val)torch.full((2,3), val) 都可;
  • 关键字参数outdtypelayoutdevicerequires_grad 强制使用 key=value 形式;
  • 常见场景:常数张量、权重初始化、占位符分配等;
  • 示例丰富:提供了形状、数据类型、设备、in-place 输出、梯度追踪等全方位示例。

相关文章:

  • 每日算法刷题Day11 5.20:leetcode不定长滑动窗口求最长/最大6道题,结束不定长滑动窗口求最长/最大,用时1h20min
  • python-leetcode 69.最小栈
  • YOLO中model.predict方法返回内容Results详解
  • WSL虚拟机整体迁移教程(如何将WSL从C盘迁移到其他盘)
  • 物流项目第四期(运费模板列表实现)
  • 战略游戏--树形dp
  • 《初入苍穹:大一新手的编程成长之旅》
  • ACS ANM突破:微波一步法合成多孔吸波材料——焦耳加热技术如何赋能材料创新?
  • JAVASE查漏补缺
  • 无人机精准降落与避障模块技术解析
  • Java 01简单集合
  • HarmonyOS5云服务技术分享--ArkTS开发函数
  • 【深入理解索引扩展—1】提升智能检索系统召回质量的3大利器
  • 软考软件测评师——系统安全设计(防火墙技术)
  • SpringBoot(三)--- 数据库基础
  • vitepress项目添加百度统计或者google统计方式
  • 星闪开发之buttondemo烧录后无效果思路
  • 初识Linux 进程:进程创建、终止与进程地址空间
  • 软考软件评测师——基于风险的测试技术
  • protobuf原理和使用
  • 阳朔兴坪镇:在建乾元桥“垮塌”是谣言,系降雨导致工程挡土墙倾斜
  • 美国前驻华大使携美大学生拜访中联部、外交部
  • 每一笔都是对的!再读周碧初画作有感
  • 人民日报评论员观察:稳企业,全力以赴纾困解难
  • 受贿2.61亿余元,陕西省政协原主席韩勇一审被判死缓
  • 鸿蒙电脑正式发布,国产操作系统在个人电脑领域实现重要突破