当前位置: 首页 > news >正文

【3DV 进阶-4】VecSet 论文+代码对照理解

  • 【3DV 进阶-1】Hunyuan3D2.1 训练代码详细理解上-模型调用流程

  • 【3DV 进阶-2】Hunyuan3D2.1 训练代码详细理解下-数据读取流程

  • 【3DV 进阶-3】Hunyuan3D2.1 训练代码详细理解之-Flow matching 训练 loss 详解

  • 本文介绍 3DShape2VecSet [TOG 2023], 该篇论文中提出的 VecSet 是目前主流图生3D base model (包括 Hunyuan3D2.1, TripoSG 等)采用的表示。

项目概览

  • 3DShape2VecSet 是一个 3D 形状表示与生成框架,先把点云/mesh 的表面样本编码成一组向量(VecSet),再在该集合空间里训练条件扩散模型,用于生成新形状。
  • 项目主体训练流程分两阶段:阶段一训练 KL 自编码器,将 2048 个表面点压缩成 512 条 latent 向量;阶段二在 latent 向量集合上训练 EDM 风格的条件扩散模型,类别条件来自 ShapeNet 标签。

VecSet 表示本质

  • VecSet 的核心思想是:把输入点云视为一个无序集合,使用最远点采样 (FPS) 选出 num_latents 个代表点,再经交叉注意力从完整点云聚合特征,获得一个“向量集合”形式的潜空间表示。
  • 这个潜空间通过多层自注意力进一步建模,使得每个 latent token 都携带局部+全局几何信息,既适合重建网络,也便于扩散模型直接作用于“集合”而非序列。
244-258: /3DShape2VecSet/models_ae.pyidx = fps(pos, batch, ratio=ratio)                    # pos: (B*N,3) 展平后的所有点;batch: (B*N) 指示每个点所属样本;ratio=M/N 目标采样比例 → 返回最远点采样后的全局索引
sampled_pc = pos[idx]                                   
# 利用 idx 在 pos 中取样,得到下采样后的点云;形状 (B*M,3)
sampled_pc = sampled_pc.view(B, -1, 3)                  
# 还原 batch 维度,得到 (B, M, 3) 的代表点集合
sampled_pc_embeddings = self.point_embed(sampled_pc)    
# 对采样点做位置编码 + MLP,输出 (B, M, dim)
pc_embeddings = self.point_embed(pc)                    
# 对完整点云 (B, N, 3) 做同样编码,得到 (B, N, dim)
x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings
# 以采样点嵌入作为查询 Q,完整点嵌入作为键值 K/V 做跨注意力,输出 (B, M, dim),并与原采样嵌入残差相加
x = cross_ff(x) + x                                     
# 将 cross-attn 输出送入前馈网络 (B, M, dim),再做一次残差,得到最终 VecSet latent 表示

训练与生成流程

  • 阶段一(变分自编码器)
    • KLAutoEncoder.encode 输出 latent 分布的均值、方差并采样 latent set,同时返回 KL 损失;解码时再对给定查询点(如体素网格或采样点)做 cross-attn,预测占据率,从而重建原始几何。
    • 训练时同时优化体素占据和近表面样本的 BCE 损失,并以小权重加入 KL 项。
368-401:/3DShape2VecSet/models_ae.pymean = self.mean_fc(x)logvar = self.logvar_fc(x)posterior = DiagonalGaussianDistribution(mean, logvar)x = posterior.sample()kl = posterior.kl()...o = self.decode(x, queries).squeeze(-1)return {'logits': o, 'kl': kl}
49-66:/3DShape2VecSet/engine_ae.pyoutputs = model(surface, points)if 'kl' in outputs:loss_kl = outputs['kl']loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]outputs = outputs['logits']loss_vol = criterion(outputs[:, :1024], labels[:, :1024])loss_near = criterion(outputs[:, 1024:], labels[:, 1024:])loss = loss_vol + 0.1 * loss_near + kl_weight * loss_kl
  • 阶段二(类别条件扩散)
    • 先用冻结的 autoencoder 编码器提取 latent set,再把它连同类别标签送入 EDMPrecond 网络,按 EDM loss 做 score matching。
    • 扩散网络是多层 Transformer,对 latent token 序列进行自注意力,并通过类别嵌入调制层归一化,从而学会在噪声条件下预测干净 latent。
48:54:/3DShape2VecSet/engine_class_cond.pywith torch.cuda.amp.autocast(enabled=False):with torch.no_grad():_, x = ae.encode(surface)loss = criterion(model, x, categories)
484:533:/3DShape2VecSet/models_class_cond.pyself.model = LatentArrayTransformer(in_channels=channels, t_channels=256, n_heads=n_heads, d_head=d_head, depth=depth)self.category_emb = nn.Embedding(55, n_heads * d_head)...c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()c_noise = sigma.log() / 4F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), cond=cond_emb, **model_kwargs)D_x = c_skip * x + c_out * F_x.to(torch.float32)
  • 采样与重建
    • 采样时给定类别 ID 和随机种子,扩散模型生成新的 latent set,再由 AE 解码器对体素网格查询,得到占据率体积,最后通过 marching cubes 导出 mesh。
64:75:/3DShape2VecSet/sample_class_cond.pylogits = ae.decode(sampled_array[j:j+1], grid)volume = logits.view(density+1, density+1, density+1).permute(1, 0, 2).cpu().numpy()verts, faces = mcubes.marching_cubes(volume, 0)m = trimesh.Trimesh(verts, faces)m.export('class_cond_obj/{}/{:02d}-{:05d}.obj'.format(args.dm, category_id, i*iters+j))

关键代码片段

  • VecSet 编码:models_ae.AutoEncoder.encode / KLAutoEncoder.encode(见上文片段)。
  • 变分采样与解码:models_ae.KLAutoEncoder(见上文两段)。
  • 自编码器训练循环:engine_ae.train_one_epoch
  • 扩散训练入口:engine_class_cond.train_one_epoch
  • 扩散网络结构:models_class_cond.EDMPrecond
  • 生成与网格重建脚本:sample_class_cond.py

(1)最远点采样 (FPS) 是什么?

  • 最远点采样用于从原始点云中等距挑选 num_latents 个代表点,保证覆盖形状空间,常用于点云下采样。
  • 代码使用了 torch_cluster.fps,在 models_ae.pyencode() 里把输入点云拉平成 (B*N,3) 后调用 fps,返回代表点索引,再 reshape 成 (B,num_latents,3)
242:247:/3DShape2VecSet/models_ae.pyidx = fps(pos, batch, ratio=ratio)sampled_pc = pos[idx]sampled_pc = sampled_pc.view(B, -1, 3)

(2)交叉注意力如何处理 num_latents 与原始点云?Q/K/V 怎么定义?

- sampled_pc_embeddings(代表点)作为查询 Qpc_embeddings全部点)作为上下文,提供 K 和 V

  • Attention.forwardself.to_q(x) 生成 Q,self.to_kv(context) 一次性输出 K、V;然后做缩放点积注意力、softmax、聚合。
  • 相关逻辑在 models_ae.pyAttention 类与 AutoEncoder.encode 中的调用。
70:107:/3DShape2VecSet/models_ae.pyq = self.to_q(x)context = default(context, x)k, v = self.to_kv(context).chunk(2, dim = -1)...sim = einsum('b i d, b j d -> b i j', q, k) * self.scaleattn = sim.softmax(dim = -1)out = einsum('b i j, b j d -> b i d', attn, v)
        sampled_pc_embeddings = self.point_embed(sampled_pc)pc_embeddings = self.point_embed(pc)x = cross_attn(sampled_pc_embeddings, context = pc_embeddings, mask = None) + sampled_pc_embeddings

(3)“向量集合”潜空间表示是什么?代码位置?

  • 交叉注意力后的 x 形状为 (B, num_latents, dim),每一行对应一个 latent 向量,它们组合成 VecSet。
  • 这个集合随后通过多层自注意力(self.layers)进一步处理,再返回给解码器重建;因此 x 就是我们说的“向量集合”表示。
  • 关键赋值和后续处理在 encode()decode() 方法里。
        x = cross_attn(sampled_pc_embeddings, context = pc_embeddings, mask = None) + sampled_pc_embeddingsx = cross_ff(x) + xreturn x
...for self_attn, self_ff in self.layers:x = self_attn(x) + xx = self_ff(x) + x

这样即可把点云编码成一个无序 latent 向量集合,再在这一集合空间里做后续生成或重建。

http://www.dtcms.com/a/560672.html

相关文章:

  • Oracle实用参考(13)——Oracle for Linux (RAC)到Oracle for Linux(单实例)间OGG单向复制环境搭建(2)
  • 前端开发 网站建设头像logo图片在线制作免费
  • 电话语音接入扣子介绍
  • Go分布式追踪实战:从理论到OpenTelemetry集成|Go语言进阶(15)
  • Vue-理解 vuex
  • 【Android】View滑动的实现
  • 广西南宁网站优化急切网头像在线制作图片
  • 创建对象中的单例模式
  • AI革新汽车安全软件开发
  • 单例模式并使用多线程方式验证
  • 小梦音乐下载器(高品质MP3下载) 中文绿色版
  • 网站群发推广软件wordpress页面显示文章
  • Redis大Key调优指针
  • Redis BigKey场景实战
  • Vue消息订阅与发布
  • 12306网站建设超30亿个人网站做贷款广告
  • 《Streamlit 交互式 Web 应用开发》总结测试题
  • 大连 网站制作黑龙江做网站
  • ASP.NET Core 9 Web Api 启用 Swagger
  • Web APIs学习第三天:事件
  • UVa 1597 Searching the Web
  • 5分钟读懂MySQL+Redis双写一致性实现流程
  • 从零开始构建PDF文档生成器(二)- 添加页眉页脚
  • PostgreSQL 中 pg_stat_database 视图的 tup_returned 字段详解
  • 网络原理--HTTP
  • 网站开发宣传标语2017做网站还赚钱吗
  • 海南网站建设公司哪家好wordpress 有点慢
  • Flutter 存储管理:从基础到进阶的完整指南
  • 鸿蒙Flutter三方库适配指南:09.版本升级适配
  • AutoAnalyze智能数据分析助手开源项目