DDPM 做了什么
本博客主要侧重点在于HOW也就是DDPM怎么做的而不是WHY为什么要这样做
那么第一个问题DDPM做了一件什么事:这个算法通过逐渐向原图像添加噪声来破坏图像,然后再学习如何从噪声成恢复图像。
第二件事如何做到的:通过训练一个网络,这个网络输入为加噪声图片和添加噪声的次数,输出为网络预测施加在图像上的噪声
添加噪声的过程 也就是前向扩散过程 满足这个式子:
逐步添加高斯噪声到数据
x
0
x_0
x0
q
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
1
−
β
t
x
t
−
1
,
β
t
I
)
q(x_t | x_{t-1}) = \mathcal{N}\left(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I\right)
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
最终隐式表达:
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
q(x_t | x_0) = \mathcal{N}\left(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I\right)
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
其中:
- α t = 1 − β t \alpha_t = 1 - \beta_t αt=1−βt
-
α
ˉ
t
=
∏
i
=
1
t
α
i
\bar{\alpha}_t = \prod_{i=1}^t \alpha_i
αˉt=∏i=1tαi
这边的 β t \beta_t βt是自己设的
这个式子用人话来说就是由原图像加噪t
次后产生的图像(就命名为
I
t
I_t
It吧)要满足偏差为
α
ˉ
t
x
0
\sqrt{\bar{\alpha}_t} x_0
αˉtx0 方差为
(
1
−
α
ˉ
t
)
I
(1-\bar{\alpha}_t)I
(1−αˉt)I 的正态分布。
听起来是不是还是不像人话,没事代码一看便懂
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
mean = gather(self.alpha_bar, t) ** 0.5 * x0
var = 1 - gather(self.alpha_bar, t)
return mean, var
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
if eps is None:
eps = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
return mean + (var ** 0.5) * eps
也就是
I
t
I_t
It是由
I
0
I_0
I0乘上一个系数然后加上由标准正态分布采样得到的和原图像大小一致的随机噪声乘上系数得到的。
那么为什么mean + (var ** 0.5) * eps
=
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
\mathcal{N}\left(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I\right)
N(xt;αˉtx0,(1−αˉt)I) 呢?
因为这边的
e
p
s
∼
N
(
0
,
I
)
eps\sim \mathcal{N}(0, I)
eps∼N(0,I) 所以
(
v
a
r
∗
∗
0.5
)
∗
e
p
s
∼
N
(
0
,
(
1
−
α
ˉ
t
)
I
)
(var ** 0.5) * eps \sim \mathcal{N}(0,(1-\bar{\alpha}_t)I)
(var∗∗0.5)∗eps∼N(0,(1−αˉt)I) (这块看不懂去看看概率论吧) 那么
m
e
a
n
+
(
v
a
r
∗
∗
0.5
)
∗
e
p
s
∼
N
(
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
mean + (var ** 0.5) * eps \sim N(\sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I)
mean+(var∗∗0.5)∗eps∼N(αˉtx0,(1−αˉt)I) 满足了隐式表达的式子 。
说完了添加噪声,那么自然来到了如何去除噪声,前面也说过,我们训练一个网络网络输入为
I
t
I_t
It和t,输出为网络预测的第t次施加在图像上的噪声。我们把这个网络就记作
ϵ
θ
(
I
t
,
t
)
\epsilon_\theta(I_t, t)
ϵθ(It,t) ,我们的目标是使得网络预测的噪声和添加在图像上的噪声越相似越好,就得到了网络的损失函数
L
(
θ
)
=
E
t
,
x
0
,
ϵ
[
∥
ϵ
−
ϵ
θ
(
I
t
,
t
)
∥
2
]
\mathcal{L}(\theta) = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(I_t, t) \|^2 \right]
L(θ)=Et,x0,ϵ[∥ϵ−ϵθ(It,t)∥2]
训练过程就是采样,计算损失函数,反向传播更新参数。具体就不多说了
TODO:DDPM的噪声预测网络结构