Diffusion Policy Visuomotor Policy Learning via Action Diffusion官方项目解读(二)(6)
运行官方代码库中提供的Colab代码:vision-based environment(二)(6)
- Vision Encoder
- 二十四、函数`get_resnet`
- 总体说明
- 二十五、函数`replace_submodules`
- 总体说明
- 二十五、函数`replace_bn_with_gn`
- 总体说明
- Network Demo
- 二十六
- 二十六.1 构造视觉编码器
- 二十六.2. 替换归一化层
- 二十六.3 设置各特征维度
- 二十六.4 构造噪声预测网络(ConditionalUnet1D)
- 二十六.5 构建整体网络字典
- 二十六.6 演示部分:前向推理(with torch.no_grad() 部分)
- 二十六.7 配置扩散噪声调度器
- 二十六.8 设备转移
- 总体说明
官方项目地址:https://diffusion-policy.cs.columbia.edu/
Colab代码:vision-based environment
Vision Encoder
二十四、函数get_resnet
def get_resnet(name: str, weights=None, **kwargs) -> nn.Module:
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", None
"""
- 作用:定义一个名为
get_resnet
的函数,用于获取预定义名称的 ResNet 模型。 - 参数:
name: str
:字符串,表示所需 ResNet 版本,例如"resnet18"
、"resnet34"
或"resnet50"
。- 示例:若调用时传入
name="resnet50"
,则表示需要使用 ResNet50 模型。
- 示例:若调用时传入
weights
:指定预训练权重,可传入类似"IMAGENET1K_V1"
的标识,或传入None
表示不加载预训练权重。- 示例:调用时传入
weights="IMAGENET1K_V1"
,表示加载在 ImageNet 数据集上的预训练参数。
- 示例:调用时传入
**kwargs
:其他关键字参数,将传递给 torchvision 模型函数。- 示例:可能传入
num_classes=1000
(不过在这里会被忽略,因为后续移除 fc 层)。
- 示例:可能传入
- 返回类型:
-> nn.Module
表示返回的是一个 PyTorch 模型实例。
# Use standard ResNet implementation from torchvision
func = getattr(torchvision.models, name)
- 作用:使用内置函数
getattr
从torchvision.models
模块中获取属性,其名称由变量name
指定。 - 解释:
- 如果
name
为"resnet18"
,则getattr(torchvision.models, "resnet18")
返回torchvision.models.resnet18
这一函数。
- 如果
- 示例:假设
name="resnet50"
,那么此行后,变量func
将指向torchvision.models.resnet50
。 - 意义:实现动态选择模型类型,无需用 if/else 分支。
resnet = func(weights=weights, **kwargs)
- 作用:通过前面动态获取的函数
func
(例如torchvision.models.resnet50
),调用其构造函数生成一个 ResNet 模型实例。 - 参数:
weights=weights
:将传入的 weights 参数传给模型构造函数,决定是否加载预训练权重。**kwargs
:传入额外参数,可能包括其它网络配置,如num_classes
等。
- 示例:
- 如果传入
weights="IMAGENET1K_V1"
, 则生成的模型将加载 ImageNet 预训练权重。 - 假设调用
get_resnet("resnet50", weights="IMAGENET1K_V1")
,则生成的模型对象存储在resnet
中,为 ResNet50 模型。
- 如果传入
- 意义:使用 torchvision 提供的预训练模型构造函数,快速得到标准 ResNet 模型实例。
# remove the final fully connected layer
# for resnet18, the output dim should be 512
- 作用:说明下面代码将删除 ResNet 模型中的最后全连接层(fc 层)。
- 意义:为了将 ResNet 用作特征提取器,而非进行分类,常常需要去除分类头。
- 示例:
- 对于 ResNet18,全连接层的输出维度原本为 512;移除后,模型最后输出的特征维度为 512(作为特征向量)。
- 用途:去除 fc 层使得模型输出最后一层特征图,适合用作下游任务的特征提取器。
resnet.fc = torch.nn.Identity()
- 作用:将 ResNet 模型的 fc 属性(最后的全连接层)替换为
torch.nn.Identity()
模块,即不执行任何计算,直接输出输入本身。 - 示例:
- 若原始
resnet.fc
是一个线性映射层,将其替换后,对于输入特征(例如形状 (B,512))直接返回相同的 (B,512) 张量。
- 若原始
- 意义:使模型在前向传播时绕过全连接层,直接输出最后的特征表示。这在迁移学习或特征提取任务中常见。
return resnet
- 作用:将修改后的 ResNet 模型(已将 fc 层替换为 Identity)返回。
- 输出示例:
- 如果调用
get_resnet("resnet50", weights="IMAGENET1K_V1")
,最终返回的模型为 ResNet50,其最后输出为 (B,2048) 维特征(对于 ResNet50 而言,其 fc 层前的特征维度为 2048;对于 ResNet18, 通常 fc 层前特征为 512)。
- 如果调用
- 意义:完成函数功能,给使用者提供一个用于特征提取而非分类的 ResNet 模型实例。
总体说明
函数 get_resnet 的作用与意义
- 目的:该函数旨在动态加载 torchvision 模型库中指定名称的 ResNet 模型,并移除最后的全连接层,使其成为一个特征提取器。
- 输入:
name
: 字符串,指示所需 ResNet 模型,如 “resnet18”, “resnet34”, “resnet50”。weights
: 可选参数,用于指定预训练权重,例如 “IMAGENET1K_V1”(在 torchvision 中的常用配置),或设置为 None 表示随机初始化。**kwargs
: 其他传递给模型构造函数的参数,如可能的num_classes
等。
- 输出:
- 返回一个 nn.Module 模型实例,该模型是所选 ResNet 去除了最后全连接层的版本,输出为最后一层特征张量。
- 如对于 resnet18,输出特征维度为 512;对于 resnet50,可能输出 2048 维特征(具体取决于 torchvision 实现)。
- 设计原因:
- 在迁移学习、特征提取或下游任务中,我们通常不需要用于分类的全连接层。将 fc 层置为 Identity 可简化后续使用,并保持 ResNet 的深度特征表示。
- 动态使用 getattr 可以根据传入的 name 选项灵活获取不同版本的 ResNet 模型,而不必写多个分支代码。
二十五、函数replace_submodules
该函数的主要目的是在一个给定的 PyTorch 模型(root_module)的子模块层次结构中,按照一个判断条件(predicate)替换满足条件的模块,并用给定的替换函数(func)的输出代替原模块。这样的工具在修改现有模型(例如替换 BatchNorm 为 GroupNorm 或替换所有某类型的层)时非常有用。
def replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
Replace all submodules selected by the predicate with
the output of func.
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
- 定义了名为
replace_submodules
的函数,返回一个nn.Module
(修改后的模型)。 - 参数说明:
root_module
:一个 PyTorch 模块,即我们要在其内部查找并替换子模块的根模块。例如,一个大型神经网络模型。predicate
:一个函数,其输入为一个模块,返回布尔值。用于判断当前模块是否需要替换。- 示例:可以是
lambda m: isinstance(m, nn.BatchNorm2d)
用来判断是否为 BatchNorm2d 层。
- 示例:可以是
func
:一个函数,输入为一个模块,输出为一个模块。用于将满足条件的模块替换成新模块。- 示例:可以是
lambda m: nn.Identity()
用来将所有 BatchNorm2d 层替换为 Identity 层。
- 示例:可以是
- 意义:使得函数可以动态地在模型中查找并替换满足某个条件的所有子模块。
if predicate(root_module):
return func(root_module)
- 作用:如果根模块自己就满足 predicate(即要替换),直接返回
func(root_module)
替换后的模块,不再往下遍历。 - 具体数值例子:
- 如果 predicate 定义为
lambda m: isinstance(m, nn.BatchNorm2d)
,而 root_module 本身是 BatchNorm2d(假设特殊情况),则 predicate(root_module) 为 True,将直接返回func(root_module)
。
- 如果 predicate 定义为
- 意义:避免在整个树形结构中重复遍历,如果根模块需要替换,则整个模块直接被替换。
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
- 作用:使用
root_module.named_modules(remove_duplicate=True)
遍历所有子模块,取出满足 predicate 的模块的键(即名称),并对名称字符串按句点进行分割,存入列表 bn_list。 - 示例:
- 假设模型中存在模块
layer1.0.bn1
和layer1.1.bn1
满足 predicate,则名称分别为 “layer1.0.bn1” 和 “layer1.1.bn1”,经过k.split('.')
得到列表["layer1", "0", "bn1"]
和["layer1", "1", "bn1"]
,最终 bn_list 为:[ ["layer1", "0", "bn1"], ["layer1", "1", "bn1"] ]
- 假设模型中存在模块
- 意义:将待替换模块的完整路径分解为层次结构,方便后续定位具体模块的父模块和属性名。
for *parent, k in bn_list:
- 作用:对 bn_list 中的每个分割后的路径列表进行遍历,使用 Python 的可变参数解包将除最后一个部分之外的部分存入列表 parent,而最后一个部分存入 k。
- 示例:
- 对于 bn_list 中的 [“layer1”, “0”, “bn1”],解包后:parent = [“layer1”, “0”],k = “bn1”。
- 意义:将路径拆分为父模块路径和当前模块名称,便于定位并替换模块。
parent_module = root_module
- 作用:变量 parent_module 初始赋值为 root_module,用于后续查找子模块的父模块。
- 示例:开始时 parent_module 指向整个模型。
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
- 作用:如果 parent 不为空,则使用 root_module.get_submodule() 方法获取由 parent 路径指定的父模块。
- 具体过程:
- 将 parent 列表用 ‘.’ 连接成字符串,例如 [“layer1”, “0”] → “layer1.0”。
- 调用 get_submodule(“layer1.0”) 得到该父模块。
- 示例:对于前面例子,将 parent_module 设为模型中名为 “layer1.0” 的模块。
- 意义:确定要替换的目标模块的父容器,以便能够通过属性赋值或索引方式替换它。
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
- 作用:
- 如果父模块为
nn.Sequential
(子模块以数字索引存储),则将 k 转换为整数,并取出对应的子模块; - 否则,通过
getattr
从父模块中获取属性名为 k 的子模块。
- 如果父模块为
- 示例:
- 假设 parent_module 是一个 nn.Sequential,其包含多个模块,k 在这种情况下可能为 “0” 或 “1”(字符串形式),转换为 int 后就可以用索引获取;例如 parent_module[0]。
- 若 parent_module 不是 Sequential(例如一个自定义模块),则通过 getattr(parent_module, “bn1”) 取得该模块。
- 意义:确保正确获取被替换的具体模块。
tgt_module = func(src_module)
- 作用:调用传入函数 func,对获取的源模块 src_module 进行转换,并将结果保存为 tgt_module。
- 示例:
- 如果 func 定义为
lambda m: nn.Identity()
,那么 tgt_module 将为 nn.Identity();即目标模块替换为“恒等映射”。 - 假设 src_module 为 BatchNorm2d 模块,经过 func 后,tgt_module 为 Identity 层。
- 如果 func 定义为
- 意义:用新的模块替换原有模块,从而实现模块升级、修改或其他自定义调整。
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
- 作用:
- 如果父模块为 nn.Sequential,则直接使用索引赋值(例如 parent_module[0] = tgt_module);
- 否则使用 setattr 在父模块中设置属性名 k 对应的模块为 tgt_module。
- 示例:
- 对于 nn.Sequential 父模块,若 k 为 “0”(转换为 int 为 0),就执行 parent_module[0] = tgt_module;
- 若父模块为其他模块,则执行 setattr(parent_module, “bn1”, tgt_module)。
- 意义:完成替换操作,使得满足 predicate 的子模块被 func 函数生成的新模块所取代。
# verify that all modules are replaced
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
- 作用:验证是否所有满足 predicate 的模块都已经被替换掉。与前面类似,重新遍历整个模块结构,收集所有仍然满足 predicate 的模块路径列表。
- 意义:检查替换后的模型是否仍有模块满足 predicate(若有,说明替换没有彻底完成)。
- 示例:如果一开始需要替换的模块全部成功替换,则此时 bn_list 应为空列表。
assert len(bn_list) == 0
- 作用:通过断言确保 bn_list 长度为 0,即模型中没有残留仍需替换的模块。
- 意义:保证函数执行后符合预期,若断言失败则表示替换未完全成功。
- 示例:
- 如果 bn_list 非空(例如有 2 个模块未被替换),则 assert 会触发错误,提示开发者检查替换逻辑。
return root_module
- 作用:返回经过所有替换操作的 root_module,即整个模型已更新。
- 意义:调用函数的最终输出,是一个经过全局修改后、满足要求的模型。
- 示例:
- 若最初调用
replace_submodules(root_module, predicate, func)
后,返回的模型中所有满足 predicate 的子模块都已替换为 func 的输出模块。
- 若最初调用
总体说明
函数 replace_submodules 的作用
- 目的:在给定模型 root_module 内查找所有满足 predicate 条件的子模块,然后使用 func 函数生成的新模块替换它们。这样可以批量修改模型的特定部分,常见于模型优化、迁移学习或模型结构定制时。
- 输入:
- 一个根模块(例如一个完整的神经网络模型)。
- 一个判断函数 predicate,用于筛选需要替换的模块。
- 一个替换函数 func,接受一个模块并返回新的模块。
- 输出:
- 返回修改后的模型 root_module,其中所有满足 predicate 条件的模块均已被替换。
- 设计原因:
- 采用动态模块替换可以大幅简化代码,不必手动找到每个需要改变的层。
- 使用 named_modules 和子模块路径管理保证了层次结构中每个目标模块能够正确定位和替换。
- 断言验证(assert)用于确保替换工作彻底完成,是一种良好的健壮性检查。
- 此工具适用于开发过程中对模型进行“微调”,例如替换所有 Dropout 或 BatchNorm 层为自定义实现等场景。
二十五、函数replace_bn_with_gn
def replace_bn_with_gn(
root_module: nn.Module,
features_per_group: int = 16) -> nn.Module:
"""
Relace all BatchNorm layers with GroupNorm.
"""
- 作用:定义一个名为
replace_bn_with_gn
的函数,该函数接受两个参数:root_module
:类型为nn.Module
,表示要处理的模型或网络的根模块;features_per_group
:整型参数,默认为 16,表示在替换时每个 GroupNorm 中期望的通道数。
- 返回值:函数返回修改后的
root_module
,类型依然为nn.Module
。 - 意义:该函数目标是遍历
root_module
内部所有子模块,并将所有的 BatchNorm2d 层替换为 GroupNorm 层,以获得更稳定或适应小批量数据的归一化效果。 - 示例:假设我们有一个包含若干
nn.BatchNorm2d
层的模型,可以调用
替换后模型中的 BatchNorm2d 都变成了 GroupNorm。new_model = replace_bn_with_gn(model, features_per_group=16)
replace_submodules(
root_module=root_module,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // features_per_group,
num_channels=x.num_features)
)
- 作用:调用一个名为
replace_submodules
的函数,目的是在root_module
内部递归地查找满足条件的子模块,并用新的模块替换它们。 - 参数详细说明:
root_module=root_module
:传入最初的模块,作为递归搜索的根节点。predicate=lambda x: isinstance(x, nn.BatchNorm2d)
:- 这是一个匿名函数(lambda),对传入的子模块进行检查,只有当子模块是
nn.BatchNorm2d
的实例时返回 True。 - 示例:如果某层为
nn.BatchNorm2d(64)
, 则isinstance(x, nn.BatchNorm2d)
将返回 True;如果是nn.Conv2d
则返回 False。
- 这是一个匿名函数(lambda),对传入的子模块进行检查,只有当子模块是
func=lambda x: nn.GroupNorm( num_groups=x.num_features // features_per_group, num_channels=x.num_features)
:- 这是另一个匿名函数,用于创建新模块来替换满足 predicate 条件的 BatchNorm2d 层。
- 其中,
x.num_features
表示原 BatchNorm2d 层的特征数。例如,如果原层创建时定义为nn.BatchNorm2d(64)
,那么x.num_features
就是 64。 x.num_features // features_per_group
计算出每个 GroupNorm 的组数:- 若 features_per_group 为 16,则 64 // 16 = 4;即将 64 个通道划分为 4 组,每组 16 个通道。
nn.GroupNorm(num_groups=4, num_channels=64)
将创建一个 GroupNorm 层用于归一化 64 个通道的数据。
- 意义:这一行的调用会在整个
root_module
内部寻找所有 BatchNorm2d 层,并将它们替换成对应的 GroupNorm 层。
- 注意:函数
replace_submodules
是一个已实现的工具函数,用于递归遍历模块并进行替换操作,其内部实现不在此处展示,但核心思想是:对于每个子模块,如果 predicate 返回 True,则用 func(x) 的结果替换原模块;否则递归搜索子模块。
- 示例:假设模型中存在一个
nn.BatchNorm2d(128)
层,且 features_per_group=16,此时:x.num_features
为 128,- 计算
128 // 16 = 8
, - 将该层替换为
nn.GroupNorm(num_groups=8, num_channels=128)
。
return root_module
- 作用:函数结束时返回经过替换操作后的
root_module
。 - 意义:用户调用该函数后获得一个所有 BatchNorm 层均被替换为 GroupNorm 层的模型副本。
- 示例:如果原始
root_module
是模型 A,那么返回值就是 A,但其中所有nn.BatchNorm2d
层均已变为 GroupNorm 层。
总体说明
- 大函数/模块意义:
- 目的:
replace_bn_with_gn
的主要目的是在一个给定的神经网络模块中,将所有 2D 批归一化层(BatchNorm2d)替换成分组归一化层(GroupNorm)。 - 输入:
root_module
:需要转换的模型或模块,它可能包含多个 BatchNorm2d 子模块。features_per_group
:指定每组希望包含的特征数,用于计算 GroupNorm 的组数。
- 输出:
- 返回修改后的模块,其内部所有 BatchNorm2d 均已替换为 GroupNorm 层。
- 设计原因:
- 在某些情况下(例如小批量训练或者跨设备部署时),BatchNorm 的表现可能不如 GroupNorm 稳定。使用 GroupNorm 可以降低对 batch size 的依赖并获得更稳定的归一化效果。
- 通过递归替换的方法,用户无需手动修改每个层,而可以自动遍历整个网络实现替换,提升代码的灵活性与可维护性。
- 目的:
Network Demo
二十六
二十六.1 构造视觉编码器
# construct ResNet18 encoder
# if you have multiple camera views, use seperate encoder weights for each view.
vision_encoder = get_resnet('resnet18')
- 注释1:
- “construct ResNet18 encoder”
说明接下来要构造一个基于 ResNet18 的视觉编码器(encoder)。
- “construct ResNet18 encoder”
- 注释2:
- “if you have multiple camera views, use seperate encoder weights for each view.”
提示当有多个摄像头视角时,应为每个视角使用独立的编码器权重。
- “if you have multiple camera views, use seperate encoder weights for each view.”
- 代码:
vision_encoder = get_resnet('resnet18')
调用辅助函数get_resnet
(假设内部会根据字符串参数选择网络结构),这里传入'resnet18'
,表示构造 ResNet18 模型。- 示例:最终
vision_encoder
将是一个 ResNet18 模型实例,其最后一层输出特征维度通常为 512(参见后文说明)。
- 意义:
通过预训练的 ResNet18 提取图像特征,为后续任务提供视觉信息。
二十六.2. 替换归一化层
# IMPORTANT!
# replace all BatchNorm with GroupNorm to work with EMA
# performance will tank if you forget to do this!
vision_encoder = replace_bn_with_gn(vision_encoder)
- 注释:
- 提醒非常重要:为配合 EMA(Exponential Moving Average)训练策略,需要用 GroupNorm 替换所有 BatchNorm 层,否则性能会大幅下降。
- 代码:
vision_encoder = replace_bn_with_gn(vision_encoder)
调用函数replace_bn_with_gn
(前面实现过,功能是递归遍历模块,将nn.BatchNorm2d
替换为相应的nn.GroupNorm
),传入当前的vision_encoder
。- 示例:如果 ResNet18 中原来有
nn.BatchNorm2d(64)
层,经过替换后,将变为nn.GroupNorm(num_groups=64//16, num_channels=64)
(假设默认 features_per_group=16,即 4 组,每组 16 个通道)。
- 意义:
使用 GroupNorm 对于小批量训练或 EMA 策略更稳定,提升模型训练效果和推理稳定性。
二十六.3 设置各特征维度
# ResNet18 has output dim of 512
vision_feature_dim = 512
# agent_pos is 2 dimensional
lowdim_obs_dim = 2
# observation feature has 514 dims in total per step
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2
- 第一行注释:
- 提示 ResNet18 的输出特征维度为 512。
- 代码:
vision_feature_dim = 512
将视觉特征维度设置为 512。
- 第二行注释:
- “agent_pos is 2 dimensional”
表示智能体位置信息是 2 维(例如 x,y 坐标)。
- “agent_pos is 2 dimensional”
- 代码:
lowdim_obs_dim = 2
将智能体低维观测(位置)维度设置为 2。
- 第三行注释:
- 说明每个时间步的总观测特征由视觉特征和 agent 位置拼接而成,总维度为 514。
- 代码:
obs_dim = vision_feature_dim + lowdim_obs_dim
计算总观测维度,即 512 + 2 = 514。
- 代码:
action_dim = 2
定义动作维度为 2;通常动作表示为二维向量,如移动方向或目标位置。
- 意义:
明确整个系统在后续网络中各数据流的维度,便于构造条件输入等模块。
二十六.4 构造噪声预测网络(ConditionalUnet1D)
# create network object
noise_pred_net = ConditionalUnet1D(
input_dim=action_dim,
global_cond_dim=obs_dim*obs_horizon
)
- 注释:
- “create network object” 提示接下来创建一个噪声预测网络。
- 代码:
noise_pred_net = ConditionalUnet1D(...)
调用构造函数创建一个 ConditionalUnet1D 模块。
- 参数说明:
input_dim=action_dim
表示输入维度为动作维度,即 2(在本例中)。global_cond_dim=obs_dim*obs_horizon
表示全局条件维度为每个时间步观测特征维度(514)乘以观测时间步数(obs_horizon),例如如果 obs_horizon 设为 2,则 global_cond_dim = 514 * 2 = 1028。
- 意义:
噪声预测网络用于在扩散模型中预测噪声(或残差),条件网络通过全局条件和扩散步信息调制其生成结果。 - 示例:假设 action_dim=2,global_cond_dim=1028,则 ConditionalUnet1D 将按照预先定义的 UNet 结构处理输入噪声序列。
二十六.5 构建整体网络字典
# the final arch has 2 parts
nets = nn.ModuleDict({
'vision_encoder': vision_encoder,
'noise_pred_net': noise_pred_net
})
- 注释:
- “the final arch has 2 parts” 说明整个系统包含视觉编码器和噪声预测网络两大部分。
- 代码:
nets = nn.ModuleDict({...})
创建一个nn.ModuleDict
,将 vision_encoder 和 noise_pred_net 分别作为两个键的值存入字典中。
- 意义:
使用 ModuleDict 可将多个子模块组织在一起,方便统一管理、调用和设备转移。 - 示例:
- 通过
nets['vision_encoder']
可调用视觉编码器,对输入图像进行特征提取; - 通过
nets['noise_pred_net']
调用噪声预测 UNet 模块。
- 通过
二十六.6 演示部分:前向推理(with torch.no_grad() 部分)
# demo
with torch.no_grad():
- 作用:进入
with torch.no_grad()
上下文,关闭自动求导,提高推理效率并节省内存。 - 意义:演示阶段无需梯度计算,只需前向推理。
# example inputs
image = torch.zeros((1, obs_horizon, 3, 96, 96))
- 作用:构造一个示例图像输入,形状为 (1, obs_horizon, 3, 96, 96)。
- 示例:
- 假设 obs_horizon 设为 2,则 image 的形状为 (1, 2, 3, 96, 96);所有像素均为 0。
- 意义:用于模拟网络输入,实际场景中会用真实图像数据。
agent_pos = torch.zeros((1, obs_horizon, 2))
- 作用:构造一个示例智能体位置信息,形状为 (1, obs_horizon, 2)。
- 示例:
- 若 obs_horizon=2,则 agent_pos 的形状为 (1,2,2),所有值均为 0。
- 意义:提供与图像对应的低维观测,便于构造全局条件。
# vision encoder
image_features = nets['vision_encoder'](
image.flatten(end_dim=1))
- 作用:
- 将
image
张量从形状 (B, obs_horizon, 3, 96, 96) 展开(flatten)前两个维度,使其形状变为 (B * obs_horizon, 3, 96, 96)。 - 然后传入 vision_encoder 中提取图像特征。
- 将
- 示例:
- 对于 image 的形状 (1,2,3,96,96),flatten 后形状变为 (2,3,96,96)。
- 经过 ResNet18 编码器后,假定输出形状为 (2,512),因为 ResNet18 的特征维度为 512。
- 意义:将多步(obs_horizon)图像分别通过视觉编码器处理,获得每步的视觉特征。
# (2,512)
image_features = image_features.reshape(*image.shape[:2], -1)
- 作用:将提取出的图像特征重新 reshape 回 (B, obs_horizon, feature_dim)。
- 示例:
- 原 image_features 形状 (2,512),重塑为 (1,2,512)(因为 image.shape[:2] = (1,2))。
- 意义:方便后续与 agent_pos 拼接,构成全局条件。
# (1,2,512)
obs = torch.cat([image_features, agent_pos], dim=-1)
# (1,2,514)
- 作用:在最后一维上拼接图像特征与 agent 位置。
- 示例:
- image_features 的形状 (1,2,512),agent_pos 的形状 (1,2,2),拼接后 obs 的形状为 (1,2,514)。
- 意义:构建每个时间步的整体观测(514 维),与之前计算 obs_dim = 514 对应。确保后续输入的维度正确。
noised_action = torch.randn((1, pred_horizon, action_dim))
- 作用:生成一个噪声化的动作样本,形状为 (1, pred_horizon, action_dim)。
- 示例:
- 如果 pred_horizon=16,action_dim=2,则 noised_action 的形状为 (1,16,2);
- 其中每个元素随机生成自标准正态分布。
- 意义:作为扩散模型中预测噪声的输入样本。
diffusion_iter = torch.zeros((1,))
- 作用:创建一个表示扩散迭代步(timestep)的张量,形状为 (1,)。
- 示例:
- 输出 tensor([0]),表示当前扩散步骤为 0;在实际任务中,该值可能表示当前迭代的编号。
- 意义:用于为噪声预测网络提供扩散过程中的时间步信息。
# the noise prediction network
# takes noisy action, diffusion iteration and observation as input
# predicts the noise added to action
noise = nets['noise_pred_net'](
sample=noised_action,
timestep=diffusion_iter,
global_cond=obs.flatten(start_dim=1))
- 作用:
- 将三个输入传入
noise_pred_net
:sample
:噪声化动作,形状 (1, pred_horizon, action_dim)。timestep
:扩散迭代步,形状 (1,)。global_cond
:全局条件信息,将 obs 展开(flatten 从第 1 维开始)后,形状变为 (B, obs_dim * obs_horizon)。
- 示例:
- 之前 obs 的形状为 (1,2,514),flatten(start_dim=1) 得到 (1, 1028);
- noised_action 的形状为 (1,16,2);
- diffusion_iter 为 tensor([0]);
- noise_pred_net 将根据这些输入预测噪声,输出形状通常与 sample 相同,即 (1,16,2)。
- 将三个输入传入
- 意义:噪声预测网络基于扩散模型的思想,输入噪声化的动作及条件信息后,预测出加入到动作中的噪声量。
# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_action - noise
- 作用:
- 计算“去噪”后的动作,简单地将预测噪声从 noised_action 中减去。
- 示例:
- 如果 noised_action 中某个元素为 0.5,noise 预测为 0.1,则 denoised_action 为 0.5 - 0.1 = 0.4。
- 意义:
- 这仅作为示例说明,实际去噪过程依赖于更复杂的 NoiseScheduler 机制和扩散噪声计划。
二十六.7 配置扩散噪声调度器
# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
- 作用:设置扩散迭代次数为 100,将该值保存到 num_diffusion_iters。
- 示例:num_diffusion_iters = 100。
- 意义:在扩散模型中,通常需要进行多次迭代来逐步去噪;这里设置为 100 次迭代。
noise_scheduler = DDPMScheduler(
num_train_timesteps=num_diffusion_iters,
# the choise of beta schedule has big impact on performance
# we found squared cosine works the best
beta_schedule='squaredcos_cap_v2',
# clip output to [-1,1] to improve stability
clip_sample=True,
# our network predicts noise (instead of denoised action)
prediction_type='epsilon'
)
- 作用:创建一个 DDPMScheduler 对象,用于管理扩散过程中的噪声调度。
- 各参数说明:
num_train_timesteps=num_diffusion_iters
:设置训练时总共扩散迭代步数为 100。beta_schedule='squaredcos_cap_v2'
:选择 beta 调度策略,这里为 ‘squaredcos_cap_v2’,这种调度在文献中表现较好。clip_sample=True
:启用 clip 操作,将生成样本限制在 [-1,1] 范围内,以提高数值稳定性。prediction_type='epsilon'
:指定模型预测的是噪声 epsilon,而非直接预测去噪后的动作。
- 示例:创建后,noise_scheduler 管理整个扩散过程中的噪声添加与去除策略。
- 意义:扩散模型的实际训练和推理依赖于噪声调度器,合理的 beta 调度与剪切策略能显著影响最终生成质量和训练稳定性。
二十六.8 设备转移
# device transfer
device = torch.device('cuda')
- 作用:设置 device 为 ‘cuda’,表示将模型转移到 GPU 上进行加速训练或推理。
- 示例:device 可能为 torch.device(‘cuda’)。
- 意义:使用 GPU 可显著加速大模型的前向和反向传播计算。
_ = nets.to(device)
- 作用:将整个 nets 模块字典(包含 vision_encoder 与 noise_pred_net)转移到指定的设备上(GPU)。
- 示例:调用
nets.to(device)
后,所有网络参数都被复制到 GPU 内存中。 - 意义:确保所有子模块在同一设备上,避免数据传输错误,同时提高执行效率。
总体说明
大模块组成与意义:
- 视觉编码器构建与替换
- 利用 ResNet18 构建视觉编码器,并用 GroupNorm 替换 BatchNorm,以保证在使用 EMA 训练时效果稳定。
- 输出特征维度为 512。
- 特征维度设置与条件信息构造
- 通过将视觉特征(512 维)与 agent 低维位置信息(2 维)拼接,构成 514 维的观测,每个时间步形成整体观测。
- 同时动作维度设置为 2。
- 噪声预测网络(ConditionalUnet1D)构造
- 根据动作维度和条件维度(观测乘以观测时间步数)构造 UNet 网络,用于扩散模型中预测噪声。
- 条件信息(global_cond)由视觉和位置信息组合形成。
- 整体网络字典
- 使用 nn.ModuleDict 将视觉编码器和噪声预测网络组织在一起,便于统一管理和设备转移。
- Demo 推理流程
- 构造示例输入,包括图像和 agent 位置(均为零张量)。
- 视觉编码器提取图像特征后,重塑并拼接出全局观测 obs(形状 (1, obs_horizon, 514))。
- 随机生成噪声化动作作为扩散模型输入,并设置扩散步(timestep)。
- 通过噪声预测网络输出噪声,并简单地通过相减来模拟去噪效果(实际去噪过程由 NoiseScheduler 控制)。
- 扩散噪声调度器(DDPMScheduler)构造
- 设置扩散迭代步数、beta 调度策略等参数,以确保噪声添加与去噪过程按预期运行。
- 设备转移
- 将所有网络子模块转移到 GPU 上运行,加速计算。
整体输入与输出:
-
输入:
- 图像:形状 (B, obs_horizon, 3, 96, 96)(例如 (1,2,3,96,96))。
- Agent 位置信息:形状 (B, obs_horizon, 2)。
- 噪声动作:形状 (B, pred_horizon, 2)。
- Diffusion timestep:标量或形状 (B,),例如 0 或 100。
-
输出:
- 噪声预测网络输出噪声,与噪声动作形状相同 (B, pred_horizon, action_dim),可用于扩散模型中反向去噪生成目标动作。
设计原因:
- 采用条件 UNet 模型结构,将下采样与上采样结合起来,多尺度融合条件信息,适用于扩散模型中复杂噪声预测任务。
- 视觉编码器提取高维图像特征,结合低维 agent 位置信息构成全局条件,帮助模型在生成过程中获得更丰富的上下文信息。
- 使用 GroupNorm 替换 BatchNorm,适用于 EMA 及小批量训练,能获得更稳定的训练过程。
- 引入扩散步编码(SinusoidalPosEmb)与全局条件拼接,是扩散模型中常见的条件注入方式,有助于动态调节生成过程。