当前位置: 首页 > news >正文

【3DV 进阶-2】Hunyuan3D2.1 训练代码详细理解下-数据读取流程

  • 【3D 入门-指标篇上】3D 网格重建评估指标详解与通俗比喻
  • 【3D 入门-指标篇下】 3D重建评估指标对比-附实现代码
  • 【3D 入门-3】常见 3D 格式对比,.glb / .obj / .stl / .ply
  • 【3D 入门-4】trimesh 极速上手之 3D Mesh 数据结构解析(Vertices / Faces)
  • 【3D 入门-5】trimesh 极速上手之 Hunyuan3D-2.1 中的“非水密网格“处理流程
  • 【3D 入门-6】大白话解释 SDF(Signed Distance Field) 和 Marching Cube 算法
  • 【3D 入门-7】理解 SDF(Signed Distance Field) 不是什么?与相对坐标的区别
  • 【3D 入门-8】通过 Hunyuan3D2.1 的代码来理解 SDF 和 marching cubes(上)
  • 【3D 入门-9】通过 Hunyuan3D2.1 的代码来理解 SDF 和 marching cubes(下)
  • 【3DV 进阶-1】Hunyuan3D2.1 训练代码详细理解上-模型调用流程

在深度学习模型训练中,数据读取流程往往是连接原始数据与模型训练的关键桥梁。对于像Hunyuan3D2.1这样复杂的3D生成模型而言,高效、准确的数据读取与预处理尤为重要。本文将深入解析Hunyuan3D2.1的训练数据读取流程,帮助读者理解从数据加载到模型输入的完整链路。

  • 简而言之,总体的调用链为: - main → instantiate_from_config(dataset) → AlignedShapeLatentModule.train_dataloader → DataLoader(AlignedShapeLatentDataset) → IterableDataset.iter → decode(路径/npz) → load_render/采样点云 → transform 组装 sample{surface,image,mask,…} → 送入 Diffuser.forward → cond_stage_model(image,mask) + first_stage_model.encode(surface)

数据模块定位与整体架构

要理解Hunyuan3D2.1的数据读取流程,我们首先从数据模块的实现入手。Hunyuan3D2.1采用了PyTorch Lightning的DataModule设计模式,核心实现位于hy3dshape.data.dit_asl.AlignedShapeLatentModule。通过分析这个类的setuptrain_dataloaderval_dataloader等方法以及底层的Dataset实现,我们可以清晰梳理出batch数据的构成与调用链,进而掌握从main函数到模型forward方法的完整数据流。

训练阶段数据读取调用链

入口:构建DataModule

整个数据流程始于主函数中对DataModule的实例化,代码如下:

  • /path/Hunyuan3D-2.1/hy3dshape/main.py
# Build data modules
data: pl.LightningDataModule = instantiate_from_config(config.dataset)

这行代码从配置文件中读取数据集相关配置,并实例化出对应的DataModule对象,为后续的数据加载做好准备。

DataModule定义与DataLoader创建

  • 代码位置:/path/Hunyuan3D-2.1/hy3dshape/hy3dshape/data/dit_asl.py
  • AlignedShapeLatentModule作为核心的数据模块,定义了训练和验证过程中所需的数据转换和加载逻辑:
class AlignedShapeLatentModule(LightningDataModule):def __init__(..., image_size: int = 224, mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225), ...):...self.train_image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(self.image_size),transforms.Normalize(mean=self.mean, std=self.std)])self.val_image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(self.image_size),transforms.Normalize(mean=self.mean, std=self.std)])

__init__方法中,主要定义了图像的预处理流程,包括转换为张量、调整大小和标准化等操作,这些操作会应用于后续加载的图像数据。

对于训练数据加载器,AlignedShapeLatentModule的实现如下:

def train_dataloader(self):dataset = AlignedShapeLatentDataset(**asl_params)return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True, drop_last=True, worker_init_fn=worker_init_fn)

验证数据加载器的实现与训练类似:

def val_dataloader(self):dataset = AlignedShapeLatentDataset(** asl_params)return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=self.val_num_workers,pin_memory=True, drop_last=True, worker_init_fn=worker_init_fn)

可以看到,无论是训练还是验证,都使用了AlignedShapeLatentDataset作为数据集,并通过PyTorch的DataLoader进行封装,实现了批量加载、多进程处理等功能。

Dataset实现与数据分片重采样

AlignedShapeLatentDataset是整个数据读取流程的核心,它继承自PyTorch的IterableDataset,适用于处理大型数据集:

class AlignedShapeLatentDataset(torch.utils.data.dataset.IterableDataset):def __init__(..., data_list: str = None, cond_stage_key: str = "image", image_transform=None, ...)

为了高效处理大规模数据,Hunyuan3D2.1采用了数据分片(shard)的方式存储,并通过ResampledShards类实现了对数据分片的重采样:

class ResampledShards(IterableDataset):def __iter__(self):...for _ in range(self.nshards):index = self.rng.randint(0, len(self.datalist) - 1)yield self.datalist[index]

这种重采样机制可以有效避免模型学习到数据的顺序信息,提高模型的泛化能力。

Dataset产出样本流程

AlignedShapeLatentDataset产出样本的过程主要包括三个步骤:解码(decode)、转换(transform)和生成(yield)。

  1. 解码过程:从数据路径中读取相关文件,构建原始样本字典
def decode(self, item):uid = item.split('/')[-1]render_img_paths = [os.path.join(item, f'render_cond/{i:03d}.png') for i in range(24)]surface_npz_path = os.path.join(item, f'geo_data/{uid}_surface.npz')sample = {}sample["image"] = render_img_pathssurface_data = read_npz(surface_npz_path)sample["random_surface"] = surface_data['random_surface']sample["sharpedge_surface"] = surface_data['sharp_surface']return sample
  1. 转换过程:对原始数据进行预处理和特征提取,生成模型所需的输入格式
def transform(self, sample):rng = np.random.default_rng()...image_input, mask_input = self.load_render(sample['image'])surface, geo_points = self.load_surface_sdf_points(rng, random_surface, sharpedge_surface)sample = {"surface": surface, "geo_points": geo_points,"image": image_input, "mask": mask_input,}return sample
  1. 图像加载与处理:从多个渲染图中随机选择一张,并进行必要的预处理
def load_render(self, imgs_path):imgs_choice = self.rng.sample(imgs_path, 1)for image_path in imgs_choice:image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)  # 读取 RGBA...
if self.padding: ...  # 可选裁剪-居中-扩边
if self.image_transform:image = self.image_transform(image)mask = np.stack((mask, mask, mask), axis=-1)mask = self.image_transform(mask)
  1. 表面点云加载与采样:从表面数据中采样点云,并处理相关特征
def load_surface_sdf_points(...):# 采样点云/锐边点,拼接法线/标签,返回 "surface"(含坐标与可选法线/标签), geo_points(占位)
  1. 迭代生成样本:通过__iter__方法将上述过程串联起来,持续生成训练样本
def __iter__(self):for data in ResampledShards(self.data_list):sample = self.decode(data)sample = self.transform(sample)yield sample

模型端如何消费batch数据

经过上述流程生成的batch数据最终会被模型消费,相关代码如下:

with torch.autocast(..., dtype=torch.bfloat16):contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'))
with torch.autocast(..., dtype=torch.float16):with torch.no_grad():latents = self.first_stage_model.encode(batch[self.first_stage_key], sample_posterior=True)latents = self.z_scale_factor * latents

其中,first_stage_key默认为"surface",因此VAE编码器使用的是batch["surface"]的数据;而条件编码器则使用imagemask(可选text)作为输入。

数据流程总结

一句话调用链

main → instantiate_from_config(dataset) → AlignedShapeLatentModule.train_dataloader → DataLoader(AlignedShapeLatentDataset) → IterableDataset.iter → decode(路径/npz) → load_render/采样点云 → transform 组装 sample{surface,image,mask,…} → 送入 Diffuser.forward → cond_stage_model(image,mask) + first_stage_model.encode(surface)

关键细节

  • 产出的batch数据包含以下键值:surfaceimagemaskgeo_points(占位)。
  • 图像经过Resize和Normalize处理(由DataModule定义的transforms实现),mask也会进行同步的尺寸调整与归一化。
  • 图像选择策略:每个样本目录包含24张渲染图,训练时会从中随机选择1张(可选padding/居中处理)。
  • 采样点云:从random_surfacesharp_surface中各随机采样指定数量的点,拼接后(可选)附上法线与锐边标签。

小结

Hunyuan3D2.1的数据读取流程通过AlignedShapeLatentDataset按目录迭代数据,经过decode和transform两个主要步骤,形成模型所需的各个字段。在训练过程中,模型的条件信息来自imagemask,而几何潜变量则来自surface经过VAE编码后的结果。这种设计既保证了数据处理的高效性,又能为模型提供丰富的多模态输入信息,为3D生成任务奠定了坚实的数据基础。


文章转载自:

http://DEwzwqPV.qpsxz.cn
http://8ftDdgXj.qpsxz.cn
http://oz09JR5L.qpsxz.cn
http://9RTovRPg.qpsxz.cn
http://RnDvudXq.qpsxz.cn
http://b607GMgX.qpsxz.cn
http://Q6sZrOGf.qpsxz.cn
http://qi36jIiE.qpsxz.cn
http://Pj08oHEN.qpsxz.cn
http://r8QYqVcH.qpsxz.cn
http://4F7GSmpi.qpsxz.cn
http://8utMKTS3.qpsxz.cn
http://AkiBUlIE.qpsxz.cn
http://reCS5HTk.qpsxz.cn
http://f5Oeg5dX.qpsxz.cn
http://pSTUA2O3.qpsxz.cn
http://8deT9ivv.qpsxz.cn
http://ZG4tVZan.qpsxz.cn
http://lPe6xU3H.qpsxz.cn
http://wt1enTSi.qpsxz.cn
http://2mp8WOrE.qpsxz.cn
http://pwv8VRbR.qpsxz.cn
http://cedjSrDi.qpsxz.cn
http://a39iaGdl.qpsxz.cn
http://B1Vmk1b7.qpsxz.cn
http://c5fyQN3W.qpsxz.cn
http://ZUyQI0Yi.qpsxz.cn
http://Nj8ZQirx.qpsxz.cn
http://9rJox3Tl.qpsxz.cn
http://Q0rGny81.qpsxz.cn
http://www.dtcms.com/a/374564.html

相关文章:

  • 从零开始的云计算生活——第六十天,志在千里,使用Jenkins部署K8S
  • 平板热点频繁断连?三步彻底解决
  • nand flash的擦除命令使用
  • 《Pod调度失效到Kubernetes调度器的底层逻辑重构》
  • OC-单例模式
  • C语言链表设计及应用
  • 中级统计师-统计法规-第三章 统计法的基本原则
  • 【VR音游】音符轨道系统开发实录与原理解析(OpenXR手势交互)
  • web前端安全-什么是供应链攻击?
  • Saucony索康尼推出全新 WOOOLLY 运动生活羊毛系列 生动无理由,从专业跑步延展运动生活的每一刻
  • 后端(FastAPI)学习笔记(CLASS 2):FastAPI框架
  • Java如何实现一个安全的登录功能?
  • AI中的“预训练”是什么意思
  • 量子文件传输系统:简单高效的文件分享解决方案
  • 基于Springboot + vue实现的乡村生活垃圾治理问题中运输地图
  • 分布式专题——5 大厂Redis高并发缓存架构实战与性能优化
  • 下载 Eclipse Temurin 的 OpenJDK 提示 “无法访问此网站 github.com 的响应时间过长”
  • 从嵌入式状态管理到云原生架构:Apache Flink 的演进与下一代增量计算范式
  • Gradio全解11——Streaming:流式传输的视频应用(2)——Twilio:网络服务提供商
  • 服务器更换jar包,重启后端服务
  • 人形机器人赛道的隐形胜负手:低延迟视频链路如何决定机器人未来
  • 分钟级长视频生成迎来“记忆革命”,7倍成本降低,2.2倍端到端生成速度提升!|斯坦福字节
  • 多张图片生成视频模型技术深度解析
  • electron安装失败
  • Electron+Vite+Vue项目中,如何监听Electron的修改实现和Vue一样的热更新?[特殊字符]
  • IEEE出版,限时早鸟优惠!|2025年智能制造、机器人与自动化国际学术会议 (IMRA 2025)
  • Next.js vs Create React App:2025年该如何选择?
  • React From表单使用Formik和yup进行校验
  • 响应式编程思想与 Reactive Streams 规范
  • [react] react onClick函数的认知陷阱