当前位置: 首页 > news >正文

PyTorch生成式人工智能——深度分层变分自编码器(NVAE)详解与实现

PyTorch生成式人工智能——深度分层变分自编码器(NVAE)详解与实现

    • 0. 前言
    • 1. NVAE 技术原理
      • 1.1 变分自编码器基础
      • 1.2 深度分层架构
      • 1.3 多尺度架构设计
    • 2. 残差单元与可分离卷积
    • 3. 残差参数化与后验分布
    • 4. 使用 PyTorch 构建 NVAE
      • 4.1 数据集加载
      • 4.2 模型构建与训练
    • 相关链接

0. 前言

变分自编码器 (Variational Autoencoder, VAE) 作为深度学习生成模型的重要分支,具有独特的优势,与生成对抗网络 (Generative Adversarial Network, GAN) 和自回归模型相比,VAE 具有采样速度快、计算可处理性强以及编码网络易于访问等优势。然而,传统的 VAE 模型在生成质量上往往落后于其他先进生成模型,尤其是在处理高分辨率自然图像时表现不佳。为了应对这一挑战,深度分层变分自编码器 (Nouveau VAE, NVAE) 通过神经架构设计的创新,推动了 VAE 性能的提升。

1. NVAE 技术原理

1.1 变分自编码器基础

传统变分自编码器 (Variational Autoencoder, VAE) 由编码器和解码器组成。编码器将输入数据 xxx 映射到潜空间的后验分布 q(z∣x)q(z|x)q(zx),解码器从潜在变量 zzz 重建数据 xxxVA E的训练目标是最大化证据下界 (Evidence Lower Bound, ELBO):
log⁡p(x)≥Eqϕ(z∣x)[log⁡pθ(x∣z)]−DKL(qϕ(z∣x)∣∣p(z))\text {log}⁡p(x)≥\mathbb E_{q_ϕ(z|x)}[\text {log}⁡p_θ(x|z)]−D_{KL}(q_ϕ(z|x)||p(z)) logp(x)Eqϕ(zx)[logpθ(xz)]DKL(qϕ(zx)∣∣p(z))
其中第一项是重建损失,第二项是潜空间分布与先验分布的 KL 散度

1.2 深度分层架构

NVAE 采用深度分层架构,将潜变量分为 LLL 组:z=z1,z2,...,zLz = {z_1, z_2, ..., z_L}z=z1,z2,...,zL,其中 z1z_1z1 是最底层(最抽象)的变量,zLz_LzL 是最高层(最接近输入)的变量,形成了层次化的潜表示。这种设计使得先验和后验分布都变成了联合分布,能够在不同层次上捕获数据的抽象特征:
pθ(x,z1:L)=pθ(x∣z1:L)∏i=1Lpθ(zi∣zi+1:L)qφ(z1:L∣x)=∏i=1Lqφ(zi∣zi+1:L,x)p_θ(x, z_{1:L}) = p_θ(x|z_{1:L})∏_{i=1}^L p_θ(z_i|z_{i+1:L})\\ q_φ(z_{1:L}|x) = ∏_{i=1}^L q_φ(z_i|z_{i+1:L}, x) pθ(x,z1:L)=pθ(xz1:L)i=1Lpθ(zizi+1:L)qφ(z1:Lx)=i=1Lqφ(zizi+1:L,x)
这种分层设计允许模型在多个分辨率级别上处理输入数据,较低层次的组捕获细节信息,而较高层次的组捕获语义级别的抽象信息。这与人类视觉系统的层次化处理方式相似,使得模型能够生成全局一致且细节丰富的高分辨率图像。

1.3 多尺度架构设计

NVAE 采用了多尺度架构,在处理图像时在不同层次使用不同的分辨率。编码器逐步降低输入图像的分辨率,同时增加通道数;而解码器则执行相反的过程,逐步上采样并减少通道数。这种设计使得计算更加高效,同时保持了模型对细节和全局结构的表现能力。

2. 残差单元与可分离卷积

NVAE 的基础构建模块是专门设计的残差单元 (Residual Cell),这些单元格在编码器和解码器中都有使用。每个残差单元包含批量归一化 (Batch Normalization, BN)、Swish 激活函数和深度可分离卷积 (Depth-wise Separable Convolutions) 等组件。这种设计不仅保证了数值稳定性,还显著减少了参数量。NVAE 中所用残差单元如下图所示。

残差单元

深度可分离卷积是 NVAE 中的关键技术创新之一。与标准卷积相比,深度可分离卷积将卷积操作分解为两个步骤:深度卷积(对每个输入通道单独进行空间卷积)和逐点卷积( 1×1 卷积,用于组合通道信息)。这种分解大幅减少了计算复杂度和参数数量,使模型能够快速扩大感受野而不受计算资源的限制。
NVAE 中还采用了挤压和激励 (Squeeze-and-Excitation, SE) 模块来增强模型的表示能力。SE 模块通过自适应地重新校准通道特征响应,使模型能够关注最信息丰富的特征。这一机制与残差单元格的结合进一步提升了模型对重要特征的敏感性。

3. 残差参数化与后验分布

NVAE 提出了残差参数化 (residual parameterization) 方法来改进近似后验分布的表达能力。在传统 VAE 中,近似后验分布通常被假设为对角协方差高斯分布,这种假设限制了模型的表达能力。NVAE 通过残差参数化放松了这一限制,允许更灵活的后验分布形式。
具体而言,对于每个层次的潜在变量,其均值和方差不是直接从网络输出计算得到,而是基于先前层次的残差更新来计算。这种方法确保了潜在变量之间的依赖性,同时保持了计算的可处理性。数学上,这种参数化方式使得 KL 散度项可以解析计算,避免了复杂的近似方法。
残差参数化还与条件先验 (conditional prior) 的概念紧密结合。在生成过程中,每个层次的先验分布不仅依赖于先前潜在变量,还依赖于自上而下的网络传递的信息。这种设计使先验分布更加丰富和表达力强,有助于生成更高质量的样本。

4. 使用 PyTorch 构建 NVAE

接下来,我们将使用 Celeb A 人脸图像数据集构建 NVAE

4.1 数据集加载

(1) 首先,导入所需库:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torchvision.transforms as transforms
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader, Dataset
import torchvision.utils as vutils
import numpy as np
from PIL import Image
from glob import glob
import torchvision

(2) 定义图像预处理变换:

image_size = 64
transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 将图像归一化到[-1, 1]范围

(3) 创建数据集和数据加载器:

batch_size = 32
ds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8)

4.2 模型构建与训练

(1) 实现残差单元,用于构建 NVAE的编码器和解码器,使用深度可分离卷积提高效率:

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()# 第一个卷积层self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二个卷积层self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 快捷连接self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))# 应用谱归一化以稳定训练self.conv1 = spectral_norm(self.conv1)self.conv2 = spectral_norm(self.conv2)if len(self.shortcut) > 0:self.shortcut[0] = spectral_norm(self.shortcut[0])def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = F.relu(out)return out

(2) 定义 NVAE 编码器块,包含多个残差单元和下采样操作,每个编码器块处理特定分辨率的特征:

class EncoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(EncoderBlock, self).__init__()self.blocks = nn.ModuleList()# 第一个块进行下采样self.blocks.append(ResidualBlock(in_channels, out_channels, stride=stride))# 添加额外的残差块(不下采样)for _ in range(1, num_blocks):self.blocks.append(ResidualBlock(out_channels, out_channels, stride=1))def forward(self, x):for block in self.blocks:x = block(x)return x

(3) 定义 NVAE 解码器块,包含多个残差块和上采样操作,每个解码器块重建特定分辨率的特征:

class DecoderBlock(nn.Module):def __init__(self, in_channels, out_channels, num_blocks, stride=2):super(DecoderBlock, self).__init__()self.blocks = nn.ModuleList()# 添加残差块for _ in range(num_blocks - 1):self.blocks.append(ResidualBlock(in_channels, in_channels, stride=1))# 最后一个块进行上采样if stride > 1:self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU())else:self.upsample = nn.Identity()def forward(self, x):for block in self.blocks:x = block(x)x = self.upsample(x)return x

(4) 构建完整的 NVAE 模型,包含分层编码器和解码器,采用多尺度潜空间结构:

class NVAE(nn.Module):def __init__(self, image_channels=3, latent_dim=128, num_layers=4):super(NVAE, self).__init__()self.latent_dim = latent_dimself.num_layers = num_layers# 初始卷积层self.initial_conv = nn.Sequential(nn.Conv2d(image_channels, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU())# 编码器层级self.enc_blocks = nn.ModuleList()self.enc_blocks.append(EncoderBlock(32, 64, num_blocks=2, stride=2))  # 64x64 -> 32x32self.enc_blocks.append(EncoderBlock(64, 128, num_blocks=2, stride=2)) # 32x32 -> 16x16self.enc_blocks.append(EncoderBlock(128, 256, num_blocks=2, stride=2)) # 16x16 -> 8x8self.enc_blocks.append(EncoderBlock(256, 512, num_blocks=2, stride=2)) # 8x8 -> 4x4# 潜在空间均值和对数方差预测self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)# 解码器层级self.dec_blocks = nn.ModuleList()self.dec_blocks.append(nn.Sequential(nn.Linear(latent_dim, 512 * 4 * 4),nn.ReLU()))self.dec_blocks.append(DecoderBlock(512, 256, num_blocks=2, stride=2))  # 4x4 -> 8x8self.dec_blocks.append(DecoderBlock(256, 128, num_blocks=2, stride=2))  # 8x8 -> 16x16self.dec_blocks.append(DecoderBlock(128, 64, num_blocks=2, stride=2))   # 16x16 -> 32x32self.dec_blocks.append(DecoderBlock(64, 32, num_blocks=2, stride=2))    # 32x32 -> 64x64# 最终重建层self.final_conv = nn.Sequential(nn.Conv2d(32, image_channels, kernel_size=3, stride=1, padding=1),nn.Tanh()  # 输出范围[-1, 1],与输入一致)def encode(self, x):# 编码输入图像,返回潜分布的均值和方差x = self.initial_conv(x)for block in self.enc_blocks:x = block(x)x = x.view(x.size(0), -1)  # 展平mu = self.fc_mu(x)logvar = self.fc_logvar(x)return mu, logvardef reparameterize(self, mu, logvar):# 重参数化技巧,从潜分布中采样std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):# 从潜变量解码重建图像x = self.dec_blocks[0](z)x = x.view(-1, 512, 4, 4)  # 重塑为特征图for i in range(1, len(self.dec_blocks)):x = self.dec_blocks[i](x)x = self.final_conv(x)return xdef forward(self, x):# 前向传播:编码输入图像,采样潜变量,然后解码重建mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)recon_x = self.decode(z)return recon_x, mu, logvardef sample(self, num_samples, device):# 从潜空间采样生成新图像z = torch.randn(num_samples, self.latent_dim).to(device)samples = self.decode(z)return samples

(5) 定义损失函数,结合重建损失和 KL 散度

def loss_function(recon_x, x, mu, logvar, beta=1.0):# 重建损失(使用均方误差)recon_loss = F.mse_loss(recon_x, x, reduction='sum')# KL散度损失kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())# 总损失total_loss = recon_loss + beta * kld_lossreturn total_loss, recon_loss, kld_loss

(6) 定义设备,并实例化模型和优化器:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")model = NVAE(image_channels=3, latent_dim=256, num_layers=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

(7) 定义模型训练和验证函数:

def train(model, dataloader, optimizer, epoch, device, beta=1.0):model.train()train_loss = 0recon_loss = 0kld_loss = 0for batch_idx, (data, _) in enumerate(dataloader):data = data.to(device)optimizer.zero_grad()# 前向传播recon_batch, mu, logvar = model(data)# 计算损失loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)# 反向传播loss.backward()# 梯度裁剪(防止梯度爆炸)torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()train_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()if batch_idx % 100 == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(dataloader.dataset)} 'f'({100. * batch_idx / len(dataloader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')# 计算平均损失avg_loss = train_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> Epoch: {epoch} 平均损失: {avg_loss:.4f} 'f'平均重建损失: {avg_recon:.4f} 平均KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_klddef validate(model, dataloader, device, beta=1.0):model.eval()val_loss = 0recon_loss = 0kld_loss = 0with torch.no_grad():for i, (data, _) in enumerate(dataloader):data = data.to(device)recon_batch, mu, logvar = model(data)loss, r_loss, k_loss = loss_function(recon_batch, data, mu, logvar, beta)val_loss += loss.item()recon_loss += r_loss.item()kld_loss += k_loss.item()avg_loss = val_loss / len(dataloader.dataset)avg_recon = recon_loss / len(dataloader.dataset)avg_kld = kld_loss / len(dataloader.dataset)print(f'====> 验证集损失: {avg_loss:.4f} 'f'重建损失: {avg_recon:.4f} KL损失: {avg_kld:.4f}')return avg_loss, avg_recon, avg_kld

(8) 训练模型:

# 开始训练
num_epochs = 500
beta = 0.1  # KL散度的权重系数train_losses = []
val_losses = []for epoch in range(1, num_epochs + 1):train_loss, train_recon, train_kld = train(model, dataloader, optimizer, epoch, device, beta)val_loss, val_recon, val_kld = validate(model, dataloader, device, beta)train_losses.append(train_loss)val_losses.append(val_loss)# 更新学习率scheduler.step()# 每10个epoch保存一次模型和样本if epoch % 2 == 0:torch.save(model.state_dict(), f'nvae_celeba_epoch_{epoch}.pth')# 生成并保存样本with torch.no_grad():sample = torch.randn(16, 256).to(device)sample = model.decode(sample).cpu()vutils.save_image(sample, f'sample_epoch_{epoch}.png', nrow=4, normalize=True)# 保存重建示例test_iter = iter(dataloader)test_data, _ = next(test_iter)test_data = test_data.to(device)with torch.no_grad():recon_data, _, _ = model(test_data)comparison = torch.cat([test_data[:8], recon_data[:8]]).cpu()vutils.save_image(comparison, f'reconstruction_epoch_{epoch}.png', nrow=8, normalize=True)# 保存最终模型
torch.save(model.state_dict(), 'nvae_celeba_final.pth')

模型训练过程,模型重建效果如下所示,可以看到重建效果随着训练的进行不断得到改进:

模型训练过程

接下来,查看模型训练完成后的生成图像:

生成图像

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)——基于Transformer实现机器翻译
PyTorch生成式人工智能——VQ-VAE详解与实现


文章转载自:

http://TQiP6nxD.pdmmL.cn
http://lOzcxK7m.pdmmL.cn
http://bOaKxukN.pdmmL.cn
http://OoPvgf1I.pdmmL.cn
http://e7D6BUI1.pdmmL.cn
http://bTND32X6.pdmmL.cn
http://oHWrNve4.pdmmL.cn
http://8kVdHQ5m.pdmmL.cn
http://LvIHuP7q.pdmmL.cn
http://Xq4Xc6DY.pdmmL.cn
http://dPuw4wbs.pdmmL.cn
http://uQWvWgYO.pdmmL.cn
http://1RCfrkZk.pdmmL.cn
http://3AV5Yq37.pdmmL.cn
http://Kk6bXVZK.pdmmL.cn
http://FJGHgjpK.pdmmL.cn
http://l7nsSOZh.pdmmL.cn
http://klErUdPk.pdmmL.cn
http://Buiy8NK1.pdmmL.cn
http://4Yl4KBcH.pdmmL.cn
http://fclYKVRW.pdmmL.cn
http://tYYSmuaA.pdmmL.cn
http://tI0Cbndi.pdmmL.cn
http://65xwsgvv.pdmmL.cn
http://gTBzDDaG.pdmmL.cn
http://QS9ez45M.pdmmL.cn
http://TpauqUu3.pdmmL.cn
http://phc6Vb8m.pdmmL.cn
http://fZDlBL35.pdmmL.cn
http://kbL7RWa2.pdmmL.cn
http://www.dtcms.com/a/370053.html

相关文章:

  • Whismer-你的定制化AI问答助手
  • Paimon——官网阅读:配置
  • FPGA会用到UVM吗?
  • 电脑外接显示屏字体和图标过大
  • 深入浅出 HarmonyOS ArkUI 3.0:基于声明式开发范式与高级状态管理构建高性能应用
  • 如何在路由器上配置DHCP服务器?
  • 计算机网络:网络设备在OSI七层模型中的工作层次和传输协议
  • Unity 如何使用ModbusTCP 和PLC通讯
  • Ribbon和LoadBalance-负载均衡
  • 性能监控shell脚本编写
  • 基于SpringBoot和uni-app开发的陪诊陪护软件系统源码
  • 记一次uniapp+nutui-uniapp搭建项目
  • 计算机网络:物理层---物理层的基本概念
  • 【Java】抽象类和接口对比+详解
  • 校园管理系统|基于SpringBoot和Vue的校园管理系统(源码+数据库+文档)
  • LeetCode5最长回文子串
  • Coze源码分析-资源库-编辑提示词-前端源码
  • 《sklearn机器学习——聚类性能指标》Contingency Matrix(列联表)详解
  • 小米笔记本电脑重装C盘教程
  • Linux RCU (Read-Copy-Update) 机制深度分析
  • 贪心算法应用:柔性制造系统(FMS)刀具分配问题详解
  • WSL Ubuntu Docker 代理自动配置教程
  • 基于Scikit-learn集成学习模型的情感分析研究与实现
  • MySQL数据库精研之旅第十七期:深度拆解事务核心(下)
  • Scikit-learn Python机器学习 - 特征降维 压缩数据 - 特征选择 - 单变量特征选择 SelectKBest - 选择Top K个特征
  • 从挑西瓜到树回归:用生活智慧理解机器学习算法
  • LabVIEW无线预警喷淋系统
  • Redis 的三种高效缓存读写策略!
  • 安装MATLAB205软件记录
  • Day28 打卡