生成模型实战 | 深度分层变分自编码器(Nouveau VAE,NVAE)
生成模型实战 | 深度分层变分自编码器(Nouveau VAE,NVAE)
- 0. 前言
- 1. NVAE 技术原理
- 1.1 变分自编码器基础
- 1.2 深度分层架构
- 1.3 多尺度架构设计
- 2. 残差单元与可分离卷积
- 3. 残差参数化与后验分布
- 4. 使用 PyTorch 构建 NVAE
- 4.1 数据集加载
- 4.2 模型构建与训练
0. 前言
变分自编码器 (Variational Autoencoder, VAE) 作为深度学习生成模型的重要分支,具有独特的优势,与生成对抗网络 (Generative Adversarial Network, GAN) 和自回归模型相比,VAE
具有采样速度快、计算可处理性强以及编码网络易于访问等优势。然而,传统的 VAE
模型在生成质量上往往落后于其他先进生成模型,尤其是在处理高分辨率自然图像时表现不佳。为了应对这一挑战,深度分层变分自编码器 (Nouveau VAE
, NVAE
) 通过神经架构设计的创新,推动了 VAE
性能的提升。
1. NVAE 技术原理
1.1 变分自编码器基础
传统变分自编码器 (Variational Autoencoder, VAE) 由编码器和解码器组成。编码器将输入数据 x x x 映射到潜空间的后验分布 q ( z ∣ x ) q(z|x) q(z∣x),解码器从潜在变量 z z z 重建数据 x x x。VA
E的训练目标是最大化证据下界 (Evidence Lower Bound
, ELBO
):
log p ( x ) ≥ E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] − D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) \text {log}p(x)≥\mathbb E_{q_ϕ(z|x)}[\text {log}p_θ(x|z)]−D_{KL}(q_ϕ(z|x)||p(z)) logp(x)≥Eqϕ(z∣x)[logpθ(x∣z)]−DKL(qϕ(z∣x)∣∣p(z))
其中第一项是重建损失,第二项是潜空间分布与先验分布的 KL 散度
。
1.2 深度分层架构
NVAE
采用深度分层架构,将潜变量分为 L L L 组: z = z 1 , z 2 , . . . , z L z = {z_1, z_2, ..., z_L} z=z1,z2,...,zL,其中 z 1 z_1 z1 是最底层(最抽象)的变量, z L z_L zL 是最高层(最接近输入)的变量,形成了层次化的潜表示。这种设计使得先验和后验分布都变成了联合分布,能够在不同层次上捕获数据的抽象特征:
p θ ( x , z 1 : L ) = p θ ( x ∣ z 1 : L ) ∏ i = 1 L p θ ( z i ∣ z i + 1 : L ) q φ ( z 1 : L ∣ x ) = ∏ i = 1 L q φ ( z i ∣ z i + 1 : L , x ) p_θ(x, z_{1:L}) = p_θ(x|z_{1:L})∏_{i=1}^L p_θ(z_i|z_{i+1:L})\\ q_φ(z_{1:L}|x) = ∏_{i=1}^L q_φ(z_i|z_{i+1:L}, x) pθ(x,z1:L)=pθ(x∣z1:L)i=1∏Lpθ(zi∣zi+1:L)qφ(z1:L∣x)=i=1∏Lqφ(zi∣zi+1:L,x)
这种分层设计允许模型在多个分辨率级别上处理输入数据,较低层次的组捕获细节信息,而较高层次的组捕获语义级别的抽象信息。这与人类视觉系统的层次化处理方式相似,使得模型能够生成全局一致且细节丰富的高分辨率图像。
1.3 多尺度架构设计
NVAE
采用了多尺度架构,在处理图像时在不同层次使用不同的分辨率。编码器逐步降低输入图像的分辨率,同时增加通道数;而解码器则执行相反的过程,逐步上采样并减少通道数。这种设计使得计算更加高效,同时保持了模型对细节和全局结构的表现能力。
2. 残差单元与可分离卷积
NVAE
的基础构建模块是专门设计的残差单元 (Residual Cell
),这些单元格在编码器和解码器中都有使用。每个残差单元包含批量归一化 (Batch Normalization
, BN
)、Swish
激活函数和深度可分离卷积 (Depth-wise Separable Convolutions
) 等组件。这种设计不仅保证了数值稳定性,还显著减少了参数量。NVAE
中所用残差单元如下图所示。
深度可分离卷积是 NVAE
中的关键技术创新之一。与标准卷积相比,深度可分离卷积将卷积操作分解为两个步骤:深度卷积(对每个输入通道单独进行空间卷积)和逐点卷积( 1×1
卷积,用于组合通道信息)。这种分解大幅减少了计算复杂度和参数数量,使模型能够快速扩大感受野而不受计算资源的限制。
NVAE
中还采用了挤压和激励 (Squeeze-and-Excitation
, SE
) 模块来增强模型的表示能力。SE
模块通过自适应地重新校准通道特征响应,使模型能够关注最信息丰富的特征。这一机制与残差单元格的结合进一步提升了模型对重要特征的敏感性。
3. 残差参数化与后验分布
NVAE
提出了残差参数化 (residual parameterization
) 方法来改进近似后验分布的表达能力。在传统 VAE
中,近似后验分布通常被假设为对角协方差高斯分布,这种假设限制了模型的表达能力。NVAE
通过残差参数化放松了这一限制,允许更灵活的后验分布形式。
具体而言,对于每个层次的潜在变量,其均值和方差不是直接从网络输出计算得到,而是基于先前层次的残差更新来计算。这种方法确保了潜在变量之间的依赖性,同时保持了计算的可处理性。数学上,这种参数化方式使得 KL 散度
项可以解析计算,避免了复杂的近似方法。
残差参数化还与条件先验 (conditional prior
) 的概念紧密结合。在生成过程中,每个层次的先验分布不仅依赖于先前潜在变量,还依赖于自上而下的网络传递的信息。这种设计使先验分布更加丰富和表达力强,有助于生成更高质量的样本。
4. 使用 PyTorch 构建 NVAE
接下来,我们将使用 Celeb A 人脸图像数据集训练 NVAE
。
4.1 数据集加载
(1) 首先,导入所需库:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torchvision.transforms as transforms
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader, Dataset
import torchvision.utils as vutils
import numpy as np
from PIL import Image
from glob import glob
import torchvision
(2) 定义图像预处理变换:
image_size = 64
transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 将图像归一化到[-1, 1]范围
(3) 创建数据集和数据加载器:
batch_size = 32
ds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8)
4.2 模型构建与训练
(1) 实现残差单元,用于构建 NVAE
的编码器和解码器,使用深度可分离卷积提高效率:
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()# 第一个卷积层self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二个卷积层self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 快捷连接self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))# 应用谱归一化以稳定训练self.conv1 = spectral_norm(self.conv1)self.conv2 = spectral_norm(self.conv2)if len(self.shortcut) > 0:self.shortcut[0] = spectral_norm(self.shortcut[0])def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = F.relu(out)return out
(2) 定义 NVAE
编码器块,包含多个残差单元和下采样操作,每个编码器块处理特定分辨率的特征:
class EncoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(EncoderBlock, self).__init__()self.blocks = nn.ModuleList()# 第一个块进行下采样self.blocks.append(ResidualBlock(in_channels, out_channels, stride=stride))# 添加额外的残差块(不下采样)for _ in range(1, num_blocks):self.blocks.append(ResidualBlock(out_channels, out_channels, stride=1))def forward(self, x):for block in self.blocks:x = block(x)return x
(3) 定义 NVAE
解码器块,包含多个残差块和上采样操作,每个解码器块重建特定分辨率的特征:
class DecoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(DecoderBlock, self).__init__()self.blocks = nn.ModuleList()# 添加残差块for _ in range(num_blocks - 1):self.blocks.append(ResidualBlock(in_channels, in_channels, stride=1))# 最后一个块进行上采样if stride > 1:self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU())else:self.upsample = nn.Identity()def forward(self, x):for block in self.blocks:x = block(x)x = self.upsample(x)return x
(4) 构建完整的 NVAE
模型,包含分层编码器和解码器,采用多尺度潜空间结构:
class NVAE(nn.Module):def __init__(self, image_channels=3, latent_dim=128, num_layers=4):super(NVAE, self).__init__()self.latent_dim = latent_dimself.num_layers = num_layers# 初始卷积层self.initial_conv = nn.Sequential(nn.Conv2d(image_channels, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU())# 编码器层级self.enc_blocks = nn.ModuleList()self.enc_blocks.append(EncoderBlock(32, 64, num_blocks=2, stride=2)) # 64x64 -> 32x32self.enc_blocks.append(EncoderBlock(64, 128, num_blocks=2, stride=2)) # 32x32 -> 16x16self.enc_blocks.append(EncoderBlock(128, 256, num_blocks=2, stride=2)) # 16x16 -> 8x8self.enc_blocks.append(EncoderBlock(256, 512, num_blocks=2, stride=2)) # 8x8 -> 4x4# 潜在空间均值和对数方差预测self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)# 解码器层级self.dec_blocks = nn.ModuleList()self.dec_blocks.append(nn.Sequential(nn.Linear(latent_dim, 512 * 4 * 4),nn.ReLU()))self.dec_blocks.append(DecoderBlock(512, 256, num_blocks=2, stride=2)) # 4x4 -> 8x8self.dec_blocks.append(DecoderBlock(256, 128, num_blocks=2, stride=2)) # 8x8 -> 16x16self.dec_blocks.append(DecoderBlock(128, 64, num_blocks=2, stride=2)) # 16x16 -> 32x32self.dec_blocks.append(DecoderBlock(64, 32, num_blocks=2, stride=2)) # 32x32 -> 64x64# 最终重建层self.final_conv = nn.Sequential(nn.Conv2d(32, image_channels, kernel_size=3, stride=1, padding=1),nn.Tanh() # 输出范围[-1, 1],与输入一致)def encode(self, x):# 编码输入图像,返回潜分布的均值和方差x = self.initial_conv(x)for block in self.enc_blocks:x = block(x)x = x.view(x.size(0), -1) # 展平mu = self.fc_mu(x)logvar = self.fc_logvar(x)return mu, logvardef reparameterize(self, mu, logvar):# 重参数化技巧,从潜分布中采样std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):# 从潜变量解码重建图像x = self.dec_blocks[0](z)x = x.view(-1, 512, 4, 4) # 重塑为特征图for i in range(1, len(self.dec_blocks)):x = self.dec_blocks[i](x)x = self.final_conv(x)return xdef forward(self, x):# 前向传播:编码输入图像,采样潜变量,然后解码重建mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)recon_x = self.decode(z)return recon_x, mu, logvardef sample(self, num_samples, device):# 从潜空间采样生成新图像z = torch.randn(num_samples, self.latent_dim).to(device)samples = self.decode(z)return samples
(5) 定义损失函数,结合重建损失和 KL 散度
:
def loss_function(recon_x, x, mu, logvar, beta=1.0):# 重建损失(使用均方误差)recon_loss = F.mse_loss(recon_x, x, reduction='sum')# KL散度损失kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())# 总损失total_loss = recon_loss + beta * kld_lossreturn total_loss, recon_loss, kld_loss
(6) 定义设备,并实例化模型和优化器:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")model = NVAE(image_channels=3, latent_dim=256, num_layers=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
(7) 定义模型训练和验证函数:
def train(model, dataloader, optimizer, epoch, device, beta=1.0):model.train()train_loss = 0recon_loss = 0kld_loss = 0for batch_idx, (data, _) in enumerate(dataloader):data = data.to(device)optimizer.zero_grad()# 前向传播recon_batch, mu, logvar = model(data)# 计算损失loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)# 反向传播loss.backward()# 梯度裁剪(防止梯度爆炸)torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()train_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()if batch_idx % 100 == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(dataloader.dataset)} 'f'({100. * batch_idx / len(dataloader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')# 计算平均损失avg_loss = train_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> Epoch: {epoch} 平均损失: {avg_loss:.4f} 'f'平均重建损失: {avg_recon:.4f} 平均KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_klddef validate(model, dataloader, device, beta=1.0):model.eval()val_loss = 0recon_loss = 0kld_loss = 0with torch.no_grad():for i, (data, _) in enumerate(dataloader):data = data.to(device)recon_batch, mu, logvar = model(data)loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)val_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()avg_loss = val_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> 验证集损失: {avg_loss:.4f} 'f'重建损失: {avg_recon:.4f} KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_kld
(8) 训练模型:
# 开始训练
num_epochs = 500
beta = 0.1 # KL散度的权重系数train_losses = []
val_losses = []for epoch in range(1, num_epochs + 1):train_loss, train_recon, train_kld = train(model, dataloader, optimizer, epoch, device, beta)val_loss, val_recon, val_kld = validate(model, dataloader, device, beta)train_losses.append(train_loss)val_losses.append(val_loss)# 更新学习率scheduler.step()# 每10个epoch保存一次模型和样本if epoch % 2 == 0:torch.save(model.state_dict(), f'nvae_celeba_epoch_{epoch}.pth')# 生成并保存样本with torch.no_grad():sample = torch.randn(16, 256).to(device)sample = model.decode(sample).cpu()vutils.save_image(sample, f'sample_epoch_{epoch}.png', nrow=4, normalize=True)# 保存重建示例test_iter = iter(dataloader)test_data, _ = next(test_iter)test_data = test_data.to(device)with torch.no_grad():recon_data, _, _ = model(test_data)comparison = torch.cat([test_data[:8], recon_data[:8]]).cpu()vutils.save_image(comparison, f'reconstruction_epoch_{epoch}.png', nrow=8, normalize=True)# 保存最终模型
torch.save(model.state_dict(), 'nvae_celeba_final.pth')
模型训练过程,模型重建效果如下所示,可以看到重建效果随着训练的进行不断得到改进:
接下来,查看模型训练完成后的生成图像: