Pytorch中torch.where()函数详解和实战示例
torch.where()
是 PyTorch 中非常常用的一个函数,功能类似于 NumPy 的 where
,用于条件筛选或三元选择操作。在深度学习训练、掩码操作、损失函数处理等场景中非常常见。
一、基本语法
torch.where(condition, x, y)
- condition:一个布尔张量(
torch.bool
类型),和x
、y
的 shape 必须可广播。 - x:满足条件时取的值。
- y:不满足条件时取的值。
二、功能说明
-
如果只传入一个参数
condition
,torch.where(condition)
将返回 非零元素的坐标(类似nonzero()
)。 -
如果传入三个参数
condition, x, y
,则类似于三元表达式:result = x if condition else y
三、示例详解
示例 1:三元选择(条件替换)
import torcha = torch.tensor([1, 2, 3, 4])
b = torch.tensor([10, 20, 30, 40])
cond = torch.tensor([True, False, True, False])out = torch.where(cond, a, b)
print(out) # tensor([ 1, 20, 3, 40])
解释:满足 cond=True
的地方取 a
,否则取 b
。
示例 2:只有 condition 参数,返回索引
x = torch.tensor([[0, 1], [2, 0]])
pos = torch.where(x > 0)print(pos)
# 输出: (tensor([0, 1]), tensor([1, 0]))
# 表示非零位置是 [0,1] 和 [1,0]
如果你希望将这些坐标转换成可访问的形式:
coordinates = list(zip(pos[0].tolist(), pos[1].tolist()))
# [(0, 1), (1, 0)]
示例 3:广播行为支持
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[100]])
mask = torch.tensor([[True, False], [False, True]])out = torch.where(mask, a, b)
print(out)
# tensor([[ 1, 100],
# [100, 4]])
示例 4:用于神经网络中的掩码操作(常见)
logits = torch.tensor([0.2, -0.5, 0.7, -1.0])
mask = logits > 0
output = torch.where(mask, logits, torch.zeros_like(logits))
print(output) # tensor([0.2000, 0.0000, 0.7000, 0.0000])
示例 5:替代负无穷值(比如处理 log(0) 的场景)
eps = 1e-6
x = torch.tensor([0.5, 0.0, 1.0])
safe_x = torch.where(x > 0, x, torch.tensor(eps))
logx = torch.log(safe_x)
print(logx)
四、常见用途总结
场景 | 用法示例 |
---|---|
条件替换 | torch.where(mask, a, b) |
去除负值/NaN/0 | torch.where(x > 0, x, eps) |
多分类掩码处理 | torch.where(onehot_mask, pred, 0) |
找到满足条件的索引 | idxs = torch.where(x > 0.5) |
广播与标量搭配 | torch.where(mask, x, torch.tensor(0)) |
六、实战示例
下面是 torch.where()
在 分类问题、损失函数 和 图像掩码 场景下的实战用法示例和解释,非常适合深度学习任务中使用。
1. 分类问题中使用 torch.where
(二分类/多分类掩码)
场景:筛选预测为正类的样本做统计
import torch# 假设 logits 是模型输出,labels 是真实标签
logits = torch.tensor([0.9, 0.3, 0.8, 0.1])
labels = torch.tensor([1, 0, 1, 0])# 二分类掩码
positive_mask = labels == 1# 只保留正样本对应的预测概率
positive_preds = torch.where(positive_mask, logits, torch.tensor(0.0))
print(positive_preds)
# tensor([0.9000, 0.0000, 0.8000, 0.0000])
2. 自定义损失函数中的 torch.where
示例:加权 BCE Loss,给予正类更高的权重
def weighted_bce_loss(pred, target, pos_weight=2.0):bce_loss = -(target * torch.log(pred + 1e-6) + (1 - target) * torch.log(1 - pred + 1e-6))weights = torch.where(target == 1, torch.tensor(pos_weight), torch.tensor(1.0))weighted_loss = weights * bce_lossreturn weighted_loss.mean()# 示例输入
pred = torch.tensor([0.9, 0.2, 0.7, 0.1]) # 模型预测
target = torch.tensor([1.0, 0.0, 1.0, 0.0]) # 标签loss = weighted_bce_loss(pred, target)
print(loss)
使用 torch.where
为正样本动态加权。
3. 图像处理中的 torch.where
(掩码操作)
示例:用掩码提取前景像素或设定背景为固定值
import torch# 假设灰度图像:H×W
image = torch.tensor([[100, 120, 130],[90, 0, 50],[255, 200, 180]
], dtype=torch.float32)# 二值掩码(比如图像分割输出)
mask = image > 100# 将背景像素设为0(即屏蔽背景)
masked_image = torch.where(mask, image, torch.tensor(0.0))
print(masked_image)
输出:
tensor([[ 0., 120., 130.],[ 0., 0., 0.],[255., 200., 180.]])
避免除以 0:
denominator = torch.where(denom != 0, denom, torch.tensor(1e-6))
替代 NaN 或 Inf:
x = torch.tensor([1.0, float('nan'), 2.0, float('inf')])
cleaned = torch.where(torch.isfinite(x), x, torch.tensor(0.0))
4.语义分割 中的应用示例:
- 忽略某些像素(如 ignore index);
- 可视化前景掩码;
- 动态计算某些类的准确率、IoU;
- 针对背景与前景的不同加权 loss;
- 将预测 mask 显示成彩色图像。
1. 忽略标签为 255
的像素(ignore_index
)
import torchpred = torch.tensor([[1, 2, 0],[0, 1, 2]
])
target = torch.tensor([[1, 255, 0],[0, 1, 255]
])# 忽略标签为 255 的像素
mask = target != 255
correct = torch.where(mask, pred == target, torch.tensor(False))# 精度统计(不含 ignore)
acc = correct.sum().float() / mask.sum()
print("Accuracy (excluding ignore_index):", acc.item())
** 2. 前景/背景 mask 处理(比如用于 loss)**
# 假设标签中 0 为背景,1 为前景
label = torch.tensor([[0, 0, 1],[1, 1, 0]
])is_foreground = torch.where(label == 1, torch.tensor(1.0), torch.tensor(0.0))
print(is_foreground)
# tensor([[0., 0., 1.],
# [1., 1., 0.]])
你可以用这个 mask 计算前景区域的 loss 或平均值。
** 3. 可视化分割结果(转彩色)**
import torch# 假设预测结果为标签图(整数类 id)
label_map = torch.tensor([[0, 1, 2],[1, 2, 0]
], dtype=torch.int64)# 假设有 3 个类,对应 RGB 颜色如下
colors = torch.tensor([[0, 0, 0], # class 0 -> 黑色[255, 0, 0], # class 1 -> 红色[0, 255, 0], # class 2 -> 绿色
], dtype=torch.uint8)# 转换为彩色图像
color_image = colors[label_map]
print(color_image.shape) # torch.Size([2, 3, 3]),对应 H x W x C
如果你要保存图像(使用 OpenCV):
import cv2
cv2.imwrite("seg_output.png", color_image.numpy())
** 4. 类别不平衡:前景加权 loss**
# logits 为 (N, C, H, W),labels 为 (N, H, W)
def weighted_cross_entropy(logits, labels, fg_weight=5.0, ignore_index=255):N, C, H, W = logits.shapelogits_flat = logits.permute(0, 2, 3, 1).reshape(-1, C)labels_flat = labels.reshape(-1)# 计算 lossloss = torch.nn.functional.cross_entropy(logits_flat, labels_flat,reduction='none', ignore_index=ignore_index)# 生成权重:前景类别加权weights = torch.ones_like(labels_flat, dtype=torch.float32)weights = torch.where(labels_flat != 0, torch.tensor(fg_weight), torch.tensor(1.0))weights = torch.where(labels_flat == ignore_index, torch.tensor(0.0), weights)loss = loss * weightsreturn loss.sum() / (weights.sum() + 1e-6)
小结:语义分割中常见的 torch.where()
用法
任务 | 示例 |
---|---|
忽略 ignore_index 像素 | mask = label != 255 |
筛选前景像素 | fg_mask = torch.where(label == 1, 1.0, 0.0) |
自定义 loss 权重 | weights = torch.where(label == 1, 5.0, 1.0) |
彩色可视化分割图 | colors[label_map] |
分割输出中统计精度 | correct = torch.where(mask, pred == target, 0) |
五、补充说明
condition
必须是bool
类型张量。x
和y
的形状需 可广播。torch.where
是支持 GPU 的(放在cuda()
后依然生效)。