当前位置: 首页 > news >正文

生成模型实战 | 深度分层变分自编码器(Nouveau VAE,NVAE)

生成模型实战 | 深度分层变分自编码器(Nouveau VAE,NVAE)

    • 0. 前言
    • 1. NVAE 技术原理
      • 1.1 变分自编码器基础
      • 1.2 深度分层架构
      • 1.3 多尺度架构设计
    • 2. 残差单元与可分离卷积
    • 3. 残差参数化与后验分布
    • 4. 使用 PyTorch 构建 NVAE
      • 4.1 数据集加载
      • 4.2 模型构建与训练

0. 前言

变分自编码器 (Variational Autoencoder, VAE) 作为深度学习生成模型的重要分支,具有独特的优势,与生成对抗网络 (Generative Adversarial Network, GAN) 和自回归模型相比,VAE 具有采样速度快、计算可处理性强以及编码网络易于访问等优势。然而,传统的 VAE 模型在生成质量上往往落后于其他先进生成模型,尤其是在处理高分辨率自然图像时表现不佳。为了应对这一挑战,深度分层变分自编码器 (Nouveau VAE, NVAE) 通过神经架构设计的创新,推动了 VAE 性能的提升。

1. NVAE 技术原理

1.1 变分自编码器基础

传统变分自编码器 (Variational Autoencoder, VAE) 由编码器和解码器组成。编码器将输入数据 x x x 映射到潜空间的后验分布 q ( z ∣ x ) q(z|x) q(zx),解码器从潜在变量 z z z 重建数据 x x xVA E的训练目标是最大化证据下界 (Evidence Lower Bound, ELBO):
log⁡ p ( x ) ≥ E q ϕ ( z ∣ x ) [ log⁡ p θ ( x ∣ z ) ] − D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) \text {log}⁡p(x)≥\mathbb E_{q_ϕ(z|x)}[\text {log}⁡p_θ(x|z)]−D_{KL}(q_ϕ(z|x)||p(z)) logp(x)Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)∣∣p(z))
其中第一项是重建损失,第二项是潜空间分布与先验分布的 KL 散度

1.2 深度分层架构

NVAE 采用深度分层架构,将潜变量分为 L L L 组: z = z 1 , z 2 , . . . , z L z = {z_1, z_2, ..., z_L} z=z1,z2,...,zL,其中 z 1 z_1 z1 是最底层(最抽象)的变量, z L z_L zL 是最高层(最接近输入)的变量,形成了层次化的潜表示。这种设计使得先验和后验分布都变成了联合分布,能够在不同层次上捕获数据的抽象特征:
p θ ( x , z 1 : L ) = p θ ( x ∣ z 1 : L ) ∏ i = 1 L p θ ( z i ∣ z i + 1 : L ) q φ ( z 1 : L ∣ x ) = ∏ i = 1 L q φ ( z i ∣ z i + 1 : L , x ) p_θ(x, z_{1:L}) = p_θ(x|z_{1:L})∏_{i=1}^L p_θ(z_i|z_{i+1:L})\\ q_φ(z_{1:L}|x) = ∏_{i=1}^L q_φ(z_i|z_{i+1:L}, x) pθ(x,z1:L)=pθ(xz1:L)i=1Lpθ(zizi+1:L)qφ(z1:Lx)=i=1Lqφ(zizi+1:L,x)
这种分层设计允许模型在多个分辨率级别上处理输入数据,较低层次的组捕获细节信息,而较高层次的组捕获语义级别的抽象信息。这与人类视觉系统的层次化处理方式相似,使得模型能够生成全局一致且细节丰富的高分辨率图像。

1.3 多尺度架构设计

NVAE 采用了多尺度架构,在处理图像时在不同层次使用不同的分辨率。编码器逐步降低输入图像的分辨率,同时增加通道数;而解码器则执行相反的过程,逐步上采样并减少通道数。这种设计使得计算更加高效,同时保持了模型对细节和全局结构的表现能力。

2. 残差单元与可分离卷积

NVAE 的基础构建模块是专门设计的残差单元 (Residual Cell),这些单元格在编码器和解码器中都有使用。每个残差单元包含批量归一化 (Batch Normalization, BN)、Swish 激活函数和深度可分离卷积 (Depth-wise Separable Convolutions) 等组件。这种设计不仅保证了数值稳定性,还显著减少了参数量。NVAE 中所用残差单元如下图所示。

残差单元

深度可分离卷积是 NVAE 中的关键技术创新之一。与标准卷积相比,深度可分离卷积将卷积操作分解为两个步骤:深度卷积(对每个输入通道单独进行空间卷积)和逐点卷积( 1×1 卷积,用于组合通道信息)。这种分解大幅减少了计算复杂度和参数数量,使模型能够快速扩大感受野而不受计算资源的限制。
NVAE 中还采用了挤压和激励 (Squeeze-and-Excitation, SE) 模块来增强模型的表示能力。SE 模块通过自适应地重新校准通道特征响应,使模型能够关注最信息丰富的特征。这一机制与残差单元格的结合进一步提升了模型对重要特征的敏感性。

3. 残差参数化与后验分布

NVAE 提出了残差参数化 (residual parameterization) 方法来改进近似后验分布的表达能力。在传统 VAE 中,近似后验分布通常被假设为对角协方差高斯分布,这种假设限制了模型的表达能力。NVAE 通过残差参数化放松了这一限制,允许更灵活的后验分布形式。
具体而言,对于每个层次的潜在变量,其均值和方差不是直接从网络输出计算得到,而是基于先前层次的残差更新来计算。这种方法确保了潜在变量之间的依赖性,同时保持了计算的可处理性。数学上,这种参数化方式使得 KL 散度项可以解析计算,避免了复杂的近似方法。
残差参数化还与条件先验 (conditional prior) 的概念紧密结合。在生成过程中,每个层次的先验分布不仅依赖于先前潜在变量,还依赖于自上而下的网络传递的信息。这种设计使先验分布更加丰富和表达力强,有助于生成更高质量的样本。

4. 使用 PyTorch 构建 NVAE

接下来,我们将使用 Celeb A 人脸图像数据集训练 NVAE

4.1 数据集加载

(1) 首先,导入所需库:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torchvision.transforms as transforms
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader, Dataset
import torchvision.utils as vutils
import numpy as np
from PIL import Image
from glob import glob
import torchvision

(2) 定义图像预处理变换:

image_size = 64
transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 将图像归一化到[-1, 1]范围

(3) 创建数据集和数据加载器:

batch_size = 32
ds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8)

4.2 模型构建与训练

(1) 实现残差单元,用于构建 NVAE的编码器和解码器,使用深度可分离卷积提高效率:

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()# 第一个卷积层self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二个卷积层self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 快捷连接self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))# 应用谱归一化以稳定训练self.conv1 = spectral_norm(self.conv1)self.conv2 = spectral_norm(self.conv2)if len(self.shortcut) > 0:self.shortcut[0] = spectral_norm(self.shortcut[0])def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = F.relu(out)return out

(2) 定义 NVAE 编码器块,包含多个残差单元和下采样操作,每个编码器块处理特定分辨率的特征:

class EncoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(EncoderBlock, self).__init__()self.blocks = nn.ModuleList()# 第一个块进行下采样self.blocks.append(ResidualBlock(in_channels, out_channels, stride=stride))# 添加额外的残差块(不下采样)for _ in range(1, num_blocks):self.blocks.append(ResidualBlock(out_channels, out_channels, stride=1))def forward(self, x):for block in self.blocks:x = block(x)return x

(3) 定义 NVAE 解码器块,包含多个残差块和上采样操作,每个解码器块重建特定分辨率的特征:

class DecoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(DecoderBlock, self).__init__()self.blocks = nn.ModuleList()# 添加残差块for _ in range(num_blocks - 1):self.blocks.append(ResidualBlock(in_channels, in_channels, stride=1))# 最后一个块进行上采样if stride > 1:self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU())else:self.upsample = nn.Identity()def forward(self, x):for block in self.blocks:x = block(x)x = self.upsample(x)return x

(4) 构建完整的 NVAE 模型,包含分层编码器和解码器,采用多尺度潜空间结构:

class NVAE(nn.Module):def __init__(self, image_channels=3, latent_dim=128, num_layers=4):super(NVAE, self).__init__()self.latent_dim = latent_dimself.num_layers = num_layers# 初始卷积层self.initial_conv = nn.Sequential(nn.Conv2d(image_channels, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU())# 编码器层级self.enc_blocks = nn.ModuleList()self.enc_blocks.append(EncoderBlock(32, 64, num_blocks=2, stride=2))  # 64x64 -> 32x32self.enc_blocks.append(EncoderBlock(64, 128, num_blocks=2, stride=2)) # 32x32 -> 16x16self.enc_blocks.append(EncoderBlock(128, 256, num_blocks=2, stride=2)) # 16x16 -> 8x8self.enc_blocks.append(EncoderBlock(256, 512, num_blocks=2, stride=2)) # 8x8 -> 4x4# 潜在空间均值和对数方差预测self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)# 解码器层级self.dec_blocks = nn.ModuleList()self.dec_blocks.append(nn.Sequential(nn.Linear(latent_dim, 512 * 4 * 4),nn.ReLU()))self.dec_blocks.append(DecoderBlock(512, 256, num_blocks=2, stride=2))  # 4x4 -> 8x8self.dec_blocks.append(DecoderBlock(256, 128, num_blocks=2, stride=2))  # 8x8 -> 16x16self.dec_blocks.append(DecoderBlock(128, 64, num_blocks=2, stride=2))   # 16x16 -> 32x32self.dec_blocks.append(DecoderBlock(64, 32, num_blocks=2, stride=2))    # 32x32 -> 64x64# 最终重建层self.final_conv = nn.Sequential(nn.Conv2d(32, image_channels, kernel_size=3, stride=1, padding=1),nn.Tanh()  # 输出范围[-1, 1],与输入一致)def encode(self, x):# 编码输入图像,返回潜分布的均值和方差x = self.initial_conv(x)for block in self.enc_blocks:x = block(x)x = x.view(x.size(0), -1)  # 展平mu = self.fc_mu(x)logvar = self.fc_logvar(x)return mu, logvardef reparameterize(self, mu, logvar):# 重参数化技巧,从潜分布中采样std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):# 从潜变量解码重建图像x = self.dec_blocks[0](z)x = x.view(-1, 512, 4, 4)  # 重塑为特征图for i in range(1, len(self.dec_blocks)):x = self.dec_blocks[i](x)x = self.final_conv(x)return xdef forward(self, x):# 前向传播:编码输入图像,采样潜变量,然后解码重建mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)recon_x = self.decode(z)return recon_x, mu, logvardef sample(self, num_samples, device):# 从潜空间采样生成新图像z = torch.randn(num_samples, self.latent_dim).to(device)samples = self.decode(z)return samples

(5) 定义损失函数,结合重建损失和 KL 散度

def loss_function(recon_x, x, mu, logvar, beta=1.0):# 重建损失(使用均方误差)recon_loss = F.mse_loss(recon_x, x, reduction='sum')# KL散度损失kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())# 总损失total_loss = recon_loss + beta * kld_lossreturn total_loss, recon_loss, kld_loss

(6) 定义设备,并实例化模型和优化器:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")model = NVAE(image_channels=3, latent_dim=256, num_layers=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

(7) 定义模型训练和验证函数:

def train(model, dataloader, optimizer, epoch, device, beta=1.0):model.train()train_loss = 0recon_loss = 0kld_loss = 0for batch_idx, (data, _) in enumerate(dataloader):data = data.to(device)optimizer.zero_grad()# 前向传播recon_batch, mu, logvar = model(data)# 计算损失loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)# 反向传播loss.backward()# 梯度裁剪(防止梯度爆炸)torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()train_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()if batch_idx % 100 == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(dataloader.dataset)} 'f'({100. * batch_idx / len(dataloader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')# 计算平均损失avg_loss = train_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> Epoch: {epoch} 平均损失: {avg_loss:.4f} 'f'平均重建损失: {avg_recon:.4f} 平均KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_klddef validate(model, dataloader, device, beta=1.0):model.eval()val_loss = 0recon_loss = 0kld_loss = 0with torch.no_grad():for i, (data, _) in enumerate(dataloader):data = data.to(device)recon_batch, mu, logvar = model(data)loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)val_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()avg_loss = val_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> 验证集损失: {avg_loss:.4f} 'f'重建损失: {avg_recon:.4f} KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_kld

(8) 训练模型:

# 开始训练
num_epochs = 500
beta = 0.1  # KL散度的权重系数train_losses = []
val_losses = []for epoch in range(1, num_epochs + 1):train_loss, train_recon, train_kld = train(model, dataloader, optimizer, epoch, device, beta)val_loss, val_recon, val_kld = validate(model, dataloader, device, beta)train_losses.append(train_loss)val_losses.append(val_loss)# 更新学习率scheduler.step()# 每10个epoch保存一次模型和样本if epoch % 2 == 0:torch.save(model.state_dict(), f'nvae_celeba_epoch_{epoch}.pth')# 生成并保存样本with torch.no_grad():sample = torch.randn(16, 256).to(device)sample = model.decode(sample).cpu()vutils.save_image(sample, f'sample_epoch_{epoch}.png', nrow=4, normalize=True)# 保存重建示例test_iter = iter(dataloader)test_data, _ = next(test_iter)test_data = test_data.to(device)with torch.no_grad():recon_data, _, _ = model(test_data)comparison = torch.cat([test_data[:8], recon_data[:8]]).cpu()vutils.save_image(comparison, f'reconstruction_epoch_{epoch}.png', nrow=8, normalize=True)# 保存最终模型
torch.save(model.state_dict(), 'nvae_celeba_final.pth')

模型训练过程,模型重建效果如下所示,可以看到重建效果随着训练的进行不断得到改进:

模型训练过程

接下来,查看模型训练完成后的生成图像:

生成图像


文章转载自:

http://3hrr92H6.dskmq.cn
http://RuEaHdcg.dskmq.cn
http://L1BWNhdE.dskmq.cn
http://Ov7Imwca.dskmq.cn
http://r3xFJqtt.dskmq.cn
http://Fp8EwLuY.dskmq.cn
http://f7EzM8jh.dskmq.cn
http://Y1uy6CLg.dskmq.cn
http://Zd8km7sH.dskmq.cn
http://t3o33wjd.dskmq.cn
http://PONT9Rvb.dskmq.cn
http://wOyeCEun.dskmq.cn
http://RpTE5Rcb.dskmq.cn
http://VpBCNoTK.dskmq.cn
http://woNVTHFe.dskmq.cn
http://Cptl2x0m.dskmq.cn
http://CfeZ0GyN.dskmq.cn
http://Akbx7iW8.dskmq.cn
http://Y4wy8K86.dskmq.cn
http://SjUQhLdd.dskmq.cn
http://3bcXrBC0.dskmq.cn
http://QH28ZWRN.dskmq.cn
http://0McFg5KQ.dskmq.cn
http://Ew99vrYg.dskmq.cn
http://pslUhVSU.dskmq.cn
http://VfYwszbJ.dskmq.cn
http://MtWgalls.dskmq.cn
http://RgD1KfL0.dskmq.cn
http://rcs82c9u.dskmq.cn
http://HA8crCxr.dskmq.cn
http://www.dtcms.com/a/368660.html

相关文章:

  • Windows多开文件夹太乱?Q-Dir四窗口同屏,拖拽文件快一倍
  • 测试驱动开发 (TDD) 与 Claude Code 的协作实践详解
  • Bug 排查日记:打造高效问题定位与解决的技术秘籍
  • MySQL InnoDB索引机制
  • Nextcloud 实战:打造属于你的私有云与在线协作平台
  • linux上nexus安装教程
  • vosk语音识别实战
  • 美团发布 | LongCat-Flash最全解读,硬刚GPT-4.1、Kimi!
  • 七彩喜微高压氧舱:科技与体验的双重革新,重新定义家用氧疗新标杆
  • Gemini-2.5-Flash-Image-Preview 与 GPT-4o 图像生成能力技术差异解析​
  • 敏捷开发-Scrum(上)
  • 超越自动化:为什么说供应链的终局是“AI + 人类专家”的混合智能?
  • 一维水动力模型有限体积法(三):戈杜诺夫框架与近似黎曼求解器大全
  • 2025年互联网行业高含金量证书盘点!
  • 数据库存储大量的json文件怎么样高效的读取和分页,利用文件缓存办法不占用内存
  • springboot redis 缓存入门与实战
  • 在 vue-vben-admin(v5 版本)中,使用 ECharts 图表(豆包版)
  • 数码视讯TR100-OTT-G1_国科GK6323_安卓9_广东联通原机修改-TTL烧录包-可救砖
  • RWA 技术:让实体消费积分变身可信数字资产
  • 蚂蚁 S21 XP+ HYD 500T矿机评测:SHA-256算法与高效冷却技术的结合
  • DAY1:错题日记
  • 直播美颜SDK的技术架构剖析:人脸美型功能的实现原理与优化策略
  • Kafka 消息队列:揭秘海量数据流动的技术心脏
  • 2025 年高教社杯全国大学生数学建模竞赛C 题 NIPT 的时点选择与胎儿的异常判定详解(一)
  • 当低代码遇上AI,有趣,实在有趣
  • 从“找新家”到“走向全球”,布尔云携手涂鸦智能开启机器人新冒险
  • 低代码核心原理总结
  • rust语言 (1.88) egui (0.32.1) 学习笔记(逐行注释)(二十五)窗口图标 / 任务栏图标
  • 安科瑞基站智慧运维云平台:安全管控与节能降耗双效赋能
  • BYOFF(自定义格式函数)(79)