当前位置: 首页 > news >正文

pytorch中不同的mask方法:masked_fill, masked_select, masked_scatter

在 PyTorch 中,masked_fillmasked_selectmasked_scatter 是三种常用的掩码(mask)操作方法,它们通过布尔类型的掩码张量(mask)对原始张量进行条件筛选或修改。以下是它们的详细解释和对比:


1. masked_fill

作用:将原始张量中 maskTrue 的位置用指定值填充,其余位置保持不变。

参数
mask(BoolTensor):与原始张量形状相同的布尔掩码。
value(标量):用于填充的值。

特点
原地操作:会直接修改原始张量(除非使用 masked_fill_ 的 in-place 版本)。
保持形状:输出张量形状与输入张量一致。

示例

import torch

x = torch.tensor([[1, 2], [3, 4]])
mask = torch.tensor([[False, True], [True, False]])

y = x.masked_fill(mask, -1)
print(y)
# 输出:
# tensor([[ 1, -1],
#         [-1,  4]])

典型应用
• 在 Transformer 的注意力机制中,用 -inf 填充 padding 或未来的位置,使 softmax 后概率为 0。
• 数据清洗时屏蔽无效值(如 NaN)。


2. masked_select

作用:从原始张量中提取 maskTrue 的元素,返回一维张量。

参数
mask(BoolTensor):与原始张量形状相同的布尔掩码。

特点
返回一维张量:输出会丢失原始张量的维度信息。
非原地操作:生成新的张量。

示例

x = torch.tensor([[1, 2], [3, 4]])
mask = torch.tensor([[False, True], [True, False]])

y = x.masked_select(mask)
print(y)  # tensor([2, 3])

典型应用
• 提取满足条件的元素(如分类任务中筛选正样本)。
• 统计掩码区域的值(如计算非零元素均值)。


3. masked_scatter

作用:将另一个张量(source)中的值按顺序填充到原始张量中 maskTrue 的位置。

参数
mask(BoolTensor):与原始张量形状相同的布尔掩码。
source(Tensor):提供填充值的源张量。

特点
按顺序填充source 中的值按行优先顺序填充到 maskTrue 的位置。
source 的长度必须 ≥ maskTrue 的数量。

示例

x = torch.tensor([[1, 2], [3, 4]])
mask = torch.tensor([[False, True], [True, False]])
source = torch.tensor([10, 20])

y = x.masked_scatter(mask, source)
print(y)
# 输出:
# tensor([[ 1, 10],
#         [20,  4]])

典型应用
• 动态替换张量中的部分值(如用随机噪声替换特定区域)。
• 批量更新参数时选择性地替换某些位置。


对比总结

方法输入张量形状输出形状是否修改原张量核心功能
masked_fill保留原形状与原张量相同是(可选)用标量填充掩码区域
masked_select保留原形状一维张量提取掩码区域的元素
masked_scatter保留原形状与原张量相同是(可选)用另一张量填充掩码区域

关键注意事项

  1. 掩码形状匹配mask 必须与原始张量形状严格一致,否则会报错。
  2. 数据类型mask 必须是布尔类型(BoolTensor)。
  3. 梯度传播:所有操作均支持自动求导,但填充的值(如 masked_fill 中的 value)需是浮点数以避免类型错误。
  4. 性能:对大规模张量频繁使用这些操作可能影响性能,建议优先使用向量化操作。

选择方法指南

• 需要保持形状并填充标量masked_fill
• 需要提取元素并丢弃形状masked_select
• 需要按顺序替换为另一张量的值masked_scatter

通过合理使用这些方法,可以高效实现条件筛选、数据清洗、动态修改等任务。

相关文章:

  • MySQL 当中的锁
  • 网络运维学习笔记(DeepSeek优化版)026 OSPF vlink(Virtual Link,虚链路)配置详解
  • 深度学习 Deep Learning 第13章 线性因子模型
  • PyQt6实例_批量下载pdf工具_批量pdf网址获取
  • 3.30学习总结 Java包装类+高精度算法+查找算法
  • 开发环境解决Secure Cookie导致302重定向
  • VUE实现框架搭建(纯手写)
  • 【Python爬虫神器】requests库常用操作详解 ,附实战案例
  • RocketMQ - 从消息可靠传输谈高可用
  • Cookie可以存哪些指?
  • 一区严选!挑战5天一篇脂质体组学 DAY1-5
  • Flink介绍——实时计算核心论文之S4论文详解
  • RS232转Profinet网关扫码器在西门子1200plc快速配置
  • MySQL中的CREATE TABLE LIKE和CREATE TABLE SELECT
  • 关于为什么使用redis锁,不使用zk锁的原因
  • LeetCode知识点整理
  • golang 的time包的常用方法
  • 通过 Adobe Acrobat DC 实现 Word 到 PDF 的不可逆转换
  • HTML5和CSS3的一些特性
  • fastdds:传输层端口号计算规则
  • 网站建设哪家/付费推广
  • 怎么给网站做外链邵连虎/网络宣传
  • 动态网站开发过程ppt/企业推广方式
  • 安徽省建设厅网站人员管理/郑州网络推广服务
  • php网站开发教程下载/黄页网络的推广网站有哪些类型
  • 做旅游网站的要求/聊城今日头条最新