当前位置: 首页 > news >正文

STORM代码阅读笔记

默认的 分辨率是 [160,240] ,基于 Transformer 的方法不能做高分辨率。

Dataloader

输入是 带有 pose 信息的 RGB 图像

## 采样帧数目 = 20
num_max_future_frames = int(self.timespan * fps) 
## 每次间隔多少个时间 timesteps 取一个context image
num_context_timesteps  = 4

按照STORM 原来的 setting, future_frames = 20 context_image 每次间隔4帧,所以是 context_frame_idx = [0,5,10,15], 在 target_frame 包含了 从[0,20]的所有20帧。

以这样20 帧的 image 作为一个基本的 batch, 进行预测: 进入 model

所以,输入网络的 context_image 对应的 shape (1,4,3,160,240) 输入4个时刻帧的 frame, 每一个 frame 有 3个相机;对应的 context_camtoworlds shape (1,4,3,4,4)

Network

输入网络的 有3个 input: context_image, ray 和 time 的信息

  • context_image: (1,4,3,3,160,240)
  • Ray embedding (1,4,3,6,160,240)
  • time_embedding (1,4,3)
  • 将 image 和 ray_embedding 进行 concat 操作, 得到 x:(12,9,160,240)
 x = rearrange(x, "b t v c h w -> (b t v) c h w")
plucker_embeds = rearrange(plucker_embeds, "b t v h w c-> (b t v) c h w")
x = torch.cat([x, plucker_embeds], dim=1) ## (12,9,160,240)

然后经过3个 embedding , 将这些 feature 映射成为 token:

x = self.patch_embed(x)  # (b t v) h w c2x = self._pos_embed(x)  # (b t v) (h w) c2x = self._time_embed(x, time, num_views=v)

得到 x.shape (12,600,768), 表示一共有12张图像,每个图象 是 600 个 token, 每个 token 的 channel 是768. 然后将这些 token concat 在一起 得到了 (7200,768) 的 feature;

给得到的 token 分别加上可学习的 motion_token, affine_token 和 sky_token. 连接方式都是 concat
这样得到的 feature 为 (7220,768)的 feature

if self.num_motion_tokens > 0:motion_tokens = repeat(self.motion_tokens, "1 k d -> b k d", b=x.shape[0])x = torch.cat([motion_tokens, x], dim=-2)
if self.use_affine_token:affine_token = repeat(self.affine_token, "1 k d -> b k d", b=b)x = torch.cat([affine_token, x], dim=-2)
if self.use_sky_token:sky_token = repeat(self.sky_token, "1 1 d -> b 1 d", b=x.shape[0])x = torch.cat([sky_token, x], dim=-2)
  • 使用 Transformer 进行学习, 得到的 feature 维度不变。:
 x = self.transformer(x)x = self.norm(x) ## shape(7220,768)

运行完之后,可以将学习到的 token提取出来:

if self.use_sky_token:sky_token = x[:, :1] ## (1,1,768)x = x[:, 1:]if self.use_affine_token:affine_tokens = x[:, : self.num_cams] ## (1,3,768)x = x[:, self.num_cams :]if self.num_motion_tokens > 0:motion_tokens = x[:, : self.num_motion_tokens]  ## (1,16,768)x = x[:, self.num_motion_tokens :]

在 Transformer 内部,没有上采样层,也可以实现 这种 per-pixel feature 的学习。
对于 x 进行 GS 的预测,得到 pixel_align 的高斯。 对于每个 patch, 得到的 feature 是 (12,600,768), 通过一个CNN,虽然通道数没有变 (12,600,768), 但是之前 768 可以理解为全局的 语义, 之后的 768 为 一个patch 内部不同像素的语义,他们共享着 全局的 语义信息,但是每个pixel 却又不一样。 通过下面的 unpatchify 函数将将一个patch 的语义拆成 per-pixel 的语义,将每个768维token展开为8×8像素。

 b, t, v, h, w, _ = origins.shape## x_shape: (12,600,768)x = rearrange(x, "b (t v hw) c -> (b t v) hw c", t=t, v=v)## gs_params_shape: (12,600,768),这一步虽然通道没变,但其实是将一个 token 的全局 语义,映射成## token 内部的像素级别的语义gs_params = self.gs_pred(x)## gs_params_shape: (12,12,160,240)### 关键步骤:unpatchify将每个768维token展开为8×8像素gs_params = self.unpatchify(gs_params, hw=(h, w), patch_size=self.unpatch_size)

根据 token展开的 per-pixel feature, 进行3DGS 的属性预测

gs_params = rearrange(gs_params, "(b t v) c h w -> b t v h w c", t=t, v=v)
depth, scales, quats, opacitys, colors = gs_params.split([1, 3, 4, 1, self.gs_dim], dim=-1)
scales = self.scale_act_fn(scales)
opacitys = self.opacity_act_fn(opacitys)
depths = self.depth_act_fn(depth)
colors = self.rgb_act_fn(colors)
means = origins + directions * depths

除了3DGS 的一半属性之外, storm 还额外预测了其他的运动属性,包括:

其中: x: (1,7200,768) 代表 image_token, motion_tokens 是(1,16,768)代表 motion_token. 处理的大致思路是 motion_token 作为 query, 然后 image_token 映射的feature 作为 key, 去结合计算每一个 高斯的 moition_weightsmoition_bases

gs_params = self.forward_motion_predictor(x, motion_tokens, gs_params)
其中:
forward_flow = torch.einsum("b t v h w k, b k c -> b t v h w c", motion_weights, motion_bases)

moition_bases: shape: [1,16,3]
moition_weights: shape: [1,4,3,160,240,16]
forward_flow: shape: [1,4,3,160,240,3]: 是 weights 和bases 结合的结果

GS_param Rendering

  • 取出高斯的各项属性,尤其是 means 和 速度 forward_v: STORM 假设 在这 20帧是出于 匀速直线运动, 其速度时不变的,可能并不合理。我们的方法直接预测 BBX,可能更为准确。
means = rearrange(gs_params["means"], "b t v h w c -> b (t v h w) c")
scales = rearrange(gs_params["scales"], "b t v h w c -> b (t v h w) c")
quats = rearrange(gs_params["quats"], "b t v h w c -> b (t v h w) c")
opacities = rearrange(gs_params["opacities"], "b t v h w -> b (t v h w)")
colors = rearrange(gs_params["colors"], "b t v h w c -> b (t v h w) c")
forward_v = rearrange(gs_params["forward_flow"], "b t v h w c -> b (t v h w) c")

这里得到的 高斯的 mean 是全部由 context_image 得到的, shape (46800,3), 但这其实是 4个 时刻context_frame_idx = [0,5,10,15], 得到的高斯,并不处于同一时间刻度。
通过比较 target_timecontext_time 之间的插值,去得到每一个 target_time 的 3D Gaussian 的坐标means_batched

  if tgt_time.ndim == 3:tdiff_forward = tgt_time.unsqueeze(2) - ctx_time.unsqueeze(1)tdiff_forward = tdiff_forward.view(b * tgt_t, t * v, 1)tdiff_forward_batched = tdiff_forward.repeat_interleave(h * w, dim=1)else:tdiff_forward = tgt_time.unsqueeze(-1) - ctx_time.unsqueeze(-2)tdiff_forward = tdiff_forward.view(b * tgt_t, t, 1)tdiff_forward_batched = tdiff_forward.repeat_interleave(v * h * w, dim=1)forward_translation = forward_v_batched * tdiff_forward_batchedmeans_batched = means_batched + forward_translation ## (20,460800,3) 

使用 gsplatbatch_rasterization 函数:

  rendered_color, rendered_alpha, _ = rasterization(means=means_batched.float(),  ## (20,460800,3)quats=quats_batched.float(),scales=scales_batched.float(),opacities=opacities_batched.float(),colors=colors_batched.float(),viewmats=viewmats_batched,  ## (20,3,4,4)Ks=Ks_batched,  ## (20,3,3,3)width=tgt_w,height=tgt_h,render_mode="RGB+ED",near_plane=self.near,far_plane=self.far,packed=False,radius_clip=radius_clip,)
http://www.dtcms.com/a/307579.html

相关文章:

  • 邢台市某区人民医院智慧康养平台建设项目案例研究
  • Mac安装Navicat教程Navicat Premium for Mac v17.1.9 Mac安装navicat【亲测】
  • 【ARM】PK51关于内存模式的解析与选择
  • c++:设计模式训练
  • 两款免费数据恢复软件介绍,Win/Mac均可用
  • 【javascript】new.target 学习笔记
  • 揭秘动态测试:软件质量的实战防线
  • List和 ObservableCollection 的区别
  • 【worklist】worklist的hl7、dicom是什么关系
  • 原生安卓与flutter混编的实现
  • 如何使用一台电脑adb调试多个Android设备
  • AI 如何评价股票:三七互娱(SZ:002555),巨人网络(SZ:002558)
  • 解决:MATLAB 已经画好了Figure,想在不重新绘图的情况下去掉坐标轴刻度线
  • Java 大视界 -- Java 大数据在智能医疗远程健康监测与疾病预防预警中的应用(374)
  • 《以终为始,不辩过程》
  • cartographer 概率栅格地图
  • JVM面试通关指南:内存区域、类加载器、双亲委派与GC算法全解析
  • 一万字讲解Java中的IO流——包含底层原理
  • GCC/G++ + Makefile/make 使用
  • Visual Studio调试技巧与函数递归详解
  • “0 成本开跨境店” 噱头下的优哩哩:商业模式深度剖析
  • 5G 单兵终端 + 无人机:消防应急场景的 “空 - 地” 救援协同体系
  • 【可用有效】Axure RP 9 授权码
  • imx6ull-驱动开发篇5——新字符设备驱动实验
  • springcloud04——网关gateway、熔断器 sentinel
  • cas自定义返回信息和自定义认证
  • 考研408_数据结构笔记(第三章栈、队列和数组)
  • 解构衡石嵌入式BI:统一语义层与API网关的原子化封装架构
  • Vue 中使用 Dexie.js
  • 城市客运安全员证考试难不难?如何高效备考