当前位置: 首页 > news >正文

nlf loss 学习笔记

目录

3d 投影到2d 继续求loss

reconstruct_absolute

1. 功能概述

2. 参数详解

3. 两种重建模式对比


3d 投影到2d 继续求loss

compute_loss_with_3d_gt

   def compute_loss_with_3d_gt(self, inps, preds):losses = EasyDict()if inps.point_validity is None:inps.point_validity = tf.ones_like(preds.coords3d_abs[..., 0], dtype=tf.bool)diff = inps.coords3d_true - preds.coords3d_abs# CENTER-RELATIVE 3D LOSS# We now compute a "center-relative" error, which is either root-relative# (if there is a root joint present), or mean-relative (i.e. the mean is subtracted).meanrel_diff = tfu3d.center_relative_pose(diff, joint_validity_mask=inps.point_validity, center_is_mean=True)# root_index is a [batch_size] int tensor that holds which one is the root# diff is [batch_size, joint_cound, 3]# we now need to select the root joint from each batch elementif inps.root_index.shape.ndims == 0:inps.root_index = tf.fill(tf.shape(diff)[:1], inps.root_index)# diff has shape N,P,3 for batch, point, coord# and root_index has shape N# and we want to select the root joint from each batch elementsanitized_root_index = tf.where(inps.root_index == -1, tf.zeros_like(inps.root_index), inps.root_index)root_diff = tf.expand_dims(tf.gather_nd(diff, tf.stack([tf.range(tf.shape(diff)[0]), sanitized_root_index], axis=1)), axis=1)rootrel_diff = diff - root_diff# Some elements of the batch do not have a root joint, which is marked as -1 as root_index.center_relative_diff = tf.where(inps.root_index[:, tf.newaxis, tf.newaxis] == -1, meanrel_diff, rootrel_diff)losses.loss3d = tfu.reduce_mean_masked(self.my_norm(center_relative_diff, preds.uncert), inps.point_validity)# ABSOLUTE 3D LOSS (camera-space)absdiff = tf.abs(diff)# Since the depth error will naturally scale linearly with distance, we scale the z-error# down to the level that we would get if the person was 5 m away.scale_factor_for_far = tf.minimum(np.float32(1), 5 / tf.abs(inps.coords3d_true[..., 2:]))absdiff_scaled = tf.concat([absdiff[..., :2], absdiff[..., 2:] * scale_factor_for_far], axis=-1)# There are numerical difficulties for points too close to the camera, so we only# apply the absolute loss for points at least 30 cm away from the camera.is_far_enough = inps.coords3d_true[..., 2] > 0.3is_valid_and_far_enough = tf.logical_and(inps.point_validity, is_far_enough)# To make things simpler, we estimate one uncertainty and automatically# apply a factor of 4 to get the uncertainty for the absolute prediction# this is just an approximation, but it works well enough.# The uncertainty does not need to be perfect, it merely serves as a# self-gating mechanism, and the actual value of it is less important# compared to the relative values between different points.losses.loss3d_abs = tfu.reduce_mean_masked(self.my_norm(absdiff_scaled, preds.uncert * 4.),is_valid_and_far_enough)# 2D PROJECTION LOSS (pixel-space)# We also compute a loss in pixel space to encourage good image-alignment in the model.coords2d_pred = tfu3d.project_pose(preds.coords3d_abs, inps.intrinsics)coords2d_true = tfu3d.project_pose(inps.coords3d_true, inps.intrinsics)# Balance factor which considers the 2D image size equivalent to the 3D box size of the# volumetric heatmap. This is just a factor to get a rough ballpark.# It could be tuned further.scale_2d = 1 / FLAGS.proc_side * FLAGS.box_size_m# We only use the 2D loss for points that are in front of the camera and aren't# very far out of the field of view. It's not a problem that the point is outside# to a certain extent, because this will provide training signal to move points which# are outside the image, toward the image border. Therefore those point predictions# will gather up near the border and we can mask them out when doing the absolute# reconstruction.is_in_fov_pred = tf.logical_and(tfu3d.is_within_fov(coords2d_pred, border_factor=-20 * (FLAGS.proc_side / 256)),preds.coords3d_abs[..., 2] > 0.001)is_near_fov_true = tf.logical_and(tfu3d.is_within_fov(coords2d_true, border_factor=-20 * (FLAGS.proc_side / 256)),inps.coords3d_true[..., 2] > 0.001)losses.loss2d = tfu.reduce_mean_masked(self.my_norm((coords2d_true - coords2d_pred) * scale_2d, preds.uncert),tf.logical_and(is_valid_and_far_enough,tf.logical_and(is_in_fov_pred, is_near_fov_true)))return losses, tf.add_n([losses.loss3d,losses.loss2d,FLAGS.absloss_factor * self.stop_grad_before_step(losses.loss3d_abs, FLAGS.absloss_start_step)])

reconstruct_absolute

    def adjusted_train_counter(self):return self.train_counter // FLAGS.grad_accum_stepsdef reconstruct_absolute(self, head2d, head3d, intrinsics, mix_3d_inside_fov, point_validity_mask=None):return tf.cond(self.adjusted_train_counter() < 500,lambda: tfu3d.reconstruct_absolute(head2d, head3d, intrinsics, mix_3d_inside_fov=mix_3d_inside_fov,weak_perspective=True, point_validity_mask=point_validity_mask,border_factor1=1, border_factor2=0.55, mix_based_on_3d=False),lambda: tfu3d.reconstruct_absolute(head2d, head3d, intrinsics, mix_3d_inside_fov=mix_3d_inside_fov,weak_perspective=False, point_validity_mask=point_validity_mask,border_factor1=1, border_factor2=0.55, mix_based_on_3d=False))

1. 功能概述

该函数根据当前训练步数(adjusted_train_counter)选择两种不同的 3D重建策略

  • 训练初期(前500步):使用 弱透视投影(Weak Perspective Projection) 模型,简化计算以稳定训练。

  • 训练后期(500步之后):切换为 更精确的投影模型(可能是全透视投影),提升重建精度。


2. 参数详解

参数类型/范围说明
head2dTensor网络预测的2D坐标(像素空间)
head3dTensor网络预测的3D坐标(相对于根关节的偏移量,可能未对齐绝对坐标系)
intrinsicsTensor相机内参矩阵(用于从3D到2D的投影)
mix_3d_inside_fovFloat [0,1]控制视场内(FOV)点使用3D预测的权重(与2D反投影结果混合)
point_validity_maskTensor (bool)标记哪些点是有效的(如过滤掉遮挡点或离群点)
weak_perspectivebool是否使用弱透视投影(True:忽略深度变化;False:使用完整透视投影)
border_factor1/2float控制视场边缘的扩展范围(用于判断点是否在图像边界内)
mix_based_on_3dbool混合策略是否基于3D坐标(若为False,可能基于2D置信度)

3. 两种重建模式对比

特性训练初期(weak_perspective=True)训练后期(weak_perspective=False)
投影模型弱透视投影(假设物体深度变化可忽略)完整透视投影(考虑深度变化)
计算复杂度低(适合训练初期快速收敛)高(适合精细优化)
适用场景初始阶段姿态大致对齐需要高精度重建(如关节细节优化)
稳定性对噪声和初始值更鲁棒依赖准确的初始预测

相关文章:

  • 智能呼入:云蝠大模型赋能政府热线
  • 激活函数全解析:定义、分类与 17 种常用函数详解
  • 编译opencv4.11gstreamer 参考
  • Spring Batch学习,和Spring Cloud Stream区别
  • 高光谱遥感图像处理之数据分类的fcm算法
  • 采用hovernet统计整张病理切片(png)细胞数量并进行RGB可视化
  • 相机Camera日志分析之九:高通相机Camx 基于预览1帧的ConfigureStreams二级日志分析详解
  • 现代简约中式通用,民国画报风,中国风PPT模版8套一组分享
  • Spring Cloud动态配置刷新:@RefreshScope与@Component的协同机制解析
  • iOS音视频解封装分析
  • LangFlow技术深度解析:可视化编排LangChain应用的新范式 -(2)流编辑器系统
  • 深入理解 Git 分支操作的底层原理
  • SZU 编译原理
  • 深度学习笔记23-LSTM实现火灾预测(Tensorflow)
  • C++_STL_map与set
  • HNUST湖南科技大学-安卓Android期中复习
  • 【Android构建系统】了解Soong构建系统
  • 算法基础 -- 小根堆构建的两种方式:上浮法与下沉法
  • 一款强大的压测带宽工具-iperf3
  • 容器编排利器-k8s入门指南
  • 上海国际珠宝时尚功能区未来三年如何建设?六大行动将开展
  • 广西:坚决拥护党中央对蓝天立进行审查调查的决定
  • 国家统计局:2024年城镇单位就业人员工资平稳增长
  • 竞彩湃|欧联杯决赛前,曼联、热刺继续划水?
  • 最高检公布一起离婚纠纷典型案例:推动离婚经济补偿制度落实
  • 丰富“互换通”产品类型,促进中国金融市场高水平对外开放