当前位置: 首页 > wzjs >正文

中山建网站最好的公司中信建设有限责任公司项目人员配置

中山建网站最好的公司,中信建设有限责任公司项目人员配置,彩票网站维护需要几天,多用户网上商城torch.where() 是 PyTorch 中非常常用的一个函数,功能类似于 NumPy 的 where,用于条件筛选或三元选择操作。在深度学习训练、掩码操作、损失函数处理等场景中非常常见。 一、基本语法 torch.where(condition, x, y)condition:一个布尔张量&a…

torch.where() 是 PyTorch 中非常常用的一个函数,功能类似于 NumPy 的 where,用于条件筛选或三元选择操作。在深度学习训练、掩码操作、损失函数处理等场景中非常常见。


一、基本语法

torch.where(condition, x, y)
  • condition:一个布尔张量(torch.bool 类型),和 xy 的 shape 必须可广播。
  • x:满足条件时取的值。
  • y:不满足条件时取的值。

二、功能说明

  • 如果只传入一个参数 conditiontorch.where(condition) 将返回 非零元素的坐标(类似 nonzero())。

  • 如果传入三个参数 condition, x, y,则类似于三元表达式:

    result = x if condition else y
    

三、示例详解

示例 1:三元选择(条件替换)

import torcha = torch.tensor([1, 2, 3, 4])
b = torch.tensor([10, 20, 30, 40])
cond = torch.tensor([True, False, True, False])out = torch.where(cond, a, b)
print(out)  # tensor([ 1, 20,  3, 40])

解释:满足 cond=True 的地方取 a,否则取 b


示例 2:只有 condition 参数,返回索引

x = torch.tensor([[0, 1], [2, 0]])
pos = torch.where(x > 0)print(pos)
# 输出: (tensor([0, 1]), tensor([1, 0]))
# 表示非零位置是 [0,1] 和 [1,0]

如果你希望将这些坐标转换成可访问的形式:

coordinates = list(zip(pos[0].tolist(), pos[1].tolist()))
# [(0, 1), (1, 0)]

示例 3:广播行为支持

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[100]])
mask = torch.tensor([[True, False], [False, True]])out = torch.where(mask, a, b)
print(out)
# tensor([[  1, 100],
#         [100,   4]])

示例 4:用于神经网络中的掩码操作(常见)

logits = torch.tensor([0.2, -0.5, 0.7, -1.0])
mask = logits > 0
output = torch.where(mask, logits, torch.zeros_like(logits))
print(output)  # tensor([0.2000, 0.0000, 0.7000, 0.0000])

示例 5:替代负无穷值(比如处理 log(0) 的场景)

eps = 1e-6
x = torch.tensor([0.5, 0.0, 1.0])
safe_x = torch.where(x > 0, x, torch.tensor(eps))
logx = torch.log(safe_x)
print(logx)

四、常见用途总结

场景用法示例
条件替换torch.where(mask, a, b)
去除负值/NaN/0torch.where(x > 0, x, eps)
多分类掩码处理torch.where(onehot_mask, pred, 0)
找到满足条件的索引idxs = torch.where(x > 0.5)
广播与标量搭配torch.where(mask, x, torch.tensor(0))

六、实战示例

下面是 torch.where()分类问题损失函数图像掩码 场景下的实战用法示例和解释,非常适合深度学习任务中使用。


1. 分类问题中使用 torch.where(二分类/多分类掩码)

场景:筛选预测为正类的样本做统计
import torch# 假设 logits 是模型输出,labels 是真实标签
logits = torch.tensor([0.9, 0.3, 0.8, 0.1])
labels = torch.tensor([1, 0, 1, 0])# 二分类掩码
positive_mask = labels == 1# 只保留正样本对应的预测概率
positive_preds = torch.where(positive_mask, logits, torch.tensor(0.0))
print(positive_preds)
# tensor([0.9000, 0.0000, 0.8000, 0.0000])

2. 自定义损失函数中的 torch.where

示例:加权 BCE Loss,给予正类更高的权重
def weighted_bce_loss(pred, target, pos_weight=2.0):bce_loss = -(target * torch.log(pred + 1e-6) + (1 - target) * torch.log(1 - pred + 1e-6))weights = torch.where(target == 1, torch.tensor(pos_weight), torch.tensor(1.0))weighted_loss = weights * bce_lossreturn weighted_loss.mean()# 示例输入
pred = torch.tensor([0.9, 0.2, 0.7, 0.1])  # 模型预测
target = torch.tensor([1.0, 0.0, 1.0, 0.0])  # 标签loss = weighted_bce_loss(pred, target)
print(loss)

使用 torch.where 为正样本动态加权。


3. 图像处理中的 torch.where(掩码操作)

示例:用掩码提取前景像素或设定背景为固定值
import torch# 假设灰度图像:H×W
image = torch.tensor([[100, 120, 130],[90, 0, 50],[255, 200, 180]
], dtype=torch.float32)# 二值掩码(比如图像分割输出)
mask = image > 100# 将背景像素设为0(即屏蔽背景)
masked_image = torch.where(mask, image, torch.tensor(0.0))
print(masked_image)

输出:

tensor([[  0., 120., 130.],[  0.,   0.,   0.],[255., 200., 180.]])

避免除以 0:
denominator = torch.where(denom != 0, denom, torch.tensor(1e-6))
替代 NaN 或 Inf:
x = torch.tensor([1.0, float('nan'), 2.0, float('inf')])
cleaned = torch.where(torch.isfinite(x), x, torch.tensor(0.0))

4.语义分割 中的应用示例:

  • 忽略某些像素(如 ignore index);
  • 可视化前景掩码;
  • 动态计算某些类的准确率、IoU;
  • 针对背景与前景的不同加权 loss;
  • 将预测 mask 显示成彩色图像。

1. 忽略标签为 255 的像素(ignore_index

import torchpred = torch.tensor([[1, 2, 0],[0, 1, 2]
])
target = torch.tensor([[1, 255, 0],[0, 1, 255]
])# 忽略标签为 255 的像素
mask = target != 255
correct = torch.where(mask, pred == target, torch.tensor(False))# 精度统计(不含 ignore)
acc = correct.sum().float() / mask.sum()
print("Accuracy (excluding ignore_index):", acc.item())

** 2. 前景/背景 mask 处理(比如用于 loss)**

# 假设标签中 0 为背景,1 为前景
label = torch.tensor([[0, 0, 1],[1, 1, 0]
])is_foreground = torch.where(label == 1, torch.tensor(1.0), torch.tensor(0.0))
print(is_foreground)
# tensor([[0., 0., 1.],
#         [1., 1., 0.]])

你可以用这个 mask 计算前景区域的 loss 或平均值。


** 3. 可视化分割结果(转彩色)**

import torch# 假设预测结果为标签图(整数类 id)
label_map = torch.tensor([[0, 1, 2],[1, 2, 0]
], dtype=torch.int64)# 假设有 3 个类,对应 RGB 颜色如下
colors = torch.tensor([[0, 0, 0],       # class 0 -> 黑色[255, 0, 0],     # class 1 -> 红色[0, 255, 0],     # class 2 -> 绿色
], dtype=torch.uint8)# 转换为彩色图像
color_image = colors[label_map]
print(color_image.shape)  # torch.Size([2, 3, 3]),对应 H x W x C

如果你要保存图像(使用 OpenCV):

import cv2
cv2.imwrite("seg_output.png", color_image.numpy())

** 4. 类别不平衡:前景加权 loss**

# logits 为 (N, C, H, W),labels 为 (N, H, W)
def weighted_cross_entropy(logits, labels, fg_weight=5.0, ignore_index=255):N, C, H, W = logits.shapelogits_flat = logits.permute(0, 2, 3, 1).reshape(-1, C)labels_flat = labels.reshape(-1)# 计算 lossloss = torch.nn.functional.cross_entropy(logits_flat, labels_flat,reduction='none', ignore_index=ignore_index)# 生成权重:前景类别加权weights = torch.ones_like(labels_flat, dtype=torch.float32)weights = torch.where(labels_flat != 0, torch.tensor(fg_weight), torch.tensor(1.0))weights = torch.where(labels_flat == ignore_index, torch.tensor(0.0), weights)loss = loss * weightsreturn loss.sum() / (weights.sum() + 1e-6)

小结:语义分割中常见的 torch.where() 用法

任务示例
忽略 ignore_index 像素mask = label != 255
筛选前景像素fg_mask = torch.where(label == 1, 1.0, 0.0)
自定义 loss 权重weights = torch.where(label == 1, 5.0, 1.0)
彩色可视化分割图colors[label_map]
分割输出中统计精度correct = torch.where(mask, pred == target, 0)

五、补充说明

  • condition 必须是 bool 类型张量。
  • xy 的形状需 可广播
  • torch.where 是支持 GPU 的(放在 cuda() 后依然生效)。

http://www.dtcms.com/wzjs/548131.html

相关文章:

  • 哈尔滨php网站开发公司wordpress个人简历模板
  • 如何申请一个免费的网站空间秦皇岛外贸网站建设
  • 在线网站制作系统安徽营销型网站建设
  • 短视频素材哪里找徐州低价seo
  • 策划公司介绍百度网站怎么优化排名
  • 网站建设需要哪些知识wordpress查询文章分类列表
  • 恩施建站建设雅奇小蘑菇做网站好不好用
  • 公司网站 制作chinacd wordpress第三性
  • 手机号网站源码外贸网站制作公司
  • php做视频网站微信小程序案例源码
  • 德阳建设厅官方网站网站开发如何引用函数
  • 展示型网站建设的标准如何做网络推广人员
  • 河南郑州网站关键词排名助手华秋商城
  • 湖南建设厅网站即墨区城乡建设局网站
  • 花钱做网站不给源码wordpress克隆他人的网站
  • wiki网站开发工具阿里巴巴对外做网站吗
  • 酒店网站建设报价详情厦门建设网官方网站
  • 株洲建设公司网站h5网页模板下载
  • 中国产品网免费网站影响网站收录的因素
  • 北京网站建设公司新闻云虚拟机
  • 网站建设全部流程图江门鹤山最新消息新闻
  • 做网站合同封面做装修公司网站费用
  • 做滤芯的网站重庆网站建设 菠拿拿
  • 全国各地网站开发外包设计说明英文翻译
  • 关于网站设计的书籍保定网站制作哪家好建设
  • 怎么做网站后门网站轮播图怎么保存
  • 如何制作淘宝客网站wordpress免费问答模板
  • 苏州网站建设排行网站建设站长相关专业
  • 网站常用字体免费云服务器试用7天
  • 台州市临海建设局网站安徽工业大学两学一做网站