当前位置: 首页 > news >正文

通过渐进蒸馏实现扩散模型的快速采样

论文(2022'ICLR):Progressive Distillation for fast sampling of diffusion models

摘要

扩散模型用于生成建模前景光明,生成的样本比GANs的更具视觉真实感(感知质量,perceptual quality),对数据分布的建模能力(密度估计,density estimation)比自回归模型要强。但有个缺点就是采样时间太慢,生成高质量的样本需要几百甚至几千次的模型评估,作者提出两点解决这个问题:

1. 新的扩散模型参数化方法,仅用很少的采样步骤就能提供更高的稳定性。

2. 一种多步蒸馏的方法,从一个训练好的确定性扩散采样器中得到一个新的扩散模型,新模型采样步数减半,然后蒸馏新模型,反复迭代。

在经典的图片生成基准数据集CIFAR-10、ImageNet、LSUNS上,当时最先进的采样器也要8192步采样,作者从该采样器开始蒸馏,最终得到只需4步的新模型,在感知质量上没有太大损失。实验验证在CIFAR-10上只需4步就能取得3.0的FID(Frechet Inception Distance)分数,接近于原始多步模型的性能。

同时整个蒸馏过程(对数级迭代次数,每轮成本递减)耗时不超过对原始模型的训练耗时,在训练阶段总时间可控,无需重复训练多个独立模型,测试阶段上面提到的4步模型可比原始模型快2000倍(8192-> 4),适合实时应用,首次实现了扩散模型在少步采样下的实用化,平衡质量和速度。

引入

扩散模型是一类新兴的生成式模型,在ImageNet生成任务上打败了BigGAN-deep和VQ-VAE-2,同样优于自回归图像模型,不限于图像超分辨率和图像修复任务,在3D形状生成、图结构生成和文本生成等领域也被成功应用,成为跨模态生成的统一框架。

但采样速度是扩散模型实际应用的一大阻碍。扩散模型的采样效率高度依赖于任务的条件信息强度,在强条件任务,如文本生成语音、图像超分、分类器引导中,可以快速采样,而在可获取的条件信息较少时,比如仅给出ImageNet类别标签或随机生成人脸,需要成百上千次网络评估,且每步采样都要完整的前向计算,不能像其他生成式模型那样缓存中间结果,计算成本高昂。

应对以往工作中扩散模型最慢的场景,即类别条件生成或无条件(模型只使用自身学到的“图像先验”,将噪声变成图像,不参考类别、文本、标签等信息)生成任务时,“渐进蒸馏”以数量级的形式减少了采样时间,如下图,是两轮迭代的可视化

假设原始采样器f(z; \eta )通过4步确定性采样从噪声z生成图像x,该确定性采样过程可视为求解常微分方程(OED)\frac{dz}{dt}=F(z_t,t;\eta ),其中F由去噪网络\hat{x}_{\eta }(z_t)定义,原始网络通过离散化(如4步欧拉法)近似求解。将其作为目标模型,希望蒸馏出单步采样器f(z;\theta),而蒸馏的本质就是积分步数的压缩和摊销,让学生模型直接预测教师模型多步迭代的稳态解。其间通过最小化两者的输出差异\mathcal{L}(\theta) = \mathbb{E}_z\left \| f(z;\theta)-\tilde{x} \right \|^2,保证样本质量较小地下降。最终在逐步掌握教师模型复杂行为的同时避免进一步压缩导致的质量崩塌。

扩散模型

声明,我们讨论的都是指定在连续时间下定义(离散化是实现细节,而连续时间框架是理论基石 VS.离散时间)的扩散模型。

x \sim p(x)表示训练数据(比如RGB图像),隐变量z = \{z_t|t \in [0,1]\},扩散模型由信号系数\alpha_t(控制原始数据x在隐变量z_t中的保留程度)和噪声系数\sigma_t(控制高斯噪声强度)指定,在连续时间框架下,\alpha_t\sigma_t是平滑函数。\lambda_t = log[\alpha_t^2/\sigma_t^2]为信噪比的对数表示,随时间t单调递减。

前向加噪过程是高斯过程(加噪后的隐变量服从高斯分布):

q(z_t|x)=\mathcal{N}(z_t; \alpha_tx, \sigma_t^2I)

转移概率为

q(z_t|z_s) = \mathcal{N}(z_t;(\alpha_t/\alpha_s)z_s, \sigma_{t|s}^2I)0 \leq s < t < 1,其中\sigma_{t|s}^2 = (1-e^{\lambda_t-\lambda_s})\sigma_t^2


证明:

加噪样本z_t = \alpha_t x+\sigma_t \epsilon_tz_s = \alpha_s x+\sigma_s \epsilon_s,前向过程的马尔可夫性允许我们将\epsilon _t表示为\epsilon _s和一个独立新噪声\epsilon '的线性组合,即

\epsilon _t = e^{(\lambda_t-\lambda_s)/2}\epsilon _s + \sqrt{1-e^{\lambda_t-\lambda_s}}\epsilon ', \epsilon ' \sim \mathcal{N}(0,I),三式联立,得到

z_t = \alpha_t \cdot \frac{z_s-\sigma_s \epsilon _s}{\alpha_s}+\sigma_t \cdot[e^{(\lambda_t-\lambda_s)/2} \epsilon _s + \sqrt{1-e^{\lambda_t-\lambda_s}}\epsilon '] \\ =\frac{\alpha_t}{\alpha_s} \cdot z_s + [\sigma_t e^{(\lambda_t - \lambda_s)/2} - \sigma_s \frac{\alpha_t}{\alpha_s}] \cdot \epsilon _s + \sigma_t \sqrt{1-e^{\lambda_t-\lambda_s}} \cdot \epsilon '

结合信噪比定义\lambda = 2\log(\frac{\alpha}{\sigma}),以及为了保持样本方差有\alpha_t^2+\sigma_t^2=1,知道中间项为0,即

z_t =\frac{\alpha_t}{\alpha_s} \cdot z_s + \sigma_t \sqrt{1-e^{\lambda_t-\lambda_s}} \cdot \epsilon ',便得到了前文的\sigma_{t|s}^2 = (1-e^{\lambda_t-\lambda_s})\sigma_t^2

可以看到新状态z_t是旧状态z_s的线性函数,加上独立的高斯噪声,说明前向过程是一个明确的“线性高斯系统”,构成闭合族,条件/边缘分布都是高斯。

(不知道能否像另一篇博客“修正版(附录B)”这一节中那样证明)


神经网络的任务就是从噪声样本预测原始数据,用加权MSE作为损失函数训练去噪模型\hat{x}_{\theta}

\mathbb{E}_{\epsilon ,t}[w(\lambda_t)\left \| \hat{x}_{\theta}(z_t)-x \right \|_2^2]

其中噪声\epsilon \sim \mathcal{N}(0,I),时间t \sim u[0,1]权重函数w(\lambda_t),加噪样本z_t = \alpha_t x+\sigma_t \epsilon_t。该损失和数据对数似然的变分下界(variational lower bound),或者去噪分数匹配的一种形式(denoising score matching)相关联。

从训练好的模型中采样的方法有好几种。

祖先采样(ancestral sampling)(2020)

DDPM(去噪扩散概率模型)提出的经典方法,按离散时间步迭代去噪,其反向过程是随机的,每一步都从高斯分布中采样,即

z_s \sim q(z_s|z_t,x)=\mathcal{N}(z_s; \tilde{u}_{s|t}(z_t,x), \tilde{\sigma}_{s|t}^2I),注意s<t

这个已知原始数据和当前噪声,求前一步噪声的条件概率分布中有

\tilde{u}_{s|t}(z_t,x)=e^{\lambda_t-\lambda_s}(\alpha_s/\alpha_t)z_t+(1-e^{\lambda_t-\lambda_s})\alpha_sx\tilde{\sigma}_{s|t}^2 = (1-e^{\lambda_t-\lambda_s})\sigma_s^2


 证,这个不死心的菜狗!

 根据贝叶斯公式,有

q(z_s|z_t,x) = \frac{q(z_t|z_s,x)q(z_s|x)}{q(z_t|x)}

由于马尔可夫性,其中q(z_t|z_s,x) = q(z_t|z_s),所有分布均为高斯,有

q(z_t|z_s) \sim \mathcal{N}(z_t;(\alpha_t/\alpha_s)z_s, \sigma_{t|s}^2I)

q(z_s|x) \sim \mathcal{N}(z_s; \alpha_sx, \sigma_s^2I)

q(z_t|x) \sim \mathcal{N}(z_t; \alpha_tx, \sigma_t^2I)

于是所求条件分布的常数项有

\tilde{\sigma}_{s|t}^2 = \frac{\sigma_{t|s}^2 \sigma_s^2}{\sigma_t^2} = (1-e^{\lambda_t-\lambda_s})\sigma_s^2

再忽略常数项,取对数后有

log q(z_s|z_t,x)=-\frac{1}{2} \left [ \frac{\left \| z_t - \frac{\alpha_t}{\alpha_s}z_s \right \|^2}{\sigma_{t|s}^2} + \frac{\left \| z_s-\alpha_sx \right \|^2}{\sigma_s^2} - \frac{\left \| z_t - \alpha_tx \right \|^2}{\sigma_t^2}\right ]

将方括号中z_s前的系数提取出来,即

\frac{-2\tilde{\mu}_{s|t}}{\tilde{\sigma}_{s|t}^2} = -2(\frac{(\alpha_t/\alpha_s)}{\sigma_{t|s}^2} z_t +\frac{\alpha_s}{\sigma_s^2}x)

于是

\tilde{\mu}_{s|t} = (\frac{(\alpha_t/\alpha_s)}{\sigma_{t|s}^2} z_t +\frac{\alpha_s}{\sigma_s^2}x) \cdot (1-e^{\lambda_t-\lambda_s})\sigma_s^2 \\ = \frac{(\alpha_t/\alpha_s)}{(1-e^{\lambda_t-\lambda_s})\sigma_t^2} (1-e^{\lambda_t-\lambda_s})\sigma_s^2 z_t + \frac{\alpha_s}{\sigma_s^2} (1-e^{\lambda_t-\lambda_s})\sigma_s^2 x \\ = \frac{(\alpha_t/\alpha_s)^2}{(\sigma_t/\sigma_s)^2}(\alpha_s/\alpha_t)z_t + (1-e^{\lambda_t-\lambda_s})\alpha_s x \\ = e^{\lambda_t-\lambda_s}(\alpha_s/\alpha_t)z_t + (1-e^{\lambda_t-\lambda_s})\alpha_s x


在反向采样时,原始数据x未知,便用去噪模型\hat{x}_{\theta}(z_t)替代,从纯噪声z_1 \sim \mathcal{N}(0,I)开始,祖先采样器通过以下规则从z_t生成z_ss<t):

z_s = \tilde{u}_{s|t}(z_t, \hat{x}_{\theta}(z_t))+\sqrt{(\tilde{\sigma}^2_{s|t})^{1-\gamma}(\sigma^2_{t|s})^{\gamma}}\cdot \epsilon\epsilon \sim \mathcal{N}(0, I)

均值项:

\tilde{u}_{s|t}(z_t, \hat{x}_{\theta}(z_t)) = e^{\lambda_t-\lambda_s}(\alpha_s/\alpha_t)z_t + (1-e^{\lambda_t-\lambda_s})\alpha_s \hat{x}_{\theta}(z_t)

其中第一项是对当前噪声数据z_t的线性投影,第二项是基于模型预测的去噪方向,权重由信噪比变化e^{\lambda_t-\lambda_s}控制。组合起来已不再是线性函数,多步链式组合后无法写成单一高斯。

方差项:

\sqrt{(\tilde{\sigma}^2_{s|t})^{1-\gamma}(\sigma^2_{t|s})^{\gamma}}

其中\tilde{\sigma}_{s|t}^2 = (1-e^{\lambda_t-\lambda_s})\sigma_s^2是反向过程中来自于模型推导的理论方差,\tilde{\sigma}_{t|s}^2 = (1-e^{\lambda_t-\lambda_s})\sigma_t^2是前向过程中来自于噪声调度的实际方差,超参数\gamma由用户设定,用于控制噪声注入的强度,为0是完全使用反向方差,此时更新规则的有效性依赖于去噪模型的准确性,适合需要稳定输出的任务,如超分辨率、图像编辑,但可能生成重复样本,多样性不足;为1时则完全使用前向方差,保留了前向过程的噪声调度,即使在去噪后期也仍可能注入噪声,适合无条件生成,如艺术创作,但需要更多采样步数才能收敛。

DDPM默认需要大量步数迭代去噪,生成质量高但速度慢,可能过拟合训练数据分布。适用于追求极致质量,不计较时间成本的任务。

概率流OED(2021)

前向扩散过程通常由以下随机微分方程(SDE)描述:

dz_t = f(z_t, t)dt+g(t)dW_t

其中f(z_t,t)漂移项,控制信号的确定性演化,g(t)扩散项,控制噪声的注入强度,W_t标准的布朗运动,作为随机噪声。根据Kingma et al.(2021)的推导,有f(z_t,t) = \frac{dlog \alpha_t}{dt}z_t(正向时为负,指示信号随时间衰减),g^2(t) = \frac{d \sigma_t^2}{dt}-2\frac{dlog \alpha_t}{dt} \sigma_t^2

方差保持(variance-preserving),有\alpha_t^2 = 1-\sigma_t^2 = sigmoid(\lambda_t) = 1/(1+e^{-\lambda_t})

\frac{d \alpha_t^2}{d \lambda_t} = \alpha_t^2(1- \alpha_t^2)= \alpha_t^2\sigma_t^2\frac{d\sigma_t^2}{dt} = \frac{d(1-\alpha_t^2)}{dt} = -\frac{d\alpha_t^2}{d\lambda_t} \cdot \frac{d\lambda_t}{dt} = -\alpha_t^2 \sigma_t^2\frac{d\lambda_t}{dt},代入得

f(z_t,t) = \frac{dlog \alpha_t}{dt}z_t = \frac{dlog (\alpha_t^2)}{2d\lambda_t} \cdot \frac{d\lambda_t}{dt} \cdot z_t = \frac{1}{2(\alpha_t^2)} \cdot \frac{d(\alpha_t^2)}{d\lambda_t} \cdot \frac{d\lambda_t}{dt} \cdot z_t = \frac{\sigma_t^2}{2}\frac{d\lambda_t}{dt} z_t

g^2(t) = \frac{d \sigma_t^2}{dt}-2\frac{dlog \alpha_t}{dt} \sigma_t^2 = -\alpha_t^2 \sigma_t^2\frac{d\lambda_t}{dt} - 2 \cdot \frac{\sigma_t^2}{2}\frac{d\lambda_t}{dt} \cdot \sigma_t^2 = -\sigma_t^2 \frac{d\lambda_t}{dt}

扩散模型中另一种采样方法,就是通过求解下列概率流常微分方程(ODE),实现从初始噪声z_1 \sim \mathcal{N}(0,I)到生成样本x的确定性映射

dz_t = [f(z_t, t)-\frac{1}{2}g^2(t) \bigtriangledown _z log \hat{p}_{\theta}(z_t)]dt

其中对数概率密度的梯度\bigtriangledown _z log \hat{p}_{\theta}(z_t) = \frac{\alpha_t \hat{x}_{\theta}(z_t)-z_t}{\sigma_t^2}q(z_t|x) = \mathcal{N}(z_t; \alpha_tx, \sigma_t^2I)),即得分函数(score function),由去噪网络\hat{x}_{\theta}隐式学习,那么

dz_t = \frac{\sigma_t^2}{2}[z_t+ \frac{\alpha_t \hat{x}_{\theta}(z_t)-z_t}{\sigma_t^2}]d\lambda_t = \frac{1}{2}[\alpha_t \hat{x}_{\theta}(z_t)-(1-\sigma_t^2)z_t]d\lambda_t \\ = \frac{1}{2}[\alpha_t \hat{x}_{\theta}(z_t)-\alpha_t^2z_t]d\lambda_t

这是概率流ODE在信噪比参数化下的形式,表示潜变量随SNR变化的瞬间动力学。

神经ODE由Che et al.(2018)提出,核心思想是用神经网络参数化的ODE来描述数据的动态演变,即\frac{dz(t)}{dt}=f_{\theta}(z(t), t),概率流ODE便是神经ODE在生成建模中的一个特例。

连续归一化流(CNF)由Grathwohl et al.(2018)提出,其核心是通过一个可逆的、连续的变换将简单分布映射到复杂数据分布,其变化过程由神经ODE描述,其关键性质在于可以通过正向/反向ODE求解器实现双向变换,可以利用瞬时变化公式计算概率密度,概率流ODE的正向/反向ODE形式均为式,只是时间方向相反,也是一种CNF。


求解ODE的标准方法有

* 欧拉方法:最简单的ODE求解器,通过线性近似,用当前点的导数预测下一步的值,即z_{t+1}=z_t+h \cdot F(z_t,t)h是步长。实现简单,计算量小,但精度低,O(h),稳定性差。

* Runge-Kutta方法:高阶ODE求解器,通过多阶段计算斜率,加权平均以提高精度,经典四阶Runge-Kutta(RK4)就是

k_1 = h \cdot F(z_t,t)

k_2 = h \cdot F(z_t+\frac{k_1}{2},t+\frac{h}{2})

k_3 = h \cdot F(z_t+\frac{k_2}{2},t+\frac{h}{2})

k_4 = h \cdot F(z_t+k_3,t+h)

z_{t+1}=z_t+\frac{1}{6} (k_1+2 k_2+2 k_3+k_4)

误差为O(h^4)。RK高精度,计算量大,可通过调整步长平衡效率和精度。

对于种种ODE求解器,用户只需要提供ODE的右侧函数F(z_t,t)和初始条件,无需关心其内部是如何数值求解的,故为“黑箱”。


基于Score-Based Models,用梯度场直接控制生成过程,利用SDE/ODE理论实现确定性采样,支持高精度差值。需调用黑箱ODE求解器,计算开销和生成质量均依赖于此。多样性较低,适用于需确定性控制和科学计算任务。

DDIM(2021)

去噪扩散隐式模型,其核心是通过非马尔科夫链的扩散路径,跳过部分中间步骤,减少采样步数,后续发现其更新规则实际上可视为对概率流OED的数值积分方法(类似于欧拉法或RK法),尽管最初设计没从这角度出发。

前文说过,扩散模型的逆向过程可以通过概率流ODE描述:

dz_t = [f(z_t, t)-\frac{1}{2}g^2(t) \bigtriangledown _z log \hat{p}_{\theta}(z_t)]dt

其中f(z_t,t)g(t)由前向过程定义,得分函数\bigtriangledown _zlog p_t(z_t) \approx \frac{\alpha_t \hat{x}_{\theta}(z_t) - z_t}{\sigma_t^2}通过去噪模型近似。

而DDIM的更新规则如下:

z_s=\alpha_s \hat{x}_{\theta}(z_t)+\frac{\sigma_s}{\sigma_t} (z_t-\alpha_t \hat{x}_{\theta}(z_t)) = \\ e^{(\lambda_t-\lambda_s)/2}(\alpha_s/\alpha_t)z_t+(1-e^{(\lambda_t-\lambda_s)/2})\alpha_s\hat{x}_{\theta}(z_t)

可视作在潜空间中对噪声和信号的线性插值,其插值系数由SNR决定。


通过数学变换,可证明该规则是概率流ODE的一阶积分方法:

还是由“方差保持”可知\frac{d \alpha_s}{d \lambda_s} = \frac{1}{2}\alpha_s \sigma_t^s\frac{d\sigma_s}{d\lambda_s} = -\frac{1}{2} \alpha_s^2 \sigma_s,将上式两边对\lambda_s求导

\frac{dz_s}{d\lambda_s} = \frac{d\alpha_s}{d\lambda_s} \cdot \hat{x}_{\theta}(z_t)+\frac{d\sigma_s}{d\lambda_s} \cdot \frac{z_t-\alpha_t \hat{x}_{\theta}(z_t)}{\sigma_t} = \frac{1}{2}\alpha_s \sigma_s^2 \cdot \hat{x}_{\theta}(z_t) -\frac{1}{2} \alpha_s^2 \sigma_s \cdot \frac{z_t-\alpha_t \hat{x}_{\theta}(z_t)}{\sigma_t}

s=t(一阶泰勒展开)时求导就可以得到

\frac{dz_s}{d\lambda_s}|_{s=t} = \frac{1}{2}\alpha_t \sigma_t^2 \cdot \hat{x}_{\theta}(z_t) -\frac{1}{2} \alpha_t^2 \sigma_t \cdot \frac{z_t-\alpha_t \hat{x}_{\theta}(z_t)}{\sigma_t} \\ =\frac{1}{2}[\alpha_t \sigma_t^2 \cdot \hat{x}_{\theta}(z_t)+\alpha_t^3 \cdot \hat{x}_{\theta}(z_t) - \alpha_t^2z_t] = \frac{1}{2}[\alpha_t \hat{x}_{\theta}(z_t) - \alpha_t^2z_t]

完全一致!可知DDIM遵循一阶的概率流ODE。虽理论阶数与欧拉法相同,但,DDIM更新依赖于训练好的神经网络,隐含了数据分布的复杂先验,而欧拉法直接使用显式定义的F(z_t,t),与数据分布无关;DDIM允许非均匀步长(由\lambda_t-\lambda_s\alpha_s/\alpha_t控制),甚至跳跃式生成,而欧拉法需固定步长,且需小步长保证梯度;DDIM源于扩散模型的反向过程推导,目标是最大化生成数据的似然,而欧拉法是通用ODE数值逼近方法,目标是最小化局部截断误差。即,DDIM为隐式自适应步长,匹配扩散过程动力学,更新规则和前向过程形式对称,避免了欧拉法中人为离散化引入的误差。


牺牲了部分理论严谨性,允许用更少的步数生成,速度显著提升,适合实时交互式应用或资源有限的设备。

在去噪模型\hat{x}_{\theta}(z_t)满足一定平滑性的前提下,概率流ODE的数值解可以通过增加积分步数(减小步长)任意逼近理论解,人们便需要在生成质量和时间开销之间进行权衡。一直以来多数模型都需要成百上千步才能达到最佳质量,不能实际应用。而作者提出一种方法,将这些精确但缓慢的ODE积分器蒸馏成很快,且依旧精确的学生模型。

渐进蒸馏

渐进蒸馏和训练原始扩散模型的实现非常相近,算法并列呈现如下,其中渐进蒸馏的相对变化以绿色突出

算法1是标准的扩散模型训练过程,每轮迭代时独立从数据集中采取一个新样本x,并对其进行一次加噪&去噪训练,通过多次训练样本的“单步去噪”误差累积,模型学习出对所有z_t的去噪能力。而在推理阶段才有多步迭代完成去噪的过程,例如在DDPM或DDIM采样过程中,从纯噪声开始,每一步z_t \rightarrow z_{t-1},需多步才逐渐还原出干净样本。

在算法2中,先拿到用标准方法训练出的教师模型,然后在渐进蒸馏的每轮迭代中,通过复制教师模型的方式初始化学生模型,采用相同的模型结构和参数;和算法1一样从训练集中采样数据并加噪,得到z_t喂给学生模型,让它去预测一个“目标值”,并计算损失;关键区别是,我们不再使用原始数据x作为目标,而是构造出一个新目标\tilde{x},让学生只用1步DDIM推理,就能达到老师用2步的效果,怎么做的呢?

假设当前轮希望学生的采样步数为N,从z_t出发,让教师模型执行两步DDIM推理,得到z_{t-1/N},反推学生只走一步,就能达到z_{t-1/N},该预测哪个\tilde{x}才能实现呢?附录G有推导过程:

t'=t-0.5/Nt''=t-1/N,教师模型两步预测得到样本z_{t''},根据DDIM的更新规则,学生模型一步预测得到样本

\tilde{z}_{t''} = \alpha_{t''}\tilde{x}+\frac{\sigma_{t''}}{\sigma_t}(z_t-\alpha_t \tilde{x})

为了匹配两个样本,令\tilde{z}_{t''}=z_{t''},得\tilde{x} = (z_{t''}-\frac{\sigma_{t''}}{\sigma_t}z_t)/(\alpha_{t''}-\frac{\sigma_{t''}}{\sigma_t}\alpha_t)

这是一个明确可计算的“最优去噪目标”,对每个z_t都唯一,会让学生模型学的更加“果断”,反观标准扩散模型训练时,同一个z_t可能由不同的x组合不同的噪声得到,模型预测的是一个模糊的目标(后验分布的均值),往往得到模糊的图像。

在得到一个N步预测的学生模型后,可以重复这个过程以得到N/2步的学生模型,循此往复。


让“学生模型一步”模拟“教师模型两步”的可能性,是DDIM这类确定性采样器所特有的性质,这也是逐步蒸馏能成立的数学基础:在DDIM中,每一步采样是一个确定性的函数,故其组合仍是可逆的、可表示的函数,可组成一个新的神经网络函数。但如果采用的是DDPM,每一步采样虽说都是从一个高斯分布中采样,但其均值是由神经网络预测的非线性函数,两步组合后不再是高斯分布了,故DDPM这种随机采样结构是不可压缩的,不能直接用“更少步骤的DDPM”来精确模拟“多步骤DDPM”,强行蒸馏必然损失保真度。


其实蒸馏并非没有代价,教师模型是在整个连续时间区间t \sim U[0,1]上训练的,有很强的时间泛化能力,而学生模型只在一组离散的采样时间点t=i/N \in (0,1]上训练,但也正是只需在这些固定的点上学习“如何一步去噪”,使其能将模型容量集中在少量关键步骤上。

在选择离散点时,会设第一个时间点对应的信噪比为0,即\alpha_1=0,以对齐纯高斯噪声的采样起点,比起从还有一点点信号成分的噪声开始,在训练原始扩散模型和进行渐进蒸馏的场景下表现会稍稍好些。

参数化和训练损失

怎么参数化这个去噪模型\hat{x}_{\theta}?如何重构这个重建损失权重w(\lambda_t)

首先在(Kingma et al., 2021)中就有证明过,无论如何修改扩散过程的参数或形式,模型都能学到本质相同的分布,不同的扩散过程设计(如variance-preserving和variance-exploding)本质上可以通过对z_t进行线性变换而相互等价。


什么是方差爆炸?

VE的前向过程为z_t=z_0+\sigma_t \epsilon , \epsilon \sim \mathcal{N}(0, I),方差无限增长,最终分布趋近于无界高斯分布。

为什么和方差保持等价?

虽说VP和VE在形式上不同,但它们都可以实现“从数据到高斯分布的逐步扰动”,从而为逆过程(去噪)提供可能,二者的本质等价性源于一个事实,即若两个扩散过程只是隐变量尺度不同,则通过线性变换可以将一个变成另一个。两类扩散过程中的变量x_t^{VP}x_t^{VE}之间存在时间依赖的线性变换x_t^{VP} = \alpha_t \cdot x_t^{VE},只要信号系数选的合适,就能让VP和VE过程满足相同的统计分布形式,甚至对应的反向过程也能相互映射,因为分数函数也只差了一个比例因子:

\bigtriangledown _{z_t^{VE}} log p_t^{VE}(z_t^{VE}) = \alpha(t) \cdot \bigtriangledown _{z_t^{VP}} log p_t^{VP}(z_t^{VP})


作者参考Improved DDPM(Nichol & Dhariwal, 2021),采用余弦调度,即\alpha_t = cos(0.5 \pi t),其特点是t=0时信号保持完整,t=1时完全是噪声,平滑递减,比线性/指数调度更自然、更好控制噪声分布,是一种合理的、兼具性能和理论支持的设计。

Ho et al.(2020)及后续工作选择让神经网络\hat{\epsilon }_{\theta}(z_t)直接预测噪声,反推出原始数据\hat{x}_{\theta}(z_t)=\frac{1}{\alpha_t}(z_t-\sigma_t \hat{\epsilon }_{\theta}(z_t)),损失函数定义在\epsilon空间,有

L_{\theta} = \left \| \epsilon - \hat{\epsilon }_{\theta}(z_t) \right \|_2^2\\ = \left \| \frac{1}{\sigma_t}(z_t-\alpha_tx)-\frac{1}{\sigma_t}(z_t-\alpha_t \hat{x}_{\theta}(z_t)) \right \|_2^2 = \frac{\alpha_t^2}{\sigma_t^2}\left \| x- \hat{x}_{\theta}(z_t)\right \|_2^2

等价于在x空间的误差,只是再带个缩放因子,即权重函数为w(\lambda_t) = exp(\lambda_t)

这种标准的预测噪声参数化方法在训练原始扩散模型时表现很好,但并不适用于蒸馏!在渐进蒸馏初期,模型在不同的t \in [0,1]上训练,看到的SNR分布广泛,从高到低都有,但随着蒸馏逐步推进,采样步数减少,每一步的时间间隔增大,模型主要处理极低SNR(余弦调度,SNR与t的关系是快速下降的非线性曲线),而\alpha_t \rightarrow 0时,网络输出的微小改变可能被放大成x的巨大误差。在采样步数较多的情况下,即使某步预测错了也能靠后续步骤慢慢纠正,系统具备冗余性和纠错能力,但步数变少后,每步都承担着更大的“生成压力”,最终蒸馏到只有一步时,模型的输入是纯噪声,\alpha_t = 0z_t不再含有任何x的信息,\epsilon-预测和x-预测之间的数学链条消失!

故为了渐进蒸馏,对扩散模型的参数化需使得即使在不同信噪比下,对\hat{x}_{\theta}(z_t)的预测始终稳定,下面几种方案表现都不错:

1. 直接预测x

2. 单独的输出通道\left \{ \tilde{x}_{\theta}(z_t), \tilde{\epsilon }_{\theta}(z_t) \right \}预测x\epsilon,再合并\hat{x}=\sigma_t^2 \tilde{x}_{\theta}(z_t)+\alpha_t(z_t-\sigma_t \tilde{\epsilon }_{\theta}(z_t))

3.  预测v = \alpha_t \epsilon - \sigma_t x,取\hat{x} = \alpha_t z_t - \sigma_t \hat{v}_{\theta}(z_t)

除了找一个合适的参数化方法,还需要决定重构损失权重w(\lambda_t),像Ho et al.(2020)那样根据信噪比对重构损失进行加权,不适用于蒸馏,作者考量了以下两种权重:

1. L_{\theta} = max(\left \| x-\hat{x}_t \right \|_2^2, \left \| \epsilon -\hat{\epsilon }_t \right \|_2^2) = max(\frac{\alpha_t^2}{\sigma_t^2}, 1)\left \| x-\hat{x}_t \right \|_2^2,“截断版SNR”加权

2. L_{\theta}=\left \| v_t-\hat{v}_t \right \|_2^2 = (1+\frac{\alpha_t^2}{\sigma_t^2})\left \| x-\hat{x}_t \right \|_2^2,“SNR+1版”加权

实验证明都是用来训练扩散模型不错的选择,实际应用中还需考虑\alpha_t\sigma_t是如何采样的(关于时间t的函数),若训练样本在时间维度上不是均匀采样的,那某些时间区间会更频繁地出现,对损失的期望权重会更大,也就是说,损失实际的“总权重”是“单样本权重×时间采样密度”,比如采用余弦调度,下面左右图分别是不考虑和考虑调度的权重

横轴均为信噪比的对数表示log(SNR) = log[\alpha_t^2/\sigma_t^2] = \lambda_t,左图的纵坐标logw(\lambda_t)=

log(e^{\lambda_t}) = \lambda_t,呈红色直线;

log(max(e^{\lambda_t}, 1)) = max(\lambda_t, 0),呈蓝虚折线;

log(e^{\lambda_t}+1),呈黑虚曲线。(怪,感觉有些位移)

时间均匀,而SNR(t) = (\frac{\alpha_t}{\sigma_t})^2 = cot^2(0.5\pi t)不均匀,有\lambda_t = 2log(cot(0.5\pi t)),得到反函数t = \frac{2}{\pi}acrtan(e^{-\lambda_t/2}),计算SNR密度(即单位\lambda_t对应的时间间隔,导数为负数时只能说明\lambda_t增加时t减少,但变化量依旧是对应非负绝对值)得p(\lambda_t) = \left | \frac{dt}{d\lambda_t} \right | = \left | \frac{2}{\pi} \cdot \frac{1}{1+e^{-\lambda_t}} \cdot e^{-\lambda_t/2} \cdot \frac{-1}{2} \right | = \left | -\frac{1}{\pi} \cdot \frac{e^{-\lambda_t/2}}{1+e^{-\lambda_t}} \right | = \frac{1}{\pi} \cdot \frac{e^{-\lambda_t/2}}{1+e^{-\lambda_t}}

故右图的纵坐标为w(\lambda_t) \cdot p(\lambda_t) = \frac{1}{\pi} \cdot \frac{e^{\lambda_t/2}}{1+e^{-\lambda_t}} (红)……

对这些权重函数本身乘以一个任意的常数因子是不改变训练效果的,故为了便于图像对齐和可视化,作者给每个权重做了调整,使它们在同一张图上更清晰地展示。

实验

下面就是用实验来验证算法,以及针对各式模型参数化和损失加权方法的消融实验了。

所有实验采用余弦调度、和Ho et al.(2020)一样的U-Net模型架构,利用BigGAN风格的残差块(通道数可变、注意力机制优化)进行上下采样,参考Nichol&Dhariwal的改进设计,训练超参和DDPM开源代码一致。

考量了不同的图片生成任务,分辨率从32×32到128×128,最后作者选择无条件CIFAR-10(32×32分辨率)(模型会学着如何从灶神中恢复CIFAR-10风格的自然图像,但不区分猫狗等类别)作为基准,因为小分辨率图像训练速度快,适合大规模消融实验,而且CIFAR-10是扩散模型的经典测试基准,便于与已有研究比对。从头开始训练,避免预训练模型的偏差。

下表中,评估指标为FID(Fréchet Inception Distance,衡量生成图像和真实图像分布的距离,越低越好)和IS(Inception Score,衡量生成图像的多样性与清晰度,越高越好),在无条件CIFAR-10任务上训练原始模型(不蒸馏),所有结果均在200万步训练后,对3个不同的随机种子跑出的最好性能指标做平均,有±0.1的波动,在可接受范围内,训练本身具有一定的随机性。结果显示,当选择预测噪声参数化方法+截断版SNR损失加权时,训练会发散,因为在高噪声区,截断版SNR会被截断为1,而噪声预测在此时梯度计算也不稳定。除此之外参数化和损失加权的各式组合表现差异不大

参数化方法消融

预测v是最稳定的选择,因为DDIM的步长和信噪比无关,避免极端噪声下的数值问题,同时显式解耦了噪声和数据预测,但单从指标上看,预测x可能略好。


附录D

下面对DDIM的更新规则做个简化,用\phi=arctan(\sigma/\alpha)作为参数,而非时间t或者信噪比的对数\lambda_t。针对一个方差保持的模型扩散过程,有\alpha_{\phi}=cos(\phi)\sigma_{\phi}=sin(\phi),于是

z_{\phi}=cos(\phi)x+sin(\phi)\epsilon,现定义

v_{\phi}\equiv \frac{dz_{\phi}}{d\phi} = \frac{dcos(\phi)}{d\phi}x+\frac{dsin(\phi)}{d\phi}\epsilon = cos(\phi)\epsilon - sin(\phi)x

两式消去\epsilon得到x=cos(\phi) \cdot z_{\phi}-sin(\phi) \cdot v_{\phi},消去x得到\epsilon = sin(\phi) \cdot z_{\phi} + cos(\phi) \cdot v_{\phi},(只要知道zv,就能解出x\epsilon,正所谓“解耦噪声和数据预测”)据此重写DDIM的更新规则

z_{\phi_s}=cos(\phi_s) \hat{x}_{\theta}(z_{\phi_t}) + sin(\phi_s) \hat{\epsilon }_{\theta}(z_{\phi_t}) \\ = cos(\phi_s)(cos(\phi_t)z_{\phi_t}-sin(\phi_t)\hat{v}_{\theta}(z_{\phi_t})) + sin(\phi_s)(sin(\phi_t)z_{\phi_t}+cos(\phi_t)\hat{v}_{\theta}(z_{\phi_t})) \\ = [cos(\phi_s)cos(\phi_t)+sin(\phi_s)sin(\phi_t)]z_{\phi_t}+[sin(\phi_s)cos(\phi_t)-cos(\phi_s)sin(\phi_t)]\hat{v}_{\theta}(z_{\phi_t}) \\ = cos(\phi_s - \phi_t)z_{\phi_t}+sin(\phi_s - \phi_t)\hat{v}_{\theta}(z_{\phi_t})

或者等价于z_{\phi_t-\delta} = cos(\delta )z_{\phi_t}-sin(\delta ) \hat{v}_{\theta} (z_{\phi_t}),其中\delta = \phi_t - \phi_s,从这个角度看,DDIM的采样轨迹就像绕着一个圆转动,它从当前状态z_{\phi_t}出发,沿着预测的“速度”方向\hat{v}的反方向移动一个小角度,从而得到下一个状态z_{\phi_s}

深入几何本质,在“角度参数化”下,DDIM的变量z的角度就代表SNR,有SNR=\frac{\alpha_t^2}{\sigma_t^2}=cot^2 (\phi),角度越小,信号越强,角度越大,噪声越强,长度则代表总能量:

z_{\phi}=cos(\phi)x+sin(\phi)\epsilon的范数平方期望为

\mathbb{E}[\left \| z_{\phi} \right \|^2] = cos^2(\phi) \cdot \left \| x \right \|^2+sin^2{\phi} \cdot \mathbb{E}[\left \| \epsilon \right \|^2]

若对x\epsilon做归一化处理,使得\left \| x \right \|^2 = \left \| \epsilon \right \|^2 = d,则\mathbb{E}[\left \| z_{\phi} \right \|^2] = d,即z的期望模长恒定,在单位圆上,但实际训练中可能因数据分布或模型噪声略有波动。故预测的是v时,步长完全由角度差决定,不依赖于SNR的曲线形状,算一种“几何自然、角度均匀”的更新方式。


渐进蒸馏验证

在4个数据集上进行评估:32×32的CIFAR-10、64×64下采样后的ImageNet、128×128的LSUN bedrooms以及128×128的LSUN Church-Outdoor。在每个数据集上先训练一个基线模型,然后在此基础上进行渐进蒸馏。在CIFAR-10上的原始模型的去噪采样步数为8192,这是一个非常精细、近乎权威的教师模型,可以用来生成很优质的训练信号了,而在更大的数据集上,图像分辨率更高,单次采样所需的显存、计算资源大幅上升,如果设置这么多采样步数,单批次训练成本太高了,改用1024步足矣,够生成高质量样本了。

每轮蒸馏时设置参数更新5万次,最后两轮(distill to 2-step和1-step)例外,设置了10万次。每轮迭代后都会报告一次FID结果。最终蒸馏到4步采样的总的计算开销不会超过训练原始模型的成本。在附录I中还有仅在稍稍损失性能的前提下进步减少计算开销的办法。


附录I

为了进一步减少渐进蒸馏的计算成本,在CIFAR-10上进行消融实验,减少了每轮训练学生模型时参数更新的次数,不再用5万了,减到2.5万、1万甚至5千,下图显示,我们可以大幅减少更新次数,并且在使用≥4个采样步数时仍能获得很好的性能,而在采样步数很少时,如果对学生只是进行短时间的训练,性能损失会更加明显

也试过更激进的压缩策略,每轮不是采样步数减半,而是变成1/4 ,考虑到让教师执行4步DDIM推理会大大增加计算成本,从而失去了蒸馏的意义,所以在生成训练目标时仍只让教师执行两步。可以看到红色线表现一般,效果不好,所以说如果你计算资源有限,与其跳过蒸馏阶段,一口气压缩4倍步数,不如老老实实减半蒸馏,每次只训练少量更新步数。

上图中每个点是4个随机种子平均后的结果,对于每个方案会从[5e^{-5}, 1e^{-4},2e^{-4},3e^{-4}]中选择最优学习率。

下面是在ImageNet和LSUN数据集上的结果,即使用了加速版的蒸馏策略,生成质量还是可以的

上图学习率设置照旧,但结果仅来自于单个随机种子。 


下图是特定采样步数后得到的FID分数,和原始的DDIM采样器、高度优化的随机基线采样器比对。对于4个数据集,逐步蒸馏得到的模型在采样步数达到4步或者8步之前,生成质量几乎是最好的,FID分数甚至与原始模型多步结果相当,但进一步见到2步或1步时,质量迅速恶化,而DDIM和随机采样器在低于128步时就质量下降的很快了

总的来时这套蒸馏方案效果还是很诱人的。虽说设计时是针对确定性采样器来训练模型的,得到的学生模型更擅长做DDIM任务,但原则上也可用于随机采样,实验结果就是性能介于使用DDIM采样的蒸馏模型和使用随机采样的原始模型之间


附录F

不论DDPM还是DDIM,本质上都依赖于一个神经网络从噪声样本预测原始数据,前文更新规则

DDPM为z_s = e^{\lambda_t-\lambda_s}(\alpha_s/\alpha_t)z_t + (1-e^{\lambda_t-\lambda_s})\alpha_s \hat{x}_{\theta}(z_t) + \sqrt{(\tilde{\sigma}^2_{s|t})^{1-\gamma}(\sigma^2_{t|s})^{\gamma}}\cdot \epsilon

DDIM为z_s= e^{(\lambda_t-\lambda_s)/2}(\alpha_s/\alpha_t)z_t+(1-e^{(\lambda_t-\lambda_s)/2})\alpha_s\hat{x}_{\theta}(z_t)

只要学生模型学会了正确的分数函数,就可以用到任意形式的采样器上。在DDIM中去掉了噪声项,实现确定性更新,训练的目标也是让学生尽可能还原确定路径,而非建模随机分布,缺乏对其它可能轨迹的泛化性能。实际性能比对看看下图呢

 蒸馏VS.没蒸馏,DDIM VS.随机采样,为了在使用随机采样时尽可能发挥模型性能,作者对11个不同噪声水平(指在反向采样过程中人为注入的噪声强度,在Ho et al.提供的最小和最大方差之间对数均匀分布,更关注低噪声区间)进行网格搜索,然后选择生成质量(FID)最好的结果进行报告。结果就是,用蒸馏模型做随机采样,总体质量介于用原始模型做随机采样和用蒸馏模型做DDIM采样之间:在采样步数少时,用蒸馏模型做DDIM采样表现更好,采样数多时则用随机采样较好了。


下表展示了在CIFAR-10上的结果,和其它文献中快速采样的方法进行比较,效果更好,步骤更少

下图是在不同的蒸馏阶段从模型中获取的一些随机样本,从64×64的ImageNet模型开始蒸馏,以类别标签“malamute(阿拉斯加雪橇犬)”为条件输入,生成与该类别相关的图像,所有生成样本使用相同的初始噪声,区别在于不同阶段蒸馏模型的采样步数,结果就是步数减少时仍能保持从噪声到图像的稳定映射(同输入近似输出,不同输入产生样本多样性)

再补充几张,在CIFAR-10上的无条件生成任务:

 在ImageNet上以“coral reef(珊瑚礁)”为标签:

 在ImageNet上以“sports car(跑车)”为标签:

在128×128的LSUN bedrooms上的无条件生成:

在128×128的LSUN Church-Outdoor上的无条件生成:

快速采样的相关工作

最接近的工作应该是Luhman&Luhman(2021),它将一个DDIM的教师模型蒸馏成一步学生模型,但其不足是必须让教师跑完整个DDIM采样过程,对每个训练样本生成整条轨迹,拿到所有中间状态,来构建一个庞大的训练数据集,所以它们蒸馏的计算成本会随着采样步数N线性增长,就不适用于步数很多的模型(比如训练大图、音频、视频等)。相反作者的方法每轮只需要教师前向两步,每阶段训练量为常数,整个蒸馏过程的总训练时间为O(logN),远好过Luhman的O(N)

DDIM作为一种确定性采样算法,在DDPM的分数函数上重新推导出一个确定性的ODE解,允许在远少于原始DDPM所需步骤数的情况下生成高质量样本,是渐进蒸馏的基础,是这篇工作的先导成果。概率流采样器将扩散过程看作连续时间的概率流ODE,通过ODE数值解法进行采样,本质上是对DDPM的一个确定性极限形式,是和DDIM是非常相似的机制,只在推导方式和理论表述上略有不同。

Jolicoeur-Martineau at al.(2021)研究了反向扩散过程的快速SDE数值积分器,Tzen&Raginsky (2019b)研究了无偏采样器,同样有助于实现快速、高质量的采样。

前面提到的DDIM、概率流采样器、快速SDE数值积分器、无偏采样器,强调的是推理路径重构、采样动力学建模、采样器设计,而后要讲到的研究,强调的是给定一个训练好的扩散模型,如何通过调度设计、采样点选择、后处理机制来加速采样,更偏向于应用层优化、后训练调节策略或外部控制机制。

Nichol & Dhariwal(2021)提出一系列方法,将原本基于大量时间步训练的离散时间扩散模型,调整为可以在较少采样步数下进行高质量生成的模型;Watson at al.(2021)提出了一种基于动态规划的算法,用于在固定步数预算下选择最优的采样时间点组合,从而在最大化对数似然的意义下压缩采样过程;Chen et al.(2021)、Saharia et al.(2021)、Ho et al.(2021)采用基于连续噪声水平的训练方式,并在训练完成后通过调整有限反向扩散过程的噪声水平实现采样器的调优,这类方法在高度条件化的生成任务中表现尤为出色,如语音合成(text2speech)和图像超分辨率(image super-resolution)等;San-Roman et al.(2021)则训练了一个额外的神经网络,用于估计输入数据的噪声水平,并展示了如何利用这一估计结果加速采样过程,从而在保持生成质量的前提下减少推理成本。

对扩散模型进行替代性的建模方式也可以实现快速采样,例如修改前项和反向过程(Nachmani et al.2021,Lam et al.2021),或将训练转移到潜在空间中进行。

讨论

作者对带确定性采样器(如DDIM)的扩散模型进行渐进蒸馏,将原本生成高质量图片等数据所需的采样步数大幅减少,通过减少运行时间和计算需求,希望能增加其实用性。

在这篇工作中,限制了学生模型的架构和参数量(和教师一致),作者希望进一步放宽限制,探索更小的学生模型,在测试时间上得到进一步收益,也希望跳出图像生成,探究针对其它数据模态(如音频)的扩散模型的渐进蒸馏。

除了渐进蒸馏,可以看到不同的参数化方法和对应的损失设置也能带来一定的收益,作者也希望社区的探索能使这个方向上看到更多进步。

作者提供了模型架构、训练过程、超参等细节,算法2对应的具体实现已开源:

https://github.com/google-research/google-research/tree/master/diffusion_distillation

http://www.dtcms.com/a/276130.html

相关文章:

  • Java-线程池
  • 【机器学习实战笔记 16】集成学习:LightGBM算法
  • AV1高层语法
  • PostgreSQL HOT (Heap Only Tuple) 更新机制详解
  • Swin Transformer核心思路讲解(个人总结)
  • 文件上传漏洞2-常规厂商检测限制绕过原理讲解
  • 强化学习、PPO和GRPO的通俗讲解
  • C语言第一章数据类型和变量(下)
  • Java 大视界:基于 Java 的大数据可视化在智慧城市能源消耗动态监测与优化决策中的应用(2025 实战全景)
  • 视频分析应用的搭建
  • 【Linux-云原生-笔记】Apache相关
  • NE综合实验2:RIP与OSPF动态路由优化配置、FTP/TELNET服务部署及精细化访问控制
  • Java反射与注解
  • 树形动态规划详解
  • 大数据时代UI前端的智能化服务升级:基于用户情境的主动服务设计
  • 【PycharmPyqt designer桌面程序设计】
  • 【学习新知识】用 Clang 提取函数体 + 构建代码知识库 + AI 问答系统
  • GD32 CAN1和TIMER0同时开启问题
  • 《通信原理》学习笔记——第一章
  • 细谈kotlin中缀表达式
  • H2在springboot的单元测试中的应用
  • skywalking镜像应用springboot的例子
  • try-catch-finally可能输出的答案?
  • Docker-镜像构建原因
  • C语言基础教程--从入门到精通
  • Spring Boot整合MyBatis+MySQL+Redis单表CRUD教程
  • STM32中的RTC(实时时钟)详解
  • R 语言绘制 10 种精美火山图:转录组差异基因可视化
  • JavaScript 常见10种设计模式
  • 码头智能哨兵:AI入侵检测系统如何终结废钢盗窃困局