当前位置: 首页 > news >正文

Zarr Dataset (数据集) 的使用

  • 最近在跑, 3D Diffusion Policy, DexGraspVLA, Improved Diffusion Policy的时候发现大家的数据集都是以zarr的格式存储的,遂记录一下zarr作为dataset类的基本信息,方便日后查询;
    • 这个视频介绍了zarr是怎么高效存储的:https://www.youtube.com/watch?v=KiiKvXzhyMs,大约节省15%~50%的空间;
    • 涉及到下面三个的代码我上传到了百度网盘:
      • 链接: https://pan.baidu.com/s/1IQT4bAdP-LuOwSlfmE_49A 提取码: 2025
  • 我的zarr版本是2.12.0;

1. 记录数据

  • Demo (DP3, https://github.com/YanjieZe/3D-Diffusion-Policy)
    • 按照现在的Replay Buffer, episode_ends是必须要有的,而且要存给meta;
#  参考 https://github.com/YanjieZe/3D-Diffusion-Policy;
import zarr 
import numpy as np zarr_root = zarr.group('/mnt/data/workspace/Improved-3D-Diffusion-Policy/notebooks/test_data')
zarr_data = zarr_root.create_group('data')
zarr_meta = zarr_root.create_group('meta')B = 30
img_arrays = np.random.rand(B, 480, 640, 3)
pointcloud_arrays = np.random.rand(B, 512, 3)
depth_arrays = np.random.rand(B, 480, 640) 
action_arrays = np.random.rand(B, 7) 
# 每个episode的起始和终止点;
episodeends_arrays = np.array([B//3, B//3*2, B], dtype=np.int64)   compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=1)
img_chunk_size = (B, img_arrays.shape[1], img_arrays.shape[2], img_arrays.shape[3])
pointcloud_chunk_size = (B, pointcloud_arrays.shape[1], pointcloud_arrays.shape[2])
depth_chunk_size = (B, depth_arrays.shape[1], depth_arrays.shape[2])
action_chunk_size = (B, action_arrays.shape[1])
episodeends_chunk_size = (B,)zarr_data.create_dataset('img', data=img_arrays, chunks=img_chunk_size, dtype=np.float32, overwrite=True, compressor=compressor)
zarr_data.create_dataset('pointcloud', data=pointcloud_arrays, chunks=pointcloud_chunk_size, dtype=np.float32, overwrite=True, compressor=compressor)
zarr_data.create_dataset('depth', data=depth_arrays, chunks=depth_chunk_size, dtype=np.float32, overwrite=True, compressor=compressor)
zarr_data.create_dataset('action', data=action_arrays, chunks=action_chunk_size, dtype=np.float32, compressor=compressor)
zarr_meta.create_dataset('episode_ends', data=episodeends_arrays, chunks=episodeends_chunk_size, dtype=np.int64, compressor=compressor)
  • 存储完是这样的:
├── .zgroup
├── data
│   ├── .zgroup
│   ├── action
│   │   ├── .zarray
│   │   └── 0.0
│   ├── depth
│   │   ├── .zarray
│   │   └── 0.0.0
│   ├── img
│   │   ├── .zarray
│   │   └── 0.0.0.0
│   └── pointcloud
│       ├── .zarray
│       └── 0.0.0
└── meta├── .zgroup└── episode_ends├── .zarray└── 0
  • DexGraspVLA:
    • https://github.com/Psi-Robot/DexGraspVLA/blob/main/controller/common/replay_buffer.py
    • https://github.com/Psi-Robot/DexGraspVLA/blob/main/controller/common/sampler.py
  • Improved Diffusion Policy:
    • https://github.com/YanjieZe/3D-Diffusion-Policy/blob/5207c41ad917684c6f2b9d2900c0d7a593df94ab/3D-Diffusion-Policy/diffusion_policy_3d/common/replay_buffer.py
    • https://github.com/YanjieZe/3D-Diffusion-Policy/blob/5207c41a/3D-Diffusion-Policy/diffusion_policy_3d/common/sampler.py
  1. 加载数据
  • Demo
    • 使用 SequenceSampler 从 ReplayBuffer 中采样序列数据;
    • pad_before和pad_after是针对每个episode而言的,我这里三个episode,每个episode的长度为10, pad一共是前1后7共计18,采样长度是16即每个episode能采样(18+1-16=3)个sample_sequence,最后三个episode就是一共9个sample_sequence;
from replay_buffer import ReplayBuffer  
from sampler import SequenceSampler  # 方法1: 加载到内存(推荐)  
replay_buffer = ReplayBuffer.copy_from_path(  '/mnt/data/workspace/Improved-3D-Diffusion-Policy/notebooks/test_data',   keys=['img', 'pointcloud', 'depth', 'action']  
)  # 访问数据  
# imgs = replay_buffer['img']  # 获取所有动作  # 获取特定 episode 的数据  
episode_0 = replay_buffer.get_episode(0)  
# print(f"Episode 0 action shape: {episode_0['action'].shape}")  # 创建序列采样器  
sampler = SequenceSampler(  replay_buffer=replay_buffer,  sequence_length=16,      # 序列长度, 这个就是时间窗口; pad_before=1,           # 前向填充, 用第0项填;pad_after=7,            # 后向填充,用第-1项填;  episode_mask=None       # 使用所有 episode  
)  # 采样序列  
print(len(sampler)) # (10+8)
sample = sampler.sample_sequence(0)  
print(f"Sampled action shape: {sample['action'].shape}")  # (16, 7)  
print(f"Sampled state shape: {sample['img'].shape}")    # (16, 480, 640, 3)'''
Output:
[32mReplay Buffer: img, shape (30, 480, 640, 3), dtype float32, range 0.00~1.00[0m
[32mReplay Buffer: pointcloud, shape (30, 512, 3), dtype float32, range 0.00~1.00[0m
[32mReplay Buffer: depth, shape (30, 480, 640), dtype float32, range 0.00~1.00[0m
[32mReplay Buffer: action, shape (30, 7), dtype float32, range 0.00~1.00[0m
[32m--------------------------[0m
9
Sampled action shape: (16, 7)
Sampled state shape: (16, 480, 640, 3)
'''
  • DexGraspVLA, Improved Diffusion Policy的replay_buffer和DP3的毫无区别,就只是多了两个try, sampler完全一样的;
http://www.dtcms.com/a/454016.html

相关文章:

  • 淘宝 x5sec 普通滑块 分析
  • 西安网站建设制作公司提供网站建设小程序制作
  • 自建简单计算机CPU——软硬兼施
  • 企业网站模板上一品资源php网站模板修改
  • 线报网站如何做网页模板小偷
  • 小新pro更改网站设置国际新闻报道
  • JAVA算法练习题day31
  • 20g虚拟主机建设网站朔州企业网站建设
  • 大模型-扩散模型(Diffusion Model)原理讲解
  • 反欺诈模型升级:如何从“抓坏人”到“提前阻止坏人作案”?
  • 烟台网站seo服务南昌市市政建设有限公司
  • 系统集成项目管理工程师:【第一章 信息化发展】
  • 网站导航栏最多可以做几个制作手机wap网站工具
  • 离石做网站磁力搜索引擎torrentkitty
  • 上线了做网站要钱wordpress设置注册页面
  • 济南网站建设(选聚搜网络)建设工程报建网站查询
  • 江苏建筑网站建设网站开发行业代码
  • 上海普陀门户网站sem是什么职业岗位
  • 机械网站建设公司推荐高端网约车
  • 抖音私密账号显示IP属地吗?能更改IP么?
  • Sqoop的安装与配置
  • 样式网站商城网站微信支付接口申请流程
  • 量子密钥分发在BFF层的*认证实验
  • 永州市住房和城乡建设厅网站品牌型网站案例
  • MATLAB循环控制:break和continue语句详解
  • 历史网站怎么做wordpress文字怎么做超级链接
  • 水利建设管理司网站广州企业网站建设报价
  • Python美股量化交易填坑记录——3.盈透(Interactive Brokers)证券API接口
  • 网站有访问量 为什么没有询盘淘宝客的网站怎么做的
  • 力扣:9.回文数の题解