当前位置: 首页 > news >正文

Diffusion Policy Visuomotor Policy Learning via Action Diffusion官方项目解读(二)(4)

运行官方代码库中提供的Colab代码:vision-based environment(二)(4)

    • 十六、函数`unnormalize_data`,继承自`torch.utils.data.Dataset`
      • 十六.1 `def __init__()`
      • 十六.2 `def __len__ ()`
      • 十六.3 `def __getitem__()`
      • 总体说明
    • 十七、Dataset Demo
      • 十七.1 下载数据部分
      • 十七.2 设置参数部分
      • 十七.3 创建数据集
      • 十七.4 创建 DataLoader
      • 十七.5 可视化批次数据
      • 总体说明

官方项目地址:https://diffusion-policy.cs.columbia.edu/
Colab代码:vision-based environment


十六、函数unnormalize_data,继承自torch.utils.data.Dataset

# dataset
class PushTImageDataset(torch.utils.data.Dataset):
  • 作用:定义一个名为 PushTImageDataset 的类,该类继承自 torch.utils.data.Dataset
  • 意义:使数据集对象符合 PyTorch 数据加载接口(需要实现 lengetitem 方法),便于后续用 DataLoader 加载。
  • 示例:创建后可以通过 len(dataset) 获取样本数,并通过 dataset[i] 访问第 i 个样本。

十六.1 def __init__()

    def __init__(self,
                 dataset_path: str,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int):
  • 作用:定义构造函数,接收 4 个参数:
    • dataset_path: 字符串,表示数据集所在的路径(例如:“data/dataset.zarr”)。
    • pred_horizon: 整数,预测(或采样)序列的总长度,例如 30 表示后续采样序列长度为 30 个时间步。
    • obs_horizon: 整数,观察序列的长度,例如 20 表示在预测序列中,前 20 个时间步作为观察输入。
    • action_horizon: 整数,动作序列的长度,例如 10 表示后 10 个时间步作为动作目标。
  • 意义:通过这些参数确定每个样本序列的构造方式,包括如何做 padding(见后续 create_sample_indices 调用)。
        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')
  • 作用:注释,说明接下来将从 zarr 数据集中读取数据。调用 zarr.open() 方法以只读模式 (‘r’) 打开数据集。
  • 示例:如果 dataset_path 为 “data/dataset.zarr”,则 dataset_root 是一个 zarr 数组/组对象,用于后续访问数据。
  • 意义:提示数据存储格式为 zarr,便于理解后续数据加载操作。获取数据集的根对象,后续按键读取各项数据。
        # float32, [0,1], (N,96,96,3)
        train_image_data = dataset_root['data']['img'][:]
  • 作用:注释,说明接下来读取的图像数据为 float32 类型、数值归一化在 [0,1] 范围,形状为 (N,96,96,3)(N 个样本,96×96 大小,3 个通道)。从 zarr 数据集中读取图像数据,读取整个数组。
  • 示例:若数据集中的 “data/img” 存储 5000 张图像,则 train_image_data 的 shape 为 (5000, 96, 96, 3)。
  • 意义:为后续处理(例如通道转换)提供数据信息。加载训练用图像数据,后续会调整通道顺序以符合 PyTorch 要求。
        train_image_data = np.moveaxis(train_image_data, -1,1)
        # (N,3,96,96)
  • 作用:调用 np.moveaxis 将图像数据的最后一维(通道)移动到第 1 个维度。
  • 示例:原 shape 为 (5000,96,96,3) 经此操作后变为 (5000,3,96,96)。
  • 意义:PyTorch 通常要求图像数据通道顺序为 (N, C, H, W),因此做此转换。
        # (N, D)
        train_data = {
            # first two dims of state vector are agent (i.e. gripper) locations
            'agent_pos': dataset_root['data']['state'][:,:2],
            'action': dataset_root['data']['action'][:]
        }
  • 作用:注释,说明接下来读取的其他数据(状态、动作)形状为 (N, D)。构造一个字典 train_data,其中:
    • 'agent_pos' 对应从 “data/state” 读取数据的前两列(即 agent 或 gripper 的位置)。
    • 'action' 对应从 “data/action” 读取的全部动作数据。
  • 示例
    • 若 “data/state” 的 shape 为 (5000, 5),则 dataset_root['data']['state'][:,:2] 的 shape 为 (5000,2);
    • 若 “data/action” 的 shape 为 (5000, 3),则整个动作数组 shape 为 (5000,3).
  • 意义:提示后续数据维度,便于理解。将不同类型的数据分别存放在字典中,便于后续统一归一化处理。
        episode_ends = dataset_root['meta']['episode_ends'][:]
  • 作用:从数据集 meta 部分读取 episode_ends 数组,记录每个 episode 的结束索引。
  • 示例:若数据集中有 50 个 episode,episode_ends 可能为形如 [100, 230, 350, …, 5000] 的数组。
  • 意义:后续利用 create_sample_indices 根据 episode 结束位置生成样本索引,保证不跨 episode 采样。
        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)
  • 作用:说明接下来调用的 create_sample_indices 用于计算每个状态-动作序列的开始和结束索引,并处理填充问题。调用前面定义的 create_sample_indices 函数,生成样本序列的索引数组。
  • 参数说明与示例
    • episode_ends: 前面读取的 episode_ends 数组。
    • sequence_length: 设为 pred_horizon,例如 pred_horizon=30。
    • pad_before: 计算为 obs_horizon-1;例如 obs_horizon=20,则 pad_before=19。
    • pad_after: 计算为 action_horizon-1;例如 action_horizon=10,则 pad_after=9。
  • 意义:为后续采样准备索引信息。根据各个 episode 的数据范围和要求的预测、观察、动作长度,生成每个样本在数据缓冲区中的起止索引及在目标样本中应填充的位置。
        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
  • 作用:注释,说明接下来对 train_data 中的各个数据项计算统计信息,并归一化到 [-1,1] 范围。初始化两个空字典:
    • stats 用于保存每个数据项(如 ‘agent_pos’, ‘action’)的最小和最大值。
    • normalized_train_data 用于保存归一化后的数据。
  • 意义:归一化有助于模型训练的数值稳定性。
  • 示例:初始时 stats = {},normalized_train_data = {}。
        for key, data in train_data.items():
  • 作用:遍历 train_data 中的每个键及对应的数据数组。
  • 示例
    • 第一次循环:key = ‘agent_pos’,data 的 shape 可能为 (5000,2);
    • 第二次循环:key = ‘action’,data 的 shape 可能为 (5000,3).
            stats[key] = get_data_stats(data)
  • 作用:调用 get_data_stats 函数对当前 data 计算每个特征的最小值和最大值,并存入 stats 字典中。
  • 示例
    • 对于 ‘agent_pos’,假设 get_data_stats 返回 {‘min’: [0, 0], ‘max’: [512, 512]}(如果 agent 位置在 0 到 512 之间)。
  • 意义:为后续归一化做准备。
            normalized_train_data[key] = normalize_data(data, stats[key])
  • 作用:调用 normalize_data 函数,用 stats[key] 对当前 data 进行归一化,结果存入 normalized_train_data 字典。
  • 示例
    • 对 ‘agent_pos’,归一化后数值映射到 [-1,1] 范围,例如 [256,400] 可能变为 [0, -0.2157](具体数值取决于 min/max)。
  • 意义:标准化数据尺度,使得不同特征具有相似的数值范围,便于模型训练。
        # images are already normalized
        normalized_train_data['image'] = train_image_data
  • 作用:注释,说明图像数据 train_image_data 已归一化到 [0,1](原数据中如此存储)。将图像数据直接添加到 normalized_train_data 字典中,键为 ‘image’。
  • 示例
    • 若 train_image_data 的 shape 为 (5000,3,96,96) 且数值均在 [0,1],则直接存入。
  • 意义:图像数据不需要再经过 normalize_data 处理。整合所有数据到一个字典中,方便后续统一采样。
        self.indices = indices
  • 作用:将生成的 indices 数组保存到实例变量 self.indices 中。
  • 意义:供 lengetitem 方法使用,确定数据集样本个数与每个样本在数据缓冲区中的位置。
        self.stats = stats
  • 作用:将计算得到的统计信息保存到实例变量 self.stats 中。
  • 意义:可能在后续需要反归一化或其他数据分析时使用。
        self.normalized_train_data = normalized_train_data
  • 作用:将归一化后的数据字典保存到实例变量 self.normalized_train_data 中。
  • 意义:后续 getitem 方法根据这些数据进行采样。
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon
  • 作用:将构造函数传入的 pred_horizon、action_horizon、obs_horizon 分别保存到实例变量中。
  • 示例
    • pred_horizon=30、action_horizon=10、obs_horizon=20,则分别保存这些数值。
  • 意义:这些参数决定了每个样本序列的长度和各部分的划分,后续采样和数据截取都会用到。

十六.2 def __len__ ()

    def __len__(self):
        return len(self.indices)
  • 作用:实现 Dataset 的 len 方法,返回样本总数。
  • 示例
    • 若 self.indices 的 shape 为 (1000,4),则返回 1000。
  • 意义:使得 DataLoader 能够知道数据集包含多少个样本。

十六.3 def __getitem__()

    def __getitem__(self, idx):
  • 作用:定义 getitem 方法,根据索引 idx 返回对应的样本数据。
  • 输入
    • idx:样本索引,例如 5。
  • 意义:满足 PyTorch Dataset 接口,支持索引访问。
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]
  • 作用:注释,说明接下来根据 idx 获取该样本对应的缓冲区索引信息。从 self.indices 数组中取出第 idx 行的 4 个数值,分别赋给 buffer_start_idx、buffer_end_idx、sample_start_idx、sample_end_idx。
  • 示例
    • 若 self.indices[5] = [100, 120, 3, 20],则:
      • buffer_start_idx = 100
      • buffer_end_idx = 120
      • sample_start_idx = 3
      • sample_end_idx = 20
  • 意义:这些索引确定了如何从原始数据中抽取连续序列,并在定长样本中插入实际数据的位置。
        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )
  • 作用:注释,说明接下来将根据这些索引从归一化数据中抽取样本序列。调用前面定义的 sample_sequence 函数,从 normalized_train_data 中抽取一个定长样本。
  • 参数说明
    • train_data:归一化后的数据字典,包含 ‘image’, ‘agent_pos’ 等。
    • sequence_length: 使用 self.pred_horizon(例如 30)。
    • 其它参数由 indices 得到,例如前面示例 [100,120,3,20]。
  • 示例
    • 对于 ‘agent_pos’ 数据,sample_sequence 会抽取原始数组中索引 100 到 120 之间的数据,并将其放入一个长度为 30 的数组中,将前 3 个位置用 sample[0] 填充,后面部分用 sample[-1] 填充。
  • 意义:确保每个样本的长度固定为 pred_horizon,同时处理序列边界缺失的情况(padding)。
        # discard unused observations
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
  • 作用:说明后续将丢弃样本中未使用的观察部分,保留 obs_horizon 长度。对 nsample 字典中的 ‘image’ 数据进行截取,仅保留前 self.obs_horizon 行。
  • 示例
    • 若 nsample[‘image’] 的 shape 为 (30, 3, 96, 96)(30 个时间步),且 obs_horizon=20,则截取后 nsample[‘image’] 变为 (20, 3, 96, 96)。
  • 意义:通常,样本中既有观察部分也有动作部分,而此处只保留观察部分用于模型输入。将观察部分限定为前 obs_horizon 个时间步,其余部分可能属于预测或动作目标,不用于输入模型。
        nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:]
  • 作用:同理,对 ‘agent_pos’ 数据进行截取,仅保留前 obs_horizon 行。
  • 示例
    • 若 nsample[‘agent_pos’] 原 shape 为 (30, 2) 且 obs_horizon=20,则截取后为 (20,2)。
  • 意义:确保观测数据一致,仅取前部时间步。
        return nsample
  • 作用:返回处理好的样本 nsample,该样本是一个字典,包含 ‘image’ 和 ‘agent_pos’ 两个键,对应的数组均已截取为 obs_horizon 长度。
  • 意义getitem 方法输出的样本将用于模型训练或评估,满足固定时间步长要求。

总体说明

  • 大函数和大类的意义
    • PushTImageDataset 类
      • 作用:读取并处理来自 zarr 数据集的图像、状态和动作数据;生成标准化的、固定长度的训练样本。
      • 输入:数据集路径、预测序列长度(pred_horizon)、观察序列长度(obs_horizon)、动作序列长度(action_horizon)。
      • 输出:每个样本是一个字典,主要包含经过归一化并截取的图像数据(形状为 (obs_horizon, 3, 96, 96))和 agent 位置数据(形状为 (obs_horizon,2))。
    • 构造函数 (init)
      • 依次完成数据集的加载、图像数据通道转换、状态和动作数据读取、episode 划分、样本索引生成、各数据归一化及保存所有必要的参数。
    • len 方法:返回数据集样本总数,依据生成的 indices 数组长度。
    • getitem 方法:根据给定的索引,从归一化数据中抽取定长样本序列,并截取出观察部分返回。
  • 设计原因
    • 数据预处理:将原始数据归一化、转换通道顺序,便于神经网络训练。
    • 固定长度采样:通过 create_sample_indices 和 sample_sequence 的配合,实现跨 episode 数据抽样时的边界填充,确保每个样本长度一致。
    • 模型输入组织:将图像和位置信息组合成字典输出,便于后续多模态模型使用。

下面逐行详细解释这段代码的运作方式,包括具体数值例子、每行的细节说明以及整体设计思路。整个代码主要完成以下任务:

  1. 从 Google Drive 下载演示数据(如果本地没有)。
  2. 设置采样参数,说明观察、动作、预测的时间步数。
  3. 利用指定数据创建一个 PushTImageDataset 数据集实例,并保存数据统计信息。
  4. 利用该数据集创建一个 PyTorch DataLoader,用于批量加载数据。
  5. 从 DataLoader 中获取一个批次数据,并打印各部分数据的形状。

下面逐行说明每一行代码的作用。


十七、Dataset Demo

十七.1 下载数据部分

# download demonstration data from Google Drive
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
  • 作用:注释,说明接下来代码用于从 Google Drive 下载演示数据。将字符串 "pusht_cchi_v7_replay.zarr.zip" 赋值给变量 dataset_path,表示本地数据集文件名。
  • 意义:提示用户数据来源,便于理解数据加载流程。
  • 示例:最终本地文件应命名为 pusht_cchi_v7_replay.zarr.zip
if not os.path.isfile(dataset_path):
  • 作用:检查当前工作目录下是否存在名为 pusht_cchi_v7_replay.zarr.zip 的文件。
  • 示例:如果文件不存在,则条件为 True;例如,若当前目录没有该文件,则进入 if 语句内部。
  • 意义:避免重复下载数据,只有当文件不存在时才进行下载。
    id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
  • 作用:将字符串 "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t" 赋值给变量 id,表示 Google Drive 中文件的标识符。
  • 意义:通过这个 ID,可以从 Google Drive 下载对应的数据文件。
    gdown.download(id=id, output=dataset_path, quiet=False)
  • 作用:调用 gdown.download() 方法下载文件。
  • 参数说明
    • id=id:使用前面指定的文件 ID。
    • output=dataset_path:下载后保存到文件名 pusht_cchi_v7_replay.zarr.zip
    • quiet=False:不静默下载,显示下载进度。
  • 示例:下载进度会显示在终端,并将数据保存到本地文件 pusht_cchi_v7_replay.zarr.zip

十七.2 设置参数部分

# parameters
pred_horizon = 16
  • 作用:将预测序列长度设为 16,赋值给变量 pred_horizon
  • 示例:每个样本中将包含 16 个时间步的数据用于预测。
  • 意义:定义模型预测未来动作或状态的时间步数。
obs_horizon = 2
  • 作用:将观察序列长度设为 2,赋值给变量 obs_horizon
  • 示例:每个样本中用于作为输入观察的数据为前 2 个时间步。
  • 意义:模型仅使用前 2 个时间步的观测作为输入信息。
action_horizon = 8
  • 作用:将动作序列长度设为 8,赋值给变量 action_horizon
  • 示例:每个样本中,后续 8 个时间步对应的动作作为目标输出。
  • 意义:规定动作输出的时间步数。
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16
  • 作用:注释示意图,直观表示各部分的时间步分布:
    • 上面两格代表观察部分(2 步)。
    • 中间 8 格代表实际执行的动作。
    • 最下面 16 格代表完整预测序列。
  • 意义:帮助理解 pred_horizon、obs_horizon 和 action_horizon 之间的关系。

十七.3 创建数据集

# create dataset from file
dataset = PushTImageDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
  • 作用:调用 PushTImageDataset 构造函数,创建数据集对象,并传入之前设置的参数:
    • dataset_path:文件路径 “pusht_cchi_v7_replay.zarr.zip”。
    • pred_horizon:16。
    • obs_horizon:2。
    • action_horizon:8。
  • 意义:数据集对象将读取 zarr 文件,处理图像、状态、动作数据,并生成定长样本序列的索引和归一化数据。
  • 示例:创建后,dataset 内部保存了归一化数据、样本索引等信息,可通过 dataset.stats 获取数据统计信息。
# save training data statistics (min, max) for each dim
stats = dataset.stats
  • 作用:将数据集中的统计信息(min, max 等)保存到变量 stats。
  • 示例:stats 可能为字典,例如
    {'agent_pos': {'min': [0,0], 'max': [512,512]}, 'action': {'min': [...], 'max': [...]}}.
  • 意义:后续模型或分析时可参考这些统计信息。

十七.4 创建 DataLoader

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)
  • 作用:调用 torch.utils.data.DataLoader 创建数据加载器,参数说明:
    • dataset: 使用前面创建的 PushTImageDataset 对象。
    • batch_size=64: 每个批次加载 64 个样本。
    • num_workers=4: 使用 4 个子进程并行加载数据。
    • shuffle=True: 每个 epoch 数据随机打乱。
    • pin_memory=True: 固定内存,加快 CPU 到 GPU 数据传输。
    • persistent_workers=True: 保持 worker 进程在多个 epoch 间不退出,提高效率。
  • 示例:加载器在每个 epoch 中,每次输出一个批次,形状可能为:
    • batch[‘image’]: (64, 2, 3, 96, 96)(注意 getitem 可能返回截取后的 obs_horizon 部分)。
    • batch[‘agent_pos’]: (64, 2, 2) 等(具体形状取决于 getitem 实现)。
  • 意义:构建高效数据读取管道,为训练提供批次数据。

十七.5 可视化批次数据

# visualize data in batch
batch = next(iter(dataloader))
  • 作用:从 dataloader 中获取一个批次数据:
    • iter(dataloader) 创建迭代器,next() 获取第一个批次。
  • 示例:假设批次大小为 64,则 batch 中每个键对应数组第一维大小为 64。
  • 意义:检查 DataLoader 输出是否符合预期。
print("batch['image'].shape:", batch['image'].shape)
  • 作用:打印批次中 ‘image’ 数据的形状。
  • 示例
    • 例如,输出:batch['image'].shape: torch.Size([64, 2, 3, 96, 96])(若 getitem 截取 obs_horizon=2 时间步)。
  • 意义:验证图像数据尺寸是否正确。
print("batch['agent_pos'].shape:", batch['agent_pos'].shape)
  • 作用:打印批次中 ‘agent_pos’ 数据的形状。
  • 示例
    • 例如,输出:batch['agent_pos'].shape: torch.Size([64, 2, 2])(2 个时间步,每个时间步 2 维位置)。
  • 意义:检查 agent 位置数据尺寸是否符合预期。
print("batch['action'].shape", batch['action'].shape)
  • 作用:打印批次中 ‘action’ 数据的形状。
  • 示例
    • 例如,输出:batch['action'].shape torch.Size([64, 16, 3])(假设动作数据原本在样本中长度为 pred_horizon=16,每个时间步 3 维)。
  • 意义:验证动作数据尺寸,确保数据集中 ‘action’ 键的数据正常加载。

总体说明

  • 大函数和大类的意义

    • 数据下载部分:通过检查本地文件是否存在,再利用 gdown 从 Google Drive 下载演示数据,保证数据可用性。
    • 参数设置:通过 pred_horizon、obs_horizon 和 action_horizon 定义采样序列的各个时间步部分,便于构建训练样本。
    • 数据集创建:调用 PushTImageDataset 构造函数,读取 zarr 数据集,处理图像(转换通道顺序)、状态和动作数据,生成归一化数据及样本索引。
    • DataLoader 构造:使用 PyTorch DataLoader 封装数据集,实现批量数据加载和并行加速。
    • 数据可视化:通过取一个批次并打印各部分数据形状,验证数据加载、预处理、批次构建是否正确。
  • 输入与输出

    • 输入:数据集文件路径、采样参数(pred_horizon、obs_horizon、action_horizon)。
    • 输出:一个 DataLoader 对象,每个批次返回一个字典,包含 ‘image’、‘agent_pos’、‘action’ 数据,形状按预设参数整理。
  • 设计原因

    • 通过归一化、固定长度采样、批量加载等预处理,使得训练数据符合模型输入要求,且能高效供训练使用。
    • 参数化设计方便调整时间步长度等设置,满足不同实验需求。
    • DataLoader 的多进程和 pin_memory 设置有助于加速 CPU-GPU 数据传输,提高训练效率。

相关文章:

  • C++动态内存管理完全指南:从基础到现代最佳实践
  • Windows系统本地化部署DeepSeek+Open-WebUi
  • OpenBMC:BmcWeb 处理http请求4 处理路由对象
  • nginx管理nacos集群地址
  • mlir-tblgen 的应用渐进式示例
  • JS location对象
  • 【SQL】子查询详解(附例题)
  • 【C++DFS 马拉车】3327. 判断 DFS 字符串是否是回文串|2454
  • RFID手持机读写器功能模块硬件定制专属方案
  • Python 之 Pandas 常用操作
  • 从零开始学习Python游戏编程14-随机数1
  • 【AI】高效地使用 AI 模型的 Prompt(提示词)
  • 面试题汇总06-场景题线上问题排查难点亮点
  • Linux网络编程——https的协议及其加密解密方式
  • 面试题ing
  • 智谛达科技:以创新为翼,翱翔AI人形机器人蓝海
  • 企业如何解决供应商风控难题?
  • 保安员考试考哪些内容呢?
  • 51.评论日记
  • 【Vue-路由案例】面经基础版
  • 北京网站建设方案/如何写营销软文
  • 网站做的自适应体验差/品牌策划是做什么的
  • asp学校网站系统/网络营销课程设计
  • 铜陵app网站做招聘/优化网站
  • vps如何做网站步骤/国内网络营销公司排名
  • 网站建设价格a去找真甲先生/百度首页纯净版怎么设置