当前位置: 首页 > news >正文

生成模型实战 | β-VAE详解与实现

生成模型实战 | β-VAE详解与实现

    • 0. 前言
    • 1. β-VAE 原理
      • 1.1 VAE 回顾
      • 1.2 β-VAE 核心公式
    • 2. 实现 β-VAE
    • 3. 结果分析

0. 前言

在 InfoGAN 一节中,我们讨论了潜编码解耦表征的概念及其重要性。解耦表征是指单个潜编码对单一生成因子的变化敏感,而对其他因子的变化保持相对不变。调整某个潜编码仅会改变生成输出的某一属性,而其他特性保持不变。使用 InfoGANMNIST 数据集中,可以同时控制生成数字的类别以及书写风格的倾斜度和笔画粗细。本节将通过对 VAE 损失函数进行简单修正,实现 β-VAE,可强制潜编码实现解耦。

1. β-VAE 原理

1.1 VAE 回顾

变分自编码器 (Variational Autoencoder, VAE) 通过编码器学习后验分布 qϕ(z∣x)q_\phi(z|x)qϕ(zx),解码器学习似然 pθ(x∣z)p_\theta(x|z)pθ(xz)。目标是最大化证据下界 (evidence lower bound, ELBO):

logPθ(x)−DKL(Qϕ(z∣x)∣∣Pθ(z∣x))=Ez∼Q[logPθ(x∣z)]−DKL(Qϕ(z∣x)∣∣Pθ(z))logP_\theta (x)-D_{KL}(Q_\phi (z|x)||P_\theta (z|x))=\mathbb E_{z\sim Q}[logP_\theta (x|z)]-D_{KL}(Q_\phi (z|x)||P_\theta (z)) logPθ(x)DKL(Qϕ(zx)∣∣Pθ(zx))=EzQ[logPθ(xz)]DKL(Qϕ(zx)∣∣Pθ(z))

此方程为 VAE 的核心。左侧项为待最大化的 Pθ(x)P_\theta (x)Pθ(x) 减去 Qϕ(z∣x)Q_\phi (z|x)Qϕ(zx) 与真实 Pθ(z∣x)P_\theta (z|x)Pθ(zx) 的分布误差。需注意对数运算不改变极值位置。当推理模型能准确估计 Pθ(z∣x)P_\theta (z|x)Pθ(zx) 时,DKL(Qϕ(z∣x)∣∣Pθ(z∣x))D_{KL}(Q_\phi (z|x)||P_\theta (z|x))DKL(Qϕ(zx)∣∣Pθ(zx)) 近似为零。
右侧第一项 Pθ(x∣z)P_\theta (x|z)Pθ(xz) 对应解码器,它从推理模型采样以重构输入;第二项是 Qϕ(z∣x)Q_\phi (z|x)Qϕ(zx) 与先验 Pθ(z)P_\theta (z)Pθ(z) 的距离。以上公式左侧称为变分下界 (variational lower bound) 或证据下界 (evidence lower bound, ELBO)。由于 KL 散度恒正,ELBO 构成 log(Pθ(x))log(P_\theta (x))log(Pθ(x)) 的下界。通过优化神经网络参数 ϕ\phiϕθ\thetaθ 来最大化 ELBO 意味着:

  • DKL(Qϕ(z∣x)∣∣Pθ(z∣x))→0D_{KL}(Q_\phi (z|x)||P_\theta (z|x))\rightarrow 0DKL(Qϕ(zx)∣∣Pθ(zx))0,即推理模型能更精准地将属性 xxx 编码至 zzz
  • 右侧的 logPθ(x∣z)logP_{\theta}(x|z)logPθ(xz) 被最大化,即解码器能更准确地从潜在向量 zzz 重构 xxx

1.2 β-VAE 核心公式

通过对变分自编码器 (Variational Autoencoder, VAE) 损失函数进行简单修正,可强制潜编码实现进一步解耦。具体方法是在 KL 损失项前添加大于 1 的正则化权重系数 β\betaβ
Lβ−VAE=LR+βLKL\mathcal L_{\beta-VAE} = \mathcal L_R + \beta\mathcal L_{KL} LβVAE=LR+βLKL
VAE 变体称为 β-VAEβ 的隐含效应是约束更严格的标准差,即强制后验分布 Qϕ(z∣x)Q_{\phi}(z|x)Qϕ(zx) 中的潜编码保持相互独立。当 β>1β>1β>1 时,模型更偏向压缩隐空间,但可能牺牲重构精度;β<1β<1β<1 则更注重重构。

2. 实现 β-VAE

实现 β-VAE 非常简单,只需在 kl_loss 中引入额外的 beta 因子。条件变分自编码器实际上是 β=1\beta=1β=1 时的特殊 β-VAE,其他组件保持不变。但确定 βββ 值需要反复试验,必须在重构误差与潜在代码独立性正则化之间谨慎权衡。当 β≈9\beta \approx9β9 时解耦效果达到最优;当 β>9\beta>9β>9 时,β-VAE 被迫仅学习一个解耦表征,同时抑制其他潜维度。

(1) 导入所需库、选择训练设备:

import os
import random
import numpy as np
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utilsdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

(2) 设置训练参数,准备训练数据集(本节使用 MNIST 数据集),ToTensor() 将像素缩放到 [0,1],适合重构,num_classes 用于独热编码标签:

batch_size = 128
lr = 1e-3
epochs = 50
latent_dim = 2
beta = 4.0  # Beta weight for KL
num_classes = 10
img_size = 28
channels = 1
save_dir = './cvae_beta_checkpoints'
os.makedirs(save_dir, exist_ok=True)transform = transforms.Compose([transforms.ToTensor()])  # MNIST -> [0,1]train_ds = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_ds = datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

(3) 实现 β-VAE,编码器先用卷积提取空间特征并展平,再与独热标签拼接进入全连接层,产生 mulogvar (即 qϕ(z∣x,y)q_ϕ(z|x,y)qϕ(zx,y)),解码器把 z 与独热标签拼接,再经过线性层整形回特征图,使用反卷积恢复到图片尺寸,输出 logits (未经过 Sigmoid 激活函数,便于进行二元交叉熵损失),这种设计保证条件信息在编码与解码阶段都可用,从而更好地把标签和隐变量解耦:

class CVAEBeta(nn.Module):def __init__(self, latent_dim=20, num_classes=10):super().__init__()self.latent_dim = latent_dimself.num_classes = num_classes# Encoder: conv feature extractor -> flatten -> concat label -> fc -> mu/logvarself.enc_conv = nn.Sequential(nn.Conv2d(1, 32, 4, 2, 1),  # 32 x 14 x 14nn.ReLU(True),nn.Conv2d(32, 64, 4, 2, 1),  # 64 x 7 x 7nn.ReLU(True),nn.Conv2d(64, 128, 3, 2, 1),  # 128 x 4 x 4 (approx)nn.ReLU(True))self.enc_out_dim = 128 * 4 * 4# after flatten we'll concat label one-hotself.fc_enc = nn.Linear(self.enc_out_dim + num_classes, 512)self.fc_mu = nn.Linear(512, latent_dim)self.fc_logvar = nn.Linear(512, latent_dim)# Decoder: z concat label -> fc -> reshape -> convtranspose -> logitsself.fc_dec = nn.Linear(latent_dim + num_classes, self.enc_out_dim)self.dec_deconv = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1),  # upsamplenn.ReLU(True),nn.ConvTranspose2d(64, 32, 4, 2, 1),nn.ReLU(True),nn.ConvTranspose2d(32, 1, 3, 1, 1)  # output logits)def encode(self, x, y_onehot):# x: (B,1,28,28); y_onehot: (B, num_classes)h = self.enc_conv(x)  # (B, C, H, W)h = h.view(h.size(0), -1)  # (B, enc_out_dim)# concat labelh_cat = torch.cat([h, y_onehot], dim=1)h2 = F.relu(self.fc_enc(h_cat))mu = self.fc_mu(h2)logvar = self.fc_logvar(h2)return mu, logvardef reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z, y_onehot):# z: (B, latent_dim); y_onehot: (B, num_classes)inp = torch.cat([z, y_onehot], dim=1)h = self.fc_dec(inp)h = h.view(h.size(0), 128, 4, 4)logits = self.dec_deconv(h)logits = F.interpolate(logits, size=(28, 28), mode='bilinear', align_corners=False)return logitsdef forward(self, x, y_onehot):mu, logvar = self.encode(x, y_onehot)z = self.reparameterize(mu, logvar)logits = self.decode(z, y_onehot)return logits, mu, logvarmodel = CVAEBeta(latent_dim=latent_dim, num_classes=num_classes).to(device)
print(model)

(4) 定义损失函数与训练/验证步骤:

recon_criterion = nn.BCEWithLogitsLoss(reduction='sum')  # sum over pixelsoptimizer = torch.optim.Adam(model.parameters(), lr=lr)def loss_function(x_logits, x, mu, logvar, beta=1.0):# reconstruction term (sum over pixels and batch)recon = recon_criterion(x_logits, x)# KL divergence (sum over latent dims and batch)kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())# average per batch elementloss = (recon + beta * kld) / x.size(0)return loss, recon / x.size(0), kld / x.size(0)def train_epoch(model, loader, optimizer, device, beta=1.0):model.train()stats = {'loss':0.0, 'recon':0.0, 'kld':0.0}for imgs, labels in loader:imgs = imgs.to(device)labels = labels.to(device)# one-hot labelsy_onehot = F.one_hot(labels, num_classes=num_classes).float().to(device)optimizer.zero_grad()logits, mu, logvar = model(imgs, y_onehot)loss, recon, kld = loss_function(logits, imgs, mu, logvar, beta=beta)loss.backward()optimizer.step()bsize = imgs.size(0)stats['loss'] += loss.item() * bsizestats['recon'] += recon.item() * bsizestats['kld'] += kld.item() * bsizen = len(loader.dataset)for k in stats:stats[k] /= nreturn statsdef test_epoch(model, loader, device, beta=1.0):model.eval()stats = {'loss':0.0, 'recon':0.0, 'kld':0.0}with torch.no_grad():for imgs, labels in loader:imgs = imgs.to(device)labels = labels.to(device)y_onehot = F.one_hot(labels, num_classes=num_classes).float().to(device)logits, mu, logvar = model(imgs, y_onehot)loss, recon, kld = loss_function(logits, imgs, mu, logvar, beta=beta)bsize = imgs.size(0)stats['loss'] += loss.item() * bsizestats['recon'] += recon.item() * bsizestats['kld'] += kld.item() * bsizen = len(loader.dataset)for k in stats:stats[k] /= nreturn stats

(5) 定义训练主循环:

best_val = float('inf')
history = {'train': [], 'val': []}for epoch in range(1, epochs + 1):tr = train_epoch(model, train_loader, optimizer, device, beta=beta)va = test_epoch(model, test_loader, device, beta=beta)history['train'].append(tr)history['val'].append(va)print(f"Epoch {epoch}/{epochs} | train_loss {tr['loss']:.4f} recon {tr['recon']:.4f} kld {tr['kld']:.4f} | "f"val_loss {va['loss']:.4f} recon {va['recon']:.4f} kld {va['kld']:.4f}")if va['loss'] < best_val:best_val = va['loss']torch.save({'epoch': epoch,'model_state': model.state_dict(),'optimizer_state': optimizer.state_dict(),'latent_dim': latent_dim,'beta': beta}, os.path.join(save_dir, 'cvae_beta_best.pth'))

3. 结果分析

(1) 可视化潜空间,对测试集中的每张图像,通过编码器得到 mu,将 mu 的二维坐标画成散点图并按标签着色,观察不同类别在潜空间中的聚类情况:

import matplotlib.pyplot as pltdef visualize_latent_space(model, dataloader, device, num_points=None, save_path=None):model.eval()zs = []ys = []with torch.no_grad():for imgs, labels in dataloader:imgs = imgs.to(device)labels = labels.to(device)y_onehot = F.one_hot(labels, num_classes=num_classes).float().to(device)mu, logvar = model.encode(imgs, y_onehot)# use mu (deterministic embedding) for visualizationzs.append(mu.cpu().numpy())ys.append(labels.cpu().numpy())if num_points is not None and sum(len(z) for z in zs) >= num_points:breakzs = np.concatenate(zs, axis=0)ys = np.concatenate(ys, axis=0)if num_points is not None:zs = zs[:num_points]ys = ys[:num_points]plt.figure(figsize=(8,8))cmap = plt.get_cmap('tab10')  # 10 distinct colors for MNISTfor c in range(num_classes):mask = ys == cplt.scatter(zs[mask,0], zs[mask,1], s=6, color=cmap(c), label=str(c), alpha=0.8)plt.legend(title='label')plt.xlabel('z[0]')plt.ylabel('z[1]')plt.title('Latent space (mu) colored by label')if save_path:plt.savefig(save_path, dpi=200)print("Saved latent scatter to", save_path)plt.show()visualize_latent_space(model, test_loader, device=device, num_points=None, save_path=os.path.join(save_dir, 'latent_scatter.png'))

潜空间可视化

(2) 在二维空间上做小范围平移(例如沿 xy 轴),并把同一标签传给解码器,生成一系列图像,从而观察潜空间方向上语义变化而保持类别:

from torchvision.utils import save_imagedef perturb_and_generate(model, img, label, deltas, device, out_prefix='perturb'):model.eval()img = img.to(device)label = torch.tensor([label], dtype=torch.long).to(device)y_onehot = F.one_hot(label, num_classes=num_classes).float().to(device)with torch.no_grad():mu, logvar = model.encode(img, y_onehot)         # mu: (1,2)mu = mu.squeeze(0)  # (2,)all_imgs = []for i, d in enumerate(deltas):d = torch.tensor(d, dtype=torch.float32).to(device)z = (mu + d).unsqueeze(0)  # (1,2)logits = model.decode(z, y_onehot)gen = torch.sigmoid(logits).cpu()all_imgs.append(gen)grid = torch.cat(all_imgs, dim=0)  # (N,1,28,28)save_image(grid, os.path.join(save_dir, f'{out_prefix}_label{int(label.item())}.png'), nrow=len(all_imgs), normalize=False)print(f"Saved perturbation grid to {os.path.join(save_dir, f'{out_prefix}_label{int(label.item())}.png')}")imgs_iter = iter(test_loader)
imgs_all, labels_all = next(imgs_iter)
sample_img = imgs_all[0:1]   # (1,1,28,28)
sample_label = int(labels_all[0].item())
# 生成一系列小偏移(例如在 -2..2 范围内 9 个点沿 x 轴)
deltas = [[dx, 0.0] for dx in np.linspace(-2.0, 2.0, 9)]
perturb_and_generate(model, sample_img, sample_label, deltas, device, out_prefix='perturb_x')
# 同理沿 y 轴
deltas_y = [[0.0, dy] for dy in np.linspace(-2.0, 2.0, 9)]
perturb_and_generate(model, sample_img, sample_label, deltas_y, device, out_prefix='perturb_y')

生成结果

(3) 在二维平面上构建规则网格,对每个网格点解码(属于同一类别标签),观察该类在潜空间不同位置的生成效果:

def generate_grid_around_mu(model, img, label, grid_range=2.0, grid_size=9, device='cpu', out_name='grid'):model.eval()img = img.to(device)label = torch.tensor([label], dtype=torch.long, device=device)y_onehot = F.one_hot(label, num_classes=num_classes).float().to(device)with torch.no_grad():mu, _ = model.encode(img, y_onehot)  # (1,2)mu = mu.squeeze(0)  # (2,)xs = np.linspace(-grid_range, grid_range, grid_size)ys = np.linspace(-grid_range, grid_range, grid_size)imgs = []for yi in ys[::-1]:  # top row = large yrow = []for xi in xs:offset = torch.tensor([xi, yi], dtype=torch.float32, device=device)z = (mu + offset).unsqueeze(0)logits = model.decode(z, y_onehot)gen = torch.sigmoid(logits).cpu()row.append(gen)row = torch.cat(row, dim=0)imgs.append(row)grid_imgs = torch.cat(imgs, dim=0)save_image(grid_imgs, os.path.join(save_dir, f'{out_name}_label{int(label.item())}.png'),nrow=grid_size, normalize=False)print(f"Saved grid to {os.path.join(save_dir, f'{out_name}_label{int(label.item())}.png')}")generate_grid_around_mu(model, sample_img, sample_label, grid_range=2.0, grid_size=9, device=device, out_name='grid_centered')

生成结果

通过生成结果可以看出,当 β=4\beta=4β=4 时,β-VAE 的两个潜编码实际相互独立:一个控制笔迹的倾斜角度,而另一个则决定数字的宽度与圆润度。

http://www.dtcms.com/a/570290.html

相关文章:

  • 司马阅与众创集团达成生态战略合作,构建 “综合企业服务资源 + AI智能技术”的创新赋能体系
  • 一张白纸,无限画布:SkyReels刚刚重新定义了AI视频创作
  • Java_ArrayList底层结构和源码分析
  • 局域网创建网站怎么自建一个网站
  • 网站建设问题及解决办法北京网站建设方案品牌公司
  • 网站建设电销话术开场白搜索网排名
  • 中国建设银行官网站汽车卡一级做ae视频直播可以吗多少钱
  • 电子学会青少年机器人技术(三级)等级考试试卷-理论综合(2025年9月)
  • 长沙公司核名网站wordpress的图片存在哪里
  • 【IC】NoC设计入门 -- router模块
  • 网站做项目网络营销方案策划书
  • 外贸功能网站建设电脑课程培训零基础
  • 网站建设策划公司凡科建站怎样建站中站
  • 侯捷STL标准库和泛型编程
  • BigDecimal是怎么比较大小的
  • 【MCU控制 初级手札】1.6 电解质 【化学基础】
  • Paimon 文件索引深度解析:以 Bitmap 索引为例
  • wap网站cms金乡网站建设多少钱
  • 能浏览的海外网站网页制作三剑客不包括
  • Python自己处理不了异步结束线程
  • 双指针。。。。。
  • 北京有多少家网站吉县网站建设
  • 教案怎么写模板抖音seo搜索引擎优化
  • 算法题(254):灾后重建
  • 理解全连接层:深度学习中的基础构建块
  • vs网站开发教程厦门模板建站
  • c sql网站开发wordpress搜索无效
  • 如何防止 iOS 应用资源文件被替换 工程化防护与多工具组合实战
  • 网站在线支付接口网络推广经验分享
  • 18-Python 操作 Redis 实战指南:redis-py 客户端全解析与场景落地