当前位置: 首页 > news >正文

Pytorch中torch.where()函数详解和实战示例

torch.where() 是 PyTorch 中非常常用的一个函数,功能类似于 NumPy 的 where,用于条件筛选或三元选择操作。在深度学习训练、掩码操作、损失函数处理等场景中非常常见。


一、基本语法

torch.where(condition, x, y)
  • condition:一个布尔张量(torch.bool 类型),和 xy 的 shape 必须可广播。
  • x:满足条件时取的值。
  • y:不满足条件时取的值。

二、功能说明

  • 如果只传入一个参数 conditiontorch.where(condition) 将返回 非零元素的坐标(类似 nonzero())。

  • 如果传入三个参数 condition, x, y,则类似于三元表达式:

    result = x if condition else y
    

三、示例详解

示例 1:三元选择(条件替换)

import torcha = torch.tensor([1, 2, 3, 4])
b = torch.tensor([10, 20, 30, 40])
cond = torch.tensor([True, False, True, False])out = torch.where(cond, a, b)
print(out)  # tensor([ 1, 20,  3, 40])

解释:满足 cond=True 的地方取 a,否则取 b


示例 2:只有 condition 参数,返回索引

x = torch.tensor([[0, 1], [2, 0]])
pos = torch.where(x > 0)print(pos)
# 输出: (tensor([0, 1]), tensor([1, 0]))
# 表示非零位置是 [0,1] 和 [1,0]

如果你希望将这些坐标转换成可访问的形式:

coordinates = list(zip(pos[0].tolist(), pos[1].tolist()))
# [(0, 1), (1, 0)]

示例 3:广播行为支持

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[100]])
mask = torch.tensor([[True, False], [False, True]])out = torch.where(mask, a, b)
print(out)
# tensor([[  1, 100],
#         [100,   4]])

示例 4:用于神经网络中的掩码操作(常见)

logits = torch.tensor([0.2, -0.5, 0.7, -1.0])
mask = logits > 0
output = torch.where(mask, logits, torch.zeros_like(logits))
print(output)  # tensor([0.2000, 0.0000, 0.7000, 0.0000])

示例 5:替代负无穷值(比如处理 log(0) 的场景)

eps = 1e-6
x = torch.tensor([0.5, 0.0, 1.0])
safe_x = torch.where(x > 0, x, torch.tensor(eps))
logx = torch.log(safe_x)
print(logx)

四、常见用途总结

场景用法示例
条件替换torch.where(mask, a, b)
去除负值/NaN/0torch.where(x > 0, x, eps)
多分类掩码处理torch.where(onehot_mask, pred, 0)
找到满足条件的索引idxs = torch.where(x > 0.5)
广播与标量搭配torch.where(mask, x, torch.tensor(0))

六、实战示例

下面是 torch.where()分类问题损失函数图像掩码 场景下的实战用法示例和解释,非常适合深度学习任务中使用。


1. 分类问题中使用 torch.where(二分类/多分类掩码)

场景:筛选预测为正类的样本做统计
import torch# 假设 logits 是模型输出,labels 是真实标签
logits = torch.tensor([0.9, 0.3, 0.8, 0.1])
labels = torch.tensor([1, 0, 1, 0])# 二分类掩码
positive_mask = labels == 1# 只保留正样本对应的预测概率
positive_preds = torch.where(positive_mask, logits, torch.tensor(0.0))
print(positive_preds)
# tensor([0.9000, 0.0000, 0.8000, 0.0000])

2. 自定义损失函数中的 torch.where

示例:加权 BCE Loss,给予正类更高的权重
def weighted_bce_loss(pred, target, pos_weight=2.0):bce_loss = -(target * torch.log(pred + 1e-6) + (1 - target) * torch.log(1 - pred + 1e-6))weights = torch.where(target == 1, torch.tensor(pos_weight), torch.tensor(1.0))weighted_loss = weights * bce_lossreturn weighted_loss.mean()# 示例输入
pred = torch.tensor([0.9, 0.2, 0.7, 0.1])  # 模型预测
target = torch.tensor([1.0, 0.0, 1.0, 0.0])  # 标签loss = weighted_bce_loss(pred, target)
print(loss)

使用 torch.where 为正样本动态加权。


3. 图像处理中的 torch.where(掩码操作)

示例:用掩码提取前景像素或设定背景为固定值
import torch# 假设灰度图像:H×W
image = torch.tensor([[100, 120, 130],[90, 0, 50],[255, 200, 180]
], dtype=torch.float32)# 二值掩码(比如图像分割输出)
mask = image > 100# 将背景像素设为0(即屏蔽背景)
masked_image = torch.where(mask, image, torch.tensor(0.0))
print(masked_image)

输出:

tensor([[  0., 120., 130.],[  0.,   0.,   0.],[255., 200., 180.]])

避免除以 0:
denominator = torch.where(denom != 0, denom, torch.tensor(1e-6))
替代 NaN 或 Inf:
x = torch.tensor([1.0, float('nan'), 2.0, float('inf')])
cleaned = torch.where(torch.isfinite(x), x, torch.tensor(0.0))

4.语义分割 中的应用示例:

  • 忽略某些像素(如 ignore index);
  • 可视化前景掩码;
  • 动态计算某些类的准确率、IoU;
  • 针对背景与前景的不同加权 loss;
  • 将预测 mask 显示成彩色图像。

1. 忽略标签为 255 的像素(ignore_index

import torchpred = torch.tensor([[1, 2, 0],[0, 1, 2]
])
target = torch.tensor([[1, 255, 0],[0, 1, 255]
])# 忽略标签为 255 的像素
mask = target != 255
correct = torch.where(mask, pred == target, torch.tensor(False))# 精度统计(不含 ignore)
acc = correct.sum().float() / mask.sum()
print("Accuracy (excluding ignore_index):", acc.item())

** 2. 前景/背景 mask 处理(比如用于 loss)**

# 假设标签中 0 为背景,1 为前景
label = torch.tensor([[0, 0, 1],[1, 1, 0]
])is_foreground = torch.where(label == 1, torch.tensor(1.0), torch.tensor(0.0))
print(is_foreground)
# tensor([[0., 0., 1.],
#         [1., 1., 0.]])

你可以用这个 mask 计算前景区域的 loss 或平均值。


** 3. 可视化分割结果(转彩色)**

import torch# 假设预测结果为标签图(整数类 id)
label_map = torch.tensor([[0, 1, 2],[1, 2, 0]
], dtype=torch.int64)# 假设有 3 个类,对应 RGB 颜色如下
colors = torch.tensor([[0, 0, 0],       # class 0 -> 黑色[255, 0, 0],     # class 1 -> 红色[0, 255, 0],     # class 2 -> 绿色
], dtype=torch.uint8)# 转换为彩色图像
color_image = colors[label_map]
print(color_image.shape)  # torch.Size([2, 3, 3]),对应 H x W x C

如果你要保存图像(使用 OpenCV):

import cv2
cv2.imwrite("seg_output.png", color_image.numpy())

** 4. 类别不平衡:前景加权 loss**

# logits 为 (N, C, H, W),labels 为 (N, H, W)
def weighted_cross_entropy(logits, labels, fg_weight=5.0, ignore_index=255):N, C, H, W = logits.shapelogits_flat = logits.permute(0, 2, 3, 1).reshape(-1, C)labels_flat = labels.reshape(-1)# 计算 lossloss = torch.nn.functional.cross_entropy(logits_flat, labels_flat,reduction='none', ignore_index=ignore_index)# 生成权重:前景类别加权weights = torch.ones_like(labels_flat, dtype=torch.float32)weights = torch.where(labels_flat != 0, torch.tensor(fg_weight), torch.tensor(1.0))weights = torch.where(labels_flat == ignore_index, torch.tensor(0.0), weights)loss = loss * weightsreturn loss.sum() / (weights.sum() + 1e-6)

小结:语义分割中常见的 torch.where() 用法

任务示例
忽略 ignore_index 像素mask = label != 255
筛选前景像素fg_mask = torch.where(label == 1, 1.0, 0.0)
自定义 loss 权重weights = torch.where(label == 1, 5.0, 1.0)
彩色可视化分割图colors[label_map]
分割输出中统计精度correct = torch.where(mask, pred == target, 0)

五、补充说明

  • condition 必须是 bool 类型张量。
  • xy 的形状需 可广播
  • torch.where 是支持 GPU 的(放在 cuda() 后依然生效)。

http://www.dtcms.com/a/265175.html

相关文章:

  • AIGC自我介绍笔记
  • Redis基础(1):NoSQL认识
  • sqlmap学习笔记ing(3.[MoeCTF 2022]Sqlmap_boy,cookie的作用)
  • UniApp完美对接RuoYi框架开发企业级应用
  • 基于 ethers.js 的区块链事件处理与钱包管理
  • UI前端大数据可视化实战技巧:动态数据加载与刷新策略
  • 【AI智能体】Coze 搭建个人旅游规划助手实战详解
  • 【Rancher Server + Kubernets】- Nginx-ingress日志持久化至宿主机
  • Pillow 安装使用教程
  • AI之Tool:Glean的简介、安装和使用方法、案例应用之详细攻略
  • 监测检测一体化项目实践——整体功能规划
  • uniapp实现图片预览,懒加载,下拉刷新等
  • 基于 TOF 图像高频信息恢复 RGB 图像的原理、应用与实现
  • 重要版本:无需关闭UAC通知的TOS无线USB助手1.0.4,它来了(2025-07-02)
  • 操作系统考试大题-处理机调度算法-详解-1
  • 2025-暑期训练二
  • 通过具有一致性嵌入的大语言模型实现端到端乳腺癌放射治疗计划制定|文献速递-最新论文分享
  • AlpineLinux安装部署zabbix
  • 进程概念以及相关函数
  • 进程(起个开头,复习的一天)day26
  • 轻松上手:使用Nginx实现高效负载均衡
  • 应用密码学纲要
  • 怎样理解:source ~/.bash_profile
  • 决策树(Decision tree)算法详解(ID3、C4.5、CART)
  • 在线学堂-3.媒资管理模块(二)
  • 软件反调试(2)- 基于窗口列表的检测
  • 外侧三兵策略
  • 睿抗省赛2023
  • 【通识】机器学习相关
  • YOLOv11剪枝与量化(二)通道剪枝技术原理