Diffusion Policy Visuomotor Policy Learning via Action Diffusion官方项目解读(二)(5)
运行官方代码库中提供的Colab代码:vision-based environment(二)(5)
- Network
- 十八、类`SinusoidalPosEmb`,继承自`nn.Module`
- 十八.1 `def __init__()`
- 十八.2 `def forward()`
- 总体说明
- 十九、类`Downsample1d`,继承自`nn.Module`
- 十九.1 `def __init__()`
- 十九.2 `def forward()`
- 总体说明
- 二十、类`Upsample1d`,继承自`nn.Module`
- 二十.1 `def __init__()`
- 二十.2 `def forward()`
- 总体说明
- 二十一、类`Conv1dBlock`,继承自`nn.Module`
- 二十一.1 `def __init__()`
- 二十一.2 `def forward()`
- 总体说明
- 二十二、类`ConditionalResidualBlock1D`,继承自`nn.Module`
- 二十二.1 `def __init__()`
- 二十二.2 `def forward()`
- 总体说明
- 二十三、类`ConditionalUnet1D `,继承自`nn.Module`
- 二十三.1 `def __init__()`
- 二十三.2 `def forward()`
- 总体说明
官方项目地址:https://diffusion-policy.cs.columbia.edu/
Colab代码:vision-based environment
Network
十八、类SinusoidalPosEmb
,继承自nn.Module
class SinusoidalPosEmb(nn.Module):
- 作用:定义一个名为
SinusoidalPosEmb
的类,继承自PyTorch
的nn.Module
。 - 意义:该类实现了正弦位置嵌入(sinusoidal positional embedding),常用于为输入数据(如时间步或序列位置)添加位置信息,帮助模型捕捉序列中各个位置的相对关系。
- 示例:在 Transformer 中通常用于添加位置信息,使得模型不丢失顺序信息。
十八.1 def __init__()
def __init__(self, dim):
- 作用:定义构造函数,接收一个参数 dim,用于指定嵌入的总维度。
- 意义:dim 决定了输出位置编码的维度。通常 dim 是偶数,因为后面会将 sin 与 cos 拼接。
- 示例:如果传入 dim=64,则位置编码的维度为 64,其中前 32 维为 sin 部分,后 32 维为 cos 部分。
super().__init__()
- 作用:调用父类(
nn.Module
)的构造函数,完成 nn.Module 的初始化工作。 - 意义:确保该模块可以正常使用诸如
.to(device)
、.parameters()
等 PyTorch 内置方法。 - 示例:这一步通常是所有 nn.Module 子类必须调用的初始化步骤。
self.dim = dim
- 作用:将传入的参数 dim 保存到实例变量 self.dim 中。
- 意义:后续在 forward() 方法中需要使用 self.dim 来计算嵌入维度。
- 示例:如果 dim=64,则 self.dim 的值为 64。
十八.2 def forward()
def forward(self, x):
- 作用:定义前向传播方法 forward,输入参数 x。
- 意义:x 一般为一个包含位置信息的张量(例如,时间步或标量),函数会根据 x 生成对应的正弦位置编码。
- 示例:假设 x 的形状为 (batch_size, ),例如 x=torch.tensor([1.0, 2.0, 3.0]),batch_size 为 3。
device = x.device
- 作用:获取输入张量 x 所在的设备(CPU 或 GPU),并保存到变量 device 中。
- 意义:后续生成的张量需要放在与 x 相同的设备上,以避免设备不匹配错误。
- 示例:如果 x 位于 GPU 上,则 device 为 torch.device(‘cuda:0’);如果在 CPU,则 device 为 torch.device(‘cpu’)。
half_dim = self.dim // 2
- 作用:计算嵌入维度的一半,并赋值给 half_dim。
- 意义:后续生成 sin 与 cos 两部分,维度各为 half_dim,总维度为 half_dim * 2(即 self.dim)。
- 示例:如果 self.dim 为 64,则 half_dim = 64 // 2 = 32。
emb = math.log(10000) / (half_dim - 1)
具身智能零碎知识点(一):深入解析Transformer位置编码
- 作用:计算一个标量 emb,表示频率指数的步长。
- 意义:这里的
math.log(10000) / (half_dim - 1)
来自于 Transformer 论文中提出的公式,用于生成不同频率的正弦和余弦函数。
公式中 log(10000) 保证不同维度之间频率呈指数分布。 - 示例:
- 如果 half_dim=32,则 emb = math.log(10000) / 31 ≈ 9.21034 / 31 ≈ 0.2971。
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
- 作用:
- 首先通过
torch.arange(half_dim, device=device)
生成一个从 0 到 half_dim-1 的整数序列。 - 然后将该序列乘以 -emb,再取指数得到一个包含 half_dim 个元素的张量 emb。
- 首先通过
- 意义:这一步生成不同尺度(频率)的因子,用于对输入 x 进行缩放,使得每个维度的正弦/余弦周期不同。
- 示例:
-
假设 half_dim=4,为便于示例:torch.arange(4) 生成 [0,1,2,3]。
-
假设 emb 步长计算为 0.5(为了简单起见),则 [0,1,2,3] * -0.5 = [0, -0.5, -1.0, -1.5]。
-
torch.exp([0, -0.5, -1.0, -1.5]) ≈ [1.0, 0.6065, 0.3679, 0.2231].
令: scale = log(10000) / (half_dim-1) emb = exp( -i * scale ) 则: emb = exp( -i * log(10000)/(half_dim-1) ) = [exp(log(10000))]^(-i/(half_dim-1)) = 10000^(-i/(half_dim-1)) 所以和transformer中的公式是等价的
-
emb = x[:, None] * emb[None, :]
- 作用:
x[:, None]
将 x 增加一个新的维度,变为形状 (batch_size, 1)。emb[None, :]
将 emb 增加一个维度,变为形状 (1, half_dim)。- 乘法操作利用广播机制,将每个 x 的数值与 emb 的每个元素相乘,得到形状为 (batch_size, half_dim) 的张量 emb。
- 意义:这一步将输入
x
扩展到不同频率尺度上,生成位置编码的原始值。 - 示例:
- 假设 x = tensor([1.0, 2.0]),形状为 (2,),half_dim=4,且前面计算得到 emb ≈ [1.0, 0.6065, 0.3679, 0.2231]。
- 则 x[:, None] 为 [[1.0], [2.0]];乘法后得到:
- 第一个样本: [1.0*1.0, 1.0*0.6065, 1.0*0.3679, 1.0*0.2231] = [1.0, 0.6065, 0.3679, 0.2231].
- 第二个样本: [2.0, 1.2130, 0.7358, 0.4462].
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
- 作用:
- 计算 emb 的正弦值和余弦值,分别得到两个张量,形状均为 (batch_size, half_dim)。
- 使用 torch.cat 将两个张量在最后一维(dim=-1)上拼接,得到形状为 (batch_size, half_dim*2),也就是 (batch_size, self.dim)。
- 意义:生成最终的正弦位置编码,其中每个输入 x 被映射为一个包含正弦和余弦分量的向量。
- 示例:
- 继续前面例子,对于第一个样本,其 emb 原始值为 [1.0, 0.6065, 0.3679, 0.2231]。
- sin 部分 ≈ [0.8415, 0.5698, 0.3590, 0.2210],cos 部分 ≈ [0.5403, 0.8200, 0.9320, 0.9750](数值仅为示例)。
- 拼接后得到向量: [0.8415, 0.5698, 0.3590, 0.2210, 0.5403, 0.8200, 0.9320, 0.9750],总维度为 8(当 self.dim=8,half_dim=4)。
return emb
- 作用:返回生成的正弦位置编码 emb。
- 输出:张量形状为 (batch_size, self.dim),数值在 [-1,1](正弦和余弦函数的输出范围),用于为输入的每个位置提供位置信息。
总体说明
- 大函数/大类意义:
- SinusoidalPosEmb 类:
- 目的:生成正弦位置嵌入,用于将标量位置信息 x 映射为高维嵌入向量,便于 Transformer 和其他序列模型捕捉序列中各位置之间的关系。
- 输入:在 forward 方法中输入 x,通常是一个形状为 (batch_size,) 的张量,代表位置或者时间步标量。
- 输出:输出一个形状为 (batch_size, dim) 的张量,其中每个位置被编码为一组正弦和余弦值,满足 Transformer 论文中提出的公式。
- 设计原因:
- 位置编码能为模型提供序列中位置信息,弥补模型(如 Transformer)自身对顺序敏感性不足的缺陷。
- 通过固定的正弦和余弦函数,生成的编码具有平滑性和周期性,同时允许模型直接学习如何结合位置与输入信息。
- SinusoidalPosEmb 类:
十九、类Downsample1d
,继承自nn.Module
class Downsample1d(nn.Module):
- 作用:创建一个名为 Downsample1d 的新类,并继承自 PyTorch 的 nn.Module。
- 意义:使该模块能够利用 nn.Module 的所有特性,如参数管理、设备切换以及自动梯度计算。
- 示例:在构造神经网络时,可以将 Downsample1d 嵌入网络结构,调用其 forward() 方法来处理一维数据。
十九.1 def __init__()
def __init__(self, dim):
- 作用:定义初始化方法,接收一个参数 dim。
- 意义:dim 决定了输入和输出通道数(channels),保证下采样层保持通道数量不变。
- 示例:如果构造时传入 dim=64,则该模块处理 64 通道的一维信号。
super().__init__()
- 作用:调用 nn.Module 的构造函数来初始化基础模块属性。
- 意义:确保所有继承自 nn.Module 的必要初始化步骤都已完成,如注册子模块、参数等。
- 示例:这是标准的写法,确保 Downsample1d 能正确注册其内部参数,如 conv 中的卷积核参数。
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
- 作用:创建一个一维卷积层,将其赋值给实例变量
self.conv
。 - 输入通道数:dim
- 示例:如果 dim=64,则输入通道为 64。
- 输出通道数:dim
- 示例:输出通道也为 64,保证特征维数不变。
- 卷积核大小 (kernel_size):3
- 意义:每次卷积计算考虑 3 个连续的采样点。
- 示例:若输入序列为 [x1, x2, x3, x4, …],在某次卷积时计算窗口可能为 [x1, x2, x3]。
- 步幅 (stride):2
- 意义:卷积核每次移动 2 个步长,从而实现下采样,降低序列长度约为原始的一半。
- 示例:如果原始长度为 100,使用 stride=2 大致会得到长度约为 50(还会受 padding 和 kernel_size 的影响)。
- padding:1
- 意义:在输入序列两侧各补 1 个值(通常为 0),以保持边缘信息并对齐输出尺寸。
- 示例:对于输入序列 [x1, x2, x3, …],经过 padding 处理变为 [0, x1, x2, x3, …, xn, 0]。
- 整体意义:此卷积层对一维数据进行下采样,同时保持通道数不变,常用于信号抽样、特征压缩。
十九.2 def forward()
def forward(self, x):
- 作用:实现模块在输入 x 经过模块后的前向计算。
- 输入:x 应为形状为 (batch_size, dim, L) 的张量,其中 L 为输入一维序列长度。
- 示例:若 batch_size=32,dim=64,L=100,则 x.shape 为 (32, 64, 100)。
return self.conv(x)
- 作用:将输入 x 传入之前定义的 self.conv 卷积层,并返回其输出。
- 细节说明:
- 对输入一维数据执行下采样卷积,计算过程由 nn.Conv1d 内部自动实现。
总体说明
- Downsample1d 类整体设计:
- 目的:实现一个用于一维信号下采样的模块,通过卷积操作减少序列长度,同时保持特征通道数不变。
- 输入:形状为 (batch_size, dim, L) 的张量,其中 dim 是特征数,L 是原始长度。
- 输出:形状为 (batch_size, dim, L_out) 的张量,L_out 根据 stride、kernel_size 和 padding 计算得到,通常约为原序列长度的一半。
- 应用场景:在构建神经网络模型(如用于时间序列、语音信号或具身智能中的传感器数据)时,下采样能够减少计算量、提取更高层次的特征,同时压缩数据表示。
二十、类Upsample1d
,继承自nn.Module
class Upsample1d(nn.Module):
- 作用:声明一个名为
Upsample1d
的类,它继承自 PyTorch 的nn.Module
。 - 意义:使该模块具备 PyTorch 模块的所有特性,如参数注册、自动求导、设备管理等。
- 示例:在构造神经网络时,可以直接将 Upsample1d 加入网络结构,并利用
model.to(device)
将其放入 GPU。
二十.1 def __init__()
def __init__(self, dim):
- 作用:定义初始化方法,接收一个参数
dim
。 - 意义:
dim
指定了输入和输出信号的通道数;例如,如果输入信号有 64 个通道,则 dim=64。 - 示例:假设我们构造时传入
dim=64
,那么该模块将处理 64 通道的一维数据。
super().__init__()
- 作用:调用
nn.Module
的初始化方法,完成必要初始化。 - 意义:确保所有 nn.Module 内部的初始化(如注册子模块等)能够正常进行。
- 示例:这是一行标准代码,用于保证在后续调用
Upsample1d
的参数方法如parameters()
时,内部参数能被正确识别。
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
注意这里是转置卷积!!!
- 作用:创建一个一维转置卷积(反卷积)层,并将其赋值给实例变量
self.conv
。 - 输入通道数:dim
- 例如,如果
dim=64
,输入通道数即为 64。
- 例如,如果
- 输出通道数:dim
- 同样为 64,确保通道数不变。
- 卷积核大小 (kernel_size):4
- 意味着滤波器覆盖 4 个时间步的数据。
- 步幅 (stride):2
- 每次卷积核移动 2 个时间步,这将使得输出序列长度大约为输入序列长度的 2 倍,从而实现上采样(放大)。
- padding:1
- 在进行转置卷积计算时,将边缘填充1个单位,有助于对齐输出。
- 数值示例:
- 假设输入张量的形状为 (batch_size=32, dim=64, L=50);
- 根据转置卷积的公式,其输出长度为
L o u t = ( L − 1 ) × stride − 2 × padding + kernel_size L_{out} = (L - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size} Lout=(L−1)×stride−2×padding+kernel_size
代入数值:
L o u t = ( 50 − 1 ) × 2 − 2 × 1 + 4 = 98 − 2 + 4 = 100 L_{out} = (50 - 1) \times 2 - 2 \times 1 + 4 = 98 - 2 + 4 = 100 Lout=(50−1)×2−2×1+4=98−2+4=100 - 输出张量形状为 (32, 64, 100)。
- 意义总结:该转置卷积层用于一维数据的上采样,将输入序列长度扩大一倍,同时保持通道数不变,常用于生成模型、信号重构或特征尺度恢复等场景。
二十.2 def forward()
def forward(self, x):
- 作用:实现模块的前向计算,输入 x。
- 输入说明:
- x 应为形状为 (batch_size, dim, L) 的一维数据张量。
- 示例:假设 x.shape = (32, 64, 50)。
- 意义:forward 方法是 nn.Module 的标准接口,定义了模块如何处理输入数据。
return self.conv(x)
- 作用:将输入 x 传给之前定义的转置卷积层 self.conv,并返回其输出。
- 数值示例:
- 若 x 的形状为 (32, 64, 50);经过 self.conv 处理后,输出形状为 (32, 64, 100)(如前面所计算)。
- 意义:完成上采样操作,将较短的序列扩展成更长的序列,方便后续特征融合或恢复原始时间分辨率。
总体说明
- Upsample1d 类整体作用:
- 目的:实现对一维数据的上采样,通过使用转置卷积(nn.ConvTranspose1d)来增加序列长度。
- 输入:形状为 (batch_size, dim, L) 的数据张量,其中 dim 由构造函数参数给定,L 为原始序列长度。
- 输出:形状为 (batch_size, dim, L_out) 的数据张量,其长度 L_out 根据卷积参数计算(例如,上采样使长度从 L=50 变为 L_out=100)。
- 应用场景:在神经网络中,当需要将低分辨率的特征图恢复至较高分辨率(如图像生成、语音信号重构或具身智能传感器数据恢复)时,常使用上采样层。
- 具体参数选择:
- 核大小 4、步幅 2、padding 1 的组合能够近似使输出长度为输入长度的 2 倍,同时保持连续性和平滑性。
二十一、类Conv1dBlock
,继承自nn.Module
class Conv1dBlock(nn.Module):
- 作用:声明一个名为
Conv1dBlock
的类,并继承自 PyTorch 的nn.Module
。 - 意义:使该模块拥有 PyTorch 模块的所有特性,如参数自动注册、设备迁移、自动求导等。
- 示例:在构造神经网络时,可以将
Conv1dBlock
作为基础构建块嵌入到网络中。
'''
Conv1d --> GroupNorm --> Mish
'''
- 作用:说明该模块的主要组成顺序:先经过一维卷积,再进行 GroupNorm 归一化,最后使用 Mish 激活函数。
- 意义:帮助阅读代码的人快速了解模块设计的结构与功能。
- 示例:例如输入数据先经过卷积提取特征,再用 GroupNorm 归一化各通道特征,接着通过 Mish 激活提供非线性变换。
二十一.1 def __init__()
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
- 作用:定义初始化方法,接收 4 个参数:
inp_channels
: 输入通道数。例如,对于 1D 信号,如果有 32 个特征,则 inp_channels=32。out_channels
: 输出通道数,例如可能设为 64。kernel_size
: 卷积核的大小,决定滤波器覆盖连续数据的窗口长度。例如 kernel_size=3 表示每次处理连续 3 个时间步。n_groups
: 分组归一化中分组数,默认设置为 8。
- 意义:这些参数控制模块的主要特征变换,保证输入和输出特征维度以及卷积窗口大小可调节。
- 示例:可以调用
Conv1dBlock(32, 64, 3)
来构造一个从 32 通道映射到 64 通道、卷积核大小为 3 的模块。
super().__init__()
- 作用:调用
nn.Module
的构造函数,完成所有必要的初始化工作。 - 意义:这是所有自定义模块必须执行的步骤,确保内部参数和子模块正确注册。
- 示例:标准写法,用于保证后续
.parameters()
、.to(device)
能正确工作。
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
-
作用:利用
nn.Sequential
按顺序组合多个层,并赋值给实例变量self.block
。 -
详细说明:
- 第 1 个子模块:nn.Conv1d
- 构造一维卷积层,将输入通道数设为
inp_channels
,输出通道数为out_channels
。 - kernel_size:指定卷积核大小。
- padding:设置为
kernel_size // 2
,保证“same”卷积效果(输入与输出长度大致相同)。 - 示例:如果
inp_channels=32
、out_channels=64
、kernel_size=3
,则 padding = 3 // 2 = 1。
对于输入张量形状 (batch_size, 32, L),经过卷积后输出形状为 (batch_size, 64, L)(假设 stride 默认为 1)。
- 构造一维卷积层,将输入通道数设为
- 第 2 个子模块:nn.GroupNorm
- 对输出特征进行归一化。
- 参数:第一个参数是分组数
n_groups
(例如 8),第二个参数是要归一化的通道数,即out_channels
(例如 64)。 - 内部机制:将 64 个通道划分成 8 组,每组 8 个通道,对每组计算均值和方差进行归一化。
- 意义:GroupNorm 有助于在小批量训练中保持训练稳定性,比 BatchNorm 对批次大小依赖更小。
- 第 3 个子模块:nn.Mish
- 应用 Mish 激活函数,公式为:
Mish ( x ) = x × tanh ( ln ( 1 + e x ) ) \text{Mish}(x) = x \times \tanh(\ln(1 + e^x)) Mish(x)=x×tanh(ln(1+ex)) - 意义:Mish 是一种平滑、非单调的激活函数,能带来更好的梯度传递和性能,常见于一些最新的网络架构中。
- 应用 Mish 激活函数,公式为:
- 第 1 个子模块:nn.Conv1d
-
整体意义:该组合模块依次完成了特征提取(通过卷积)、归一化(GroupNorm 降低内部协变量偏移)和非线性激活(Mish 提供丰富表达能力)。
-
示例:给定输入张量 x 形状为 (batch_size=16, 32, 100),输出经过 Conv1d 后变为 (16, 64, 100),GroupNorm 与 Mish 不改变形状,最终输出仍为 (16, 64, 100)。
二十一.2 def forward()
def forward(self, x):
- 作用:重写前向传播方法,定义模块如何处理输入 x。
- 输入要求:x 应为一个张量,形状通常为 (batch_size, inp_channels, L)。
- 示例:如果 x 的形状为 (16, 32, 100),表示 16 个样本,每个样本有 32 个通道,长度为 100。
return self.block(x)
- 作用:将输入 x 依次通过
self.block
中的卷积、GroupNorm 和 Mish 层,并将结果返回。 - 意义:封装好的组合模块简化了前向传播代码,保持模块可重复调用。
- 具体数值示例:
- 输入 x 形状 (16, 32, 100)。
- 经 nn.Conv1d 后输出 (16, 64, 100)(因为 kernel_size=3,padding=1,stride=1 保持长度不变)。
- GroupNorm 归一化 (16, 64, 100);Mish 激活后数据保持形状 (16, 64, 100)。
- 最终返回一个形状为 (16, 64, 100) 的张量。
总体说明
- Conv1dBlock 模块主要用途:
- 目的:为一维序列数据构建一个标准化的特征变换模块,顺序执行卷积操作、归一化和激活操作。
- 输入:一个形状为 (batch_size, inp_channels, L) 的一维信号张量,如 (16, 32, 100)。
- 输出:处理后的张量形状为 (batch_size, out_channels, L),如 (16, 64, 100)。
- 设计原因:
- 卷积层 用于捕捉局部时序模式和提取特征;
- GroupNorm 可以在各种批量大小下保持稳定训练效果;
- Mish 激活 提供平滑且高效的非线性变换,增强模型表达能力。
- 整体结构:利用
nn.Sequential
将各层组合起来,使得模块调用简单、易于复用。
二十二、类ConditionalResidualBlock1D
,继承自nn.Module
- in_channels = 32
- out_channels = 64
- cond_dim = 10
- kernel_size = 3
- n_groups = 8
整体来说,该模块实现了一个条件残差块,其流程是:
- 使用两个 1D 卷积块(每个块内部包含 Conv1d → GroupNorm → Mish)对输入信号进行非线性变换;
- 对条件信息(cond)通过一个条件编码器进行处理,输出一个形状能拆分为 per-channel 的 scale 和 bias(FiLM 调制);
- 将中间表示与条件信息做 “scale×feature+bias” 调制;
- 将调制后的结果再通过第二个卷积块处理;
- 与输入(经过 1×1 卷积调整尺寸,如必要)相加,构成残差连接。
class ConditionalResidualBlock1D(nn.Module):
- 这行声明了一个名为
ConditionalResidualBlock1D
的类,并继承自 PyTorch 的nn.Module
。 - 意义:该模块可作为神经网络中的一个层,利用 nn.Module 的特性自动管理参数和计算图。
- 示例:该模块可以嵌入深度网络中,用于条件生成任务或条件控制任务。
二十二.1 def __init__()
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8):
- 定义初始化方法,接收 5 个参数,其中 in_channels、out_channels、cond_dim 必须传入,kernel_size 默认为 3,n_groups 默认为 8。
- 示例:传入 in_channels=32, out_channels=64, cond_dim=10。
- 意义:这些参数确定输入输出的通道数、条件信息的维度以及卷积核大小和 GroupNorm 组数,为模块内部各层构建提供依据。
super().__init__()
- 调用 nn.Module 的初始化函数,确保模块内部状态和参数注册正常。
- 意义:这是所有 nn.Module 子类必须执行的初始化步骤。
- 示例:保证后续调用 self.parameters() 能包含模块中的卷积和线性层参数。
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
])
- 使用 nn.ModuleList 存储两个 Conv1dBlock 模块。
- 第一个 Conv1dBlock 接受输入通道数 in_channels(32)并输出 out_channels(64);
- 例如:输入 x 的形状 (batch_size, 32, L) 经过后变为 (batch_size, 64, L)(卷积核 3、padding=kernel_size//2 保持长度不变)。
- 第二个 Conv1dBlock 将通道数保持为 64,即输入和输出都为 64。
- 意义:组合多个卷积块进行特征变换,为后续条件调制做准备。
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
- 描述接下来使用 FiLM(Feature-wise Linear Modulation)策略,将条件向量转换为每个通道的缩放(scale)与偏置(bias)参数。
cond_channels = out_channels * 2
- 将 out_channels 乘以 2 得到 cond_channels。
- 示例:out_channels 为 64,则 cond_channels = 64 * 2 = 128。
- 意义:因为 FiLM 需要为每个通道输出一对参数 (scale 和 bias),因此输出维度翻倍。
self.out_channels = out_channels
- 将参数 out_channels 保存到实例变量 self.out_channels,便于后续使用。
- 示例:self.out_channels 将存储 64。
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
nn.Unflatten(-1, (-1, 1))
)
- 使用 nn.Sequential 依次组合三层:
- nn.Mish():应用 Mish 激活函数,对条件输入先进行非线性变换;
- nn.Linear(cond_dim, cond_channels):全连接层将输入维度 cond_dim(例如 10)映射到 cond_channels(128);
- nn.Unflatten(-1, (-1, 1)):将输出的最后一维重塑,将 128 个特征拆分为两个维度。
- 示例:
- 若 cond 输入为形状 (batch_size, 10),nn.Linear 输出 (batch_size, 128);
- nn.Unflatten(-1, (-1, 1)) 将 (batch_size, 128) 变为 (batch_size, 128, 1)。
- 后续会对 embed 进行 reshape,将 128 重新组织为 (batch_size, 2, out_channels, 1)(因为 2×64=128)。
- 意义:条件编码器将条件信息 cond 转换为用于 FiLM 调制的 scale 和 bias 参数。
# make sure dimensions compatible
- 提示下面的代码用于保证输入与输出维度相匹配,以便进行残差相加。
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
- 使用条件表达式:如果 in_channels 与 out_channels 不同,则使用 1×1 卷积将输入通道数调整为 out_channels;否则,使用 nn.Identity() 保持原样。
- 示例:
- 如果 in_channels=32且 out_channels=64(本例),那么 self.residual_conv 为 nn.Conv1d(32, 64, kernel_size=1);
- 意义:确保残差连接中原始输入 x 经调整后与经过主分支的输出尺寸一致,以便直接相加。
二十二.2 def forward()
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
- 接收两个输入:
- x:形状 [batch_size x in_channels x horizon],例如 (batch_size, 32, L);
- cond:形状 [batch_size x cond_dim],例如 (batch_size, 10)。
- 意义:x 是一维数据输入,cond 包含条件信息(例如任务、时间信息等)。
out = self.blocks[0](x)
- 将输入 x 传入第一个 Conv1dBlock 进行处理。
- 示例:
- 假设 x.shape = (batch_size, 32, L);
- 第一个 Conv1dBlock 将 32 通道映射为 64 通道,并保持时间步长 L不变(由于 kernel_size=3 和 padding=1),所以 out.shape = (batch_size, 64, L).
- 意义:初步提取特征并将通道数提升到 out_channels。
embed = self.cond_encoder(cond)
- 将条件向量 cond(形状 (batch_size, cond_dim))传入 cond_encoder 模块。
- 示例:
- 若 cond.shape = (batch_size, 10);
- 经过 nn.Linear 后映射为 (batch_size, 128),再经 Unflatten 得到 (batch_size, 128, 1)。
- 最终 embed.shape = (batch_size, 128, 1).
- 意义:将条件信息转换为一个包含每个通道需要调制的参数信息的向量。
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
- 将 embed 重塑为 (batch_size, 2, out_channels, 1)。
- 细节:
- embed 原始 shape 为 (batch_size, 128, 1);
- 因为 128 = 2 × 64(即 2 × out_channels),reshape 后拆分为两个部分,其中第一维 2 表示分别为 scale 与 bias。
- 示例:
- 如果 batch_size=32,则重塑后的 embed.shape = (32, 2, 64, 1).
- 意义:为后续 FiLM 调制做好数据组织,分离出 scale 和 bias 参数。
scale = embed[:,0,...]
- 从 embed 中提取第 0 个维度,即 scale,结果保留所有 batch 数据及通道。
- 示例:
- 若 embed.shape = (32, 2, 64, 1),则 scale.shape = (32, 64, 1).
- 意义:scale 用于对中间卷积输出进行逐通道缩放。
bias = embed[:,1,...]
- 同理,从 embed 提取第 1 个维度作为 bias。
- 示例:
- 得到 bias.shape = (32, 64, 1)。
- 意义:bias 用于逐通道进行偏置补偿。
out = scale * out + bias
- 对第一个卷积块的输出 out 进行条件调制:每个通道的特征乘以对应 scale 并加上 bias。
- 细节与广播:
- out.shape 为 (batch_size, 64, L);
- scale 和 bias 的形状为 (batch_size, 64, 1),自动广播到 (batch_size, 64, L)。
- 示例:
- 若某通道在某时间步的值为 0.5,scale 对应为 1.2,bias 为 -0.1,则更新后值为 1.2*0.5 + (-0.1) = 0.5.
- 意义:引入条件信息调制特征,使网络能够根据外部条件动态调整内部特征表示,这就是 FiLM(Feature-wise Linear Modulation)的核心。
out = self.blocks[1](out)
- 将经过条件调制的 out 传入第二个 Conv1dBlock 进行进一步非线性特征提取。
- 示例:
- 输入 out 形状 (batch_size, 64, L),输出依然为 (batch_size, 64, L)(因该卷积块保持通道数和长度)。
- 意义:进一步融合和非线性变换条件调制后的特征。
out = out + self.residual_conv(x)
- 将经过第二个卷积块的输出 out 与原始输入 x 经过 residual_conv 处理后的结果相加。
- 细节:
- 如果 in_channels != out_channels(例如 32 ≠ 64),self.residual_conv 为一个 1×1 卷积,将 x 形状从 (batch_size, 32, L) 调整为 (batch_size, 64, L);
- 如果 in_channels 与 out_channels 相同,则 self.residual_conv 为 nn.Identity(),直接返回 x。
- 示例:
- 在本例中,x.shape 为 (batch_size, 32, L),经过 residual_conv 转换后变为 (batch_size, 64, L);
- 与 out (形状 (batch_size, 64, L))按元素相加,实现残差连接,帮助梯度流动和稳定训练。
- 意义:残差连接有助于保持输入信息,使得网络更易训练,同时提升性能。
return out
- 这行返回最终计算结果 out。
- 输出形状为 (batch_size, out_channels, horizon);例如 (32, 64, L)。
- 意义:返回条件残差块的输出,已经包含了卷积非线性变换、条件调制(FiLM)和残差连接的效果。
总体说明
- ConditionalResidualBlock1D 模块总体作用:
- 实现了一个条件残差块,用于 1D 数据(如时间序列或语音信号)的处理。在该模块中,输入 x 经过两次 Conv1dBlock 非线性变换,并在中间通过 FiLM 模块(条件编码器产生 per-channel scale 和 bias)将条件信息 cond 注入到特征中,最后通过残差连接将原始信息叠加到输出上。
- 输入:
- x: 张量形状 [batch_size, in_channels, horizon](例如 (32,32,L))。
- cond: 张量形状 [batch_size, cond_dim](例如 (32, 10))。
- 输出:
- out: 张量形状 [batch_size, out_channels, horizon](例如 (32,64,L)),是经过条件调制和残差连接的特征表示。
- 设计原因:
- 利用 FiLM 调制(scale * feature + bias),将外部条件信息引入卷积处理过程,使得网络在不同条件下能调整内部表示;
- 采用残差连接,有助于缓解梯度消失问题和保留初始信息;
- 通过两个连续的 Conv1dBlock 提高非线性特征提取能力。
二十三、类ConditionalUnet1D
,继承自nn.Module
- input_dim = 10
- global_cond_dim = 20
- diffusion_step_embed_dim = 256
- down_dims = [256, 512, 1024]
- kernel_size = 5
- n_groups = 8
该模块实现了一个条件 UNet,用于 1D 数据(例如时间序列或语音)上采样与下采样,并在各层引入条件信息,条件信息由两部分组成:
- “扩散步”位置编码(diffusion_step_encoder)
- 全局条件(global_cond),通常由多个观察值拼接而成。
模块流程大致为:
- 使用一个位置编码模块对扩散步(timestep)编码;
- 将编码结果与全局条件拼接,构成条件向量(cond_dim = diffusion_step_embed_dim + global_cond_dim);
- 通过下采样路径(down_modules)逐级下采样并保存中间特征;
- 经过两个“中间”残差块(mid_modules);
- 再通过上采样路径(up_modules)逐级上采样,并与下采样路径特征通过跳跃连接融合;
- 最后通过 final_conv 输出最终结果。
class ConditionalUnet1D(nn.Module):
- 作用:定义一个名为 ConditionalUnet1D 的类,继承自 PyTorch 的 nn.Module。
- 意义:使该模块能够像其他神经网络层一样被调用、管理参数、支持自动梯度。
- 示例:在模型中调用时,可写
model = ConditionalUnet1D(...)
。
二十三.1 def __init__()
def __init__(self,
input_dim,
global_cond_dim,
diffusion_step_embed_dim=256,
down_dims=[256,512,1024],
kernel_size=5,
n_groups=8
):
"""
input_dim: Dim of actions.
global_cond_dim: Dim of global conditioning applied with FiLM
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
down_dims: Channel size for each UNet level.
The length of this array determines numebr of levels.
kernel_size: Conv kernel size
n_groups: Number of groups for GroupNorm
"""
- 作用:定义初始化方法,接收 6 个参数:
input_dim
: 输入数据(通常为动作)的维度,假设 input_dim=10。global_cond_dim
: 全局条件的维度,假设 global_cond_dim=20;如全局观测的拼接结果。diffusion_step_embed_dim
: 扩散步位置编码的嵌入维度,默认 256。down_dims
: 一个列表,表示 UNet 每个下采样层后的通道数,这里为 [256, 512, 1024]。列表长度决定了 UNet 的层数。kernel_size
: 卷积核大小,这里为 5。n_groups
: 用于 GroupNorm 的分组数,这里为 8。
- 意义:这些参数定义了整个 UNet 模块的结构配置和条件信息的大小。
super().__init__()
- 作用:调用父类 nn.Module 的构造函数。
- 意义:确保父类内部初始化正确,注册子模块、参数等。
all_dims = [input_dim] + list(down_dims)
- 作用:构造 all_dims 列表,将输入维度放在最前面,后接每个 UNet 层的通道数。
- 示例:input_dim=10 且 down_dims=[256,512,1024],则 all_dims = [10, 256, 512, 1024].
- 意义:方便后续计算每个下采样层的输入和输出通道数。
start_dim = down_dims[0]
- 作用:将 down_dims 的第一个值保存为 start_dim。
- 示例:start_dim = 256。
- 意义:最后的 final_conv 部分将使用 start_dim 来还原维度。
dsed = diffusion_step_embed_dim
- 作用:将 diffusion_step_embed_dim 保存到局部变量 dsed,便于后续书写。
- 示例:dsed = 256。
- 意义:简化代码书写,后续所有与扩散步编码相关的尺寸均使用 dsed。
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
- 作用:构造一个扩散步位置编码模块,使用 nn.Sequential 顺序组合以下几层:
SinusoidalPosEmb(dsed)
: 将扩散步(timestep)编码为 dsed 维向量,采用正弦位置编码方法。nn.Linear(dsed, dsed * 4)
: 全连接层将向量扩展到 4 倍大小,即 256→1024。nn.Mish()
: 使用 Mish 激活函数引入非线性。nn.Linear(dsed * 4, dsed)
: 再将维度降回 dsed(256)。
- 示例:若输入 timestep 标量(或向量),经过 SinusoidalPosEmb 后输出形状 (B, 256),经过两个全连接层和激活后仍保持 (B,256)。
- 意义:用于将扩散步信息编码成一个固定维度的嵌入,便于后续与 global_cond 拼接;这种设计借鉴 Transformer 中位置编码的思想。
cond_dim = dsed + global_cond_dim
- 作用:计算条件向量的总维度,等于扩散步编码维度 dsed 加上全局条件维度 global_cond_dim。
- 示例:dsed=256,global_cond_dim=20,则 cond_dim = 256 + 20 = 276.
- 意义:后续在 ConditionalResidualBlock1D 中会用到该条件向量,作为 FiLM 调制输入。
in_out = list(zip(all_dims[:-1], all_dims[1:]))
- 作用:构造 in_out 列表,将 all_dims 中连续的两个元素配对,表示每层 UNet 下采样模块的输入和输出通道数。
- 示例:all_dims = [10, 256, 512, 1024],则 zip(all_dims[:-1], all_dims[1:]) 得到 [(10,256), (256,512), (512,1024)]。
- 意义:便于后续循环构建下采样模块,使得每层能够知道输入和输出通道数。
mid_dim = all_dims[-1]
- 作用:将 all_dims 最后一个元素赋值给 mid_dim。
- 示例:mid_dim = 1024。
- 意义:中间部分(最底部)的通道数为 mid_dim,用于构建中间模块(mid_modules)。
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
])
- 作用:构造两个中间模块,将它们包装在 nn.ModuleList 中,保存到 self.mid_modules。
- 每个模块均为 ConditionalResidualBlock1D,其输入和输出通道均为 mid_dim(1024),条件维度为 cond_dim(276),卷积核大小为 kernel_size(5),GroupNorm 分组数为 n_groups(8)。
- 示例:输入 shape 为 (B, 1024, L_down),输出保持 (B, 1024, L_down)。
- 意义:中间模块用于在最粗尺度下处理特征(连接下采样和上采样部分),增强非线性特征表达,同时引入条件信息。
down_modules = nn.ModuleList([])
- 作用:初始化一个空的 nn.ModuleList,用于存储下采样模块。
- 意义:后续将依次构造每一层下采样块并添加到该列表中。
for ind, (dim_in, dim_out) in enumerate(in_out):
- 作用:遍历 in_out 列表,获取每一层下采样块的输入和输出通道数,同时记录索引 ind。
- 示例:对于 in_out = [(10,256), (256,512), (512,1024)],循环依次得到:
- ind=0, (dim_in, dim_out) = (10,256)
- ind=1, (dim_in, dim_out) = (256,512)
- ind=2, (dim_in, dim_out) = (512,1024)
- 意义:为构造每个下采样模块提供必要的通道参数和层级索引。
is_last = ind >= (len(in_out) - 1)
- 作用:判断当前下采样层是否为最后一层。
- 示例:
- 当 ind=2 且 len(in_out)=3,则 is_last = (2>=2) → True;
- 对于 ind=0,1,则 is_last=False。
- 意义:在最后一层下采样时,不进行下采样操作(使用 nn.Identity() 代替 Downsample1d)。
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
- 作用:为当前层构造一个下采样模块,并添加到 down_modules 列表中,该模块包含三个子模块组成一个 nn.ModuleList:
- 第一层:ConditionalResidualBlock1D,将输入通道 dim_in 映射到 dim_out,条件维度为 cond_dim。
- 第二层:ConditionalResidualBlock1D,将通道维度保持为 dim_out。
- 第三层:如果当前层不是最后一层,则使用 Downsample1d(dim_out) 进行下采样(通常减半时间步长度);否则使用 nn.Identity(),保持尺寸不变。
- 示例:
- 对于第一层 (ind=0, (10,256)):
- 第一个块将 (B, 10, L) → (B, 256, L);
- 第二个块: (B, 256, L) → (B, 256, L);
- Downsample1d(dim_out=256) 对 (B,256,L) 进行下采样,例如 L 从 100 降到约 50。
- 对于最后一层 (ind=2, (512,1024)):
- 最后一步使用 nn.Identity(),不改变时间步长度。
- 对于第一层 (ind=0, (10,256)):
- 意义:下采样路径构建了多个层级,每一层通过条件残差块增强特征表达并进行下采样,形成多尺度特征表示,同时保存中间特征用于后续上采样跳跃连接。
up_modules = nn.ModuleList([])
- 作用:初始化一个空的 nn.ModuleList 用于存储上采样模块。
- 意义:后续循环构造上采样路径模块。
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
- 作用:遍历 in_out 列表中除第一个元素外的部分,并反向排列,用于构造上采样模块。
- 示例:
- 原 in_out = [(10,256), (256,512), (512,1024)],in_out[1:] = [(256,512), (512,1024)];
- reversed(in_out[1:]) 得到 [(512,1024), (256,512)];
- 循环依次:
- ind=0, (dim_in, dim_out) = (512,1024)
- ind=1, (dim_in, dim_out) = (256,512)
- 意义:上采样路径一般与下采样路径对称,跳跃连接顺序需逆序。
is_last = ind >= (len(in_out) - 1)
- 作用:判断当前上采样模块是否为最后一个模块。
- 示例:
- 对于 reversed(in_out[1:]) 长度为 2;当 ind=1,则 is_last=True;当 ind=0,则 is_last=False.
- 意义:在最后一层不进行上采样操作,而使用 nn.Identity() 保持尺寸。
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out*2, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
- 作用:构造每个上采样模块,并添加到 up_modules 中,其模块包括:
- 第一层:ConditionalResidualBlock1D,将输入通道为 dim_out*2 (因为将与对应下采样阶段的特征拼接)映射到 dim_in。
- 第二层:ConditionalResidualBlock1D,将通道数从 dim_in 保持不变。
- 第三层:如果不是最后一层则使用 Upsample1d(dim_in) 进行上采样,恢复时间步长,否则使用 nn.Identity()。
- 示例:
- 对于第一层上采样模块(ind=0, (dim_in, dim_out) = (512,1024)):
- 输入拼接后通道数为 1024*2。注意:在上采样阶段,上采样的输入经过下采样路径跳跃连接获得,与当前上采样输入拼接,故通道数是 dim_out*2 = 1024*2
- 例如:当前下采样特征 h[-1]具有通道=1024,x 当前经过中间模块也有 1024,因此拼接后通道=2048;第一层上采样块设计中将 2048 映射到 dim_in=512。
- 输入拼接后通道数为 1024*2。注意:在上采样阶段,上采样的输入经过下采样路径跳跃连接获得,与当前上采样输入拼接,故通道数是 dim_out*2 = 1024*2
- 上采样模块最后使用 Upsample1d(dim_in) 则将特征维持 512 通道,且上采样时间步数加倍(假设非最后一层)。
- 对于第一层上采样模块(ind=0, (dim_in, dim_out) = (512,1024)):
- 意义:上采样路径通过跳跃连接融合下采样层的特征,并借助条件残差块逐步恢复高分辨率特征,同时引入条件信息调制。
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
- 作用:构造 final_conv 模块,用于将上采样后的特征进一步变换到最终输出。
- 第一个子层:Conv1dBlock,将 start_dim 通道(此处 start_dim = down_dims[0] = 256)保持不变(256→256),卷积核大小为 kernel_size(5)。
- 第二个子层:一个 1×1 卷积层,将通道数从 start_dim(256)映射回 input_dim(10)。
- 示例:
- 假设最终上采样输出 x 的形状为 (B, 256, L_final),经过 Conv1dBlock 后形状保持 (B, 256, L_final),再经过 nn.Conv1d 变换为 (B, 10, L_final)。
- 意义:使网络输出符合原始输入维度(例如动作维数或其他数据维度),完成 UNet 的逆变换。
self.diffusion_step_encoder = diffusion_step_encoder
- 作用:将前面构造的 diffusion_step_encoder 模块赋值到实例变量中。
- 意义:在 forward() 方法中利用该模块编码扩散步(timestep)信息。
- 示例:self.diffusion_step_encoder 是一个 nn.Sequential 模块。
self.up_modules = up_modules
- 作用:将上采样模块列表保存到实例变量 self.up_modules。
self.down_modules = down_modules
- 作用:将下采样模块列表保存到实例变量 self.down_modules。
self.final_conv = final_conv
- 作用:将 final_conv 模块保存到实例变量。
print("number of parameters: {:e}".format(
sum(p.numel() for p in self.parameters()))
)
- 作用:打印整个模块中参数的总数量,格式采用科学计数法。
- 示例:输出可能为 “number of parameters: 1.234568e+07”。
- 意义:便于调试和了解模型规模。
二十三.2 def forward()
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
global_cond=None):
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
global_cond: (B,global_cond_dim)
returns:
out: (B,T,input_dim)
- 作用:定义前向传播方法,接收三个输入:
sample
: 输入样本,形状 (B, T, input_dim);例如 (batch_size, time_steps, 10)。timestep
: 扩散步信息,可以是张量、浮点或整数;例如,若每个样本对应一个时间步,可能为 (B,) 张量或单个数值 100。global_cond
: 全局条件信息,形状 (B, global_cond_dim);例如 (B, 20)。
- 意义:将输入 sample 经过 UNet 处理,并结合扩散步和全局条件生成输出。
- 文档说明:下方多行注释说明各输入和输出形状。
# (B,T,C)
sample = sample.moveaxis(-1,-2)
# (B,C,T)
- 作用:将 sample 的最后两维交换;将数据从 (B,T,input_dim) 转为 (B,input_dim,T)。
- 示例:若 sample.shape 初始为 (32, 50, 10)(batch 32,50 时间步,10维),经过 moveaxis 后变为 (32,10,50)。
- 意义:符合 Conv1d 的输入要求((B, channels, length))。
# 1. time
timesteps = timestep
- 作用:将传入的 timestep 赋值给变量 timesteps。
- 意义:为后续统一处理 timestep。
if not torch.is_tensor(timesteps):
- 作用:检查 timesteps 是否为张量,如果不是则执行下面分支。
- 示例:如果 timestep 是一个整数 100,则条件为 True。
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
- 作用:注释说明,如果 timesteps 不是张量,后续转换可能涉及 CPU 和 GPU 之间的同步延迟。将非张量的 timestep 转换为张量。
- 示例:若 timestep 为 100,则转为 tensor([100]),设备与 sample 相同。
- 意义:确保后续操作统一在张量上执行。
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
- 作用:如果 timesteps 是零维张量,则增加一个维度,保证其为一维,并将其移动到 sample 的设备上。
- 示例:若 timesteps = tensor(100)(零维),则变为 tensor([100])。
- 意义:保证 timesteps 的形状一致便于后续操作。
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
- 作用:将 timesteps 扩展到与样本批次相同的长度。
- 示例:若 sample.shape[0] 为 32,则 timesteps 由 tensor([100]) 扩展为 tensor([100,100,…,100]),形状 (32,)。
- 意义:为每个样本提供相同的扩散步信息,兼容导出至 ONNX、Core ML 的要求。
global_feature = self.diffusion_step_encoder(timesteps)
- 作用:将扩散步信息 timesteps 传入 diffusion_step_encoder 得到扩散步嵌入。
- 示例:假设 timesteps 为 (32,) 全部为 100,经扩散步编码后得到 global_feature 的形状为 (32, 256)。
- 意义:将扩散时间信息编码为一个向量,作为后续条件的一部分。
if global_cond is not None:
- 作用:检查是否提供了全局条件 global_cond。
- 意义:如果有,则将其与 diffusion 步嵌入结合。
global_feature = torch.cat([
global_feature, global_cond
], axis=-1)
- 作用:将 global_feature 与 global_cond 在最后一维拼接。
- 示例:假设 global_feature.shape为 (32,256),global_cond.shape为 (32,20),拼接后 global_feature 变为 (32,276)。
- 意义:组合扩散步编码和全局条件信息,形成最终条件向量,维度应等于 cond_dim(256+20=276)。
x = sample
- 作用:将 sample 赋值给变量 x,作为后续网络输入。
- 意义:保留原输入 x 用于后续逐层变换。
h = []
- 作用:初始化空列表 h,用于保存下采样过程中的中间特征,便于跳跃连接使用。
- 意义:UNet 的典型设计,在上采样阶段融合下采样的特征。
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
- 作用:遍历 down_modules 中的每个子模块,每个子模块由三个部分组成:两个 ConditionalResidualBlock1D(命名为 resnet 和 resnet2)和一个 downsample 层。
- 示例:例如,第 0 个下采样模块对应于将特征从 10 维转换到 256 维(第一层)。
- 意义:依次通过每个下采样层提取多尺度特征。
x = resnet(x, global_feature)
- 作用:将输入 x 与条件向量 global_feature 输入到当前模块的第一个条件残差块。
- 示例:若 x 的形状为 (32, 10, L)(第一层),经过 resnet 后 shape 变为 (32, 256, L)。
- 意义:进行非线性变换同时注入条件信息。
x = resnet2(x, global_feature)
- 作用:接着将输出 x 传入当前模块的第二个条件残差块,同样应用全局条件调制。
- 示例:x 形状维持 (32, 256, L) 或相应层的通道数。
- 意义:加深特征表示,在同一层内进行两次条件调制。
h.append(x)
- 作用:将经过当前下采样层前的 x 保存到列表 h,用于后续上采样阶段的跳跃连接。
- 示例:假设下采样共有 3 层,则 h 最终保存 3 个特征图,形状分别为 (32,256,L0), (32,512,L1), (32,1024,L2)。
- 意义:保留多尺度特征,帮助上采样过程恢复细节。
x = downsample(x)
- 作用:将当前 x 传入 downsample 层进行下采样处理,通常会减半时间步长度。
- 示例:若 x 形状为 (32,256,100),经过 Downsample1d 后变为 (32,256,50)(假设步幅为2)。
- 意义:下采样提取抽象特征,并逐步扩大感受野。
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
- 作用:遍历中间模块,对 x 进行处理,每个 mid_module 均为 ConditionalResidualBlock1D。
- 示例:若 x 形状为 (32,1024,25)(最底层),经过两个中间模块后依然保持 (32,1024,25)。
- 意义:在最低分辨率的特征上进一步进行条件调制和特征提取,构成 UNet 中心部分。
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
- 作用:遍历上采样模块,每个上采样模块包含两个条件残差块和一个上采样层。
- 意义:上采样过程逐级恢复高分辨率特征,同时融合跳跃连接的信息。
x = torch.cat((x, h.pop()), dim=1)
- 作用:从 h 中取出最后保存的下采样特征(通过 h.pop()),与当前 x 在通道维度(dim=1)拼接。
- 示例:如果 x 当前形状为 (32,1024,25) 且 h.pop() 返回一个形状 (32,512,25) 的特征,则拼接后 x 形状为 (32,1024+512,25) = (32,1536,25)。
- 意义:跳跃连接确保上采样过程中保留下采样细节信息,帮助恢复空间细节。
x = resnet(x, global_feature)
- 作用:将拼接后的 x 和 global_feature 输入到当前上采样模块的第一个条件残差块处理。
- 示例:x 形状由前一层 (32,1536,25) 经过处理后变为 (32,dim_in,25),其中 dim_in 来自构造时计算的上采样模块第一层。
- 意义:对融合后的特征进行条件调制和非线性变换。
x = resnet2(x, global_feature)
- 作用:将经过第一块处理的 x 继续输入到当前上采样模块的第二个条件残差块。
- 意义:进一步提取特征,增加网络深度。
x = upsample(x)
- 作用:将 x 经过上采样层 upsample,该层为 Upsample1d 或 nn.Identity(如果是最后一层)。
- 示例:若 x 的形状为 (32,dim_in,25) 且采用 Upsample1d(dim_in) 将长度翻倍(假设步幅2),则 x 形状变为 (32,dim_in,50)。
- 意义:逐步恢复时间步长,使输出分辨率回到与输入相匹配。
x = self.final_conv(x)
- 作用:将上采样后的 x 输入到 final_conv 模块进行最后的卷积变换。
- 示例:假设 x 的形状为 (32, start_dim, L_final)(例如 (32,256,L_final)),经过 final_conv 后,先经过 Conv1dBlock 变为 (32,256,L_final),再经过 1×1 卷积映射到 (32, input_dim, L_final)(如 (32,10,L_final))。
- 意义:将特征映射还原为原始输入维度,完成 UNet 的输出阶段。
# (B,C,T)
x = x.moveaxis(-1,-2)
# (B,T,C)
- 作用:将 x 的最后两维交换,将形状从 (B, C, T) 转换为 (B, T, C)。
- 示例:若 x 形状为 (32, 10, L_final)(例如 (32,10,50)),经过 moveaxis 后变为 (32,50,10)。
- 意义:使输出与最初输入 sample 的形状一致(最初 sample 为 (B, T, input_dim))。
return x
- 作用:返回最终输出 x。
- 输出示例:例如 (32,50,10)。
- 意义:模块完成条件 UNet 处理,输出经过上采样和全局条件调制的结果。
总体说明
ConditionalUnet1D 模块
- 目的:构建一个条件 UNet 模型用于 1D 数据,结合扩散步位置编码和全局条件,通过下采样、中心处理和上采样构造多尺度特征表示,并最终还原至与输入相同的形状。
- 输入:
- sample: (B, T, input_dim),例如 (32,50,10) 的一维序列。
- timestep: (B,) 或单个数值,表示扩散过程的当前步,例如 100。
- global_cond: (B, global_cond_dim),例如 (32,20) 的全局条件向量。
- 输出:
- out: (B, T, input_dim) 与输入维度保持一致,经条件调制、下采样、上采样后的输出结果。
各部分作用:
- 扩散步编码(diffusion_step_encoder):
将 timestep 信息编码为一个向量,再与 global_cond 拼接构成条件向量。 - 下采样路径(down_modules):
多层次提取特征并降低时间分辨率,同时保存各层输出用于跳跃连接。 - 中间模块(mid_modules):
在最底层进行两次条件残差块处理,增强特征表示。 - 上采样路径(up_modules):
每一层融合对应下采样层保存的特征(跳跃连接),并进行上采样以恢复序列长度。 - 最终卷积(final_conv):
将特征映射调回原始 input_dim,得到最终输出。 - 全局条件调制:
在每个条件残差块中,通过 FiLM(scale 与 bias)将条件向量注入特征中,使得网络具有条件适应性。
设计意义:
- 这种结构借鉴 UNet 和条件网络的思路,适用于条件生成、扩散模型等任务,能够在多尺度上有效融合条件信息,提高模型表达和生成能力。