当前位置: 首页 > news >正文

PyTorch生成式人工智能——VQ-VAE详解与实现

PyTorch生成式人工智能——VQ-VAE详解与实现

    • 0. 前言
    • 1. VQ-VAE 技术原理
      • 1.1 引入离散潜变量
      • 1.2 向量量化
      • 1.3 损失函数
      • 1.4 指数滑动平均
      • 1.5 梯度直通 (Straight-Through)
    • 2. VQ-VAE 网络架构
    • 3. 实现 VQ-VAE
      • 3.1 模型构建
      • 3.2 模型训练
    • 相关链接

0. 前言

在传统的变分自编码器 (Variational Auto-Encoder, VAE) 中,模型学习的是一个连续的潜在表示 (latent representation)。然而,对于许多模态的数据,如语言、语音或某些图像特征,离散的潜表示往往更加自然有效。
VQ-VAE (Vector Quantised-Variational AutoEncoder) 的核心思想就是将 VAE 的连续潜变量离散化。它通过学习一个码本 (Codebook) 来实现这一点,码本是一个包含有限个嵌入向量的字典。模型不是直接输出一个连续的潜在向量,而是从码本中找出与编码器输出最接近的嵌入向量来代替它。这种离散化带来了以下优势:

  • 兼容性:离散的潜在空间可以很自然地与自回归模型(如 PixelCNN、Transformer )结合,用于强大的先验建模,从而生成高质量的新样本
  • 计算效率:对于下游任务,处理离散的 token 通常比处理连续的向量更高效
  • 可解释性:码本中的每个向量可以看作是学习到的一种“视觉单词”或基本概念

本节首先详细讲解 VQ-VAE 的技术原理,然后使用 PyTorch 从零开始实现 VQ-VAE 模型。

1. VQ-VAE 技术原理

1.1 引入离散潜变量

传统的变分自编码器 (Variational Auto-Encoder, VAE) 的潜变量 zzz 连续且服从高斯先验,重建质量与生成保真度在某些任务上(如语音、图像纹理)不够理想。VQ-VAE 通过码本 (Codebook) 引入离散潜变量:编码器将输入映射到隐空间连续向量 ze(x)z_e(x)ze(x),随后用最近邻查找从码本 {ek}k=1K\{e_k\}_{k=1}^K{ek}k=1K 选出索引 kkk,离散化为 zq(x)=ekz_q(x)=e^kzq(x)=ek,再由解码器重建。优势在于:

  • 信息瓶颈自然离散化:离散索引更像“符号化”的语义单元
  • 便于后续建模:可以对索引序列用自回归(如 PixelCNN、Transformer )建立先验模型

1.2 向量量化

对每个位置的连续表示 zez_eze,选择使欧氏距离最小的码本向量:
k=arg⁡mink∣∣ze−ek∣∣2,zq=ekk=\underset k{arg⁡min}||z_e−e_k||_2,z_q=e_k k=kargmin∣∣zeek2,zq=ek

1.3 损失函数

VQ-VAE 的核心损失包含三部分:
L=∣∣x−x^∣∣1⏟reconstruction+∣∣sg[ze]−e∣∣22⏟codebook+β∣∣ze−sg[e]∣∣22⏟commitment\mathcal L=\underbrace {||x−\hat x||_1}_{reconstruction} +  \underbrace {||sg[z_e]−e||_2^2}_{codebook}  +  \underbrace {β||z_e−sg[e]||_2^2}_{commitment} L=reconstruction∣∣xx^1+  codebook∣∣sg[ze]e22  +  commitmentβ∣∣zesg[e]22
其中:

  • sg[⋅]sg[\cdot]sg[] (stop gradient) 阻止梯度回传,即在反向传播时将其视为常数
  • 第一部分为重构损失 (Reconstruction Loss),用于最小化输入与重构的差异
  • 第二部分为码本损失 (Codebook Loss),让码本向量 eee 向编码器输出 zez_eze 靠近,通常使用 L2 损失,stop gradient 操作作用在编码器上,因此这个损失只更新码本,不更新编码器
  • 第三部分为 Commitment Loss,让编码器的输出 zez_eze 向选中的码本向量 eee 靠近,防止编码器的输出在码本空间内随意波动,通常使用 L2 损失,stop gradient 作用在码本向量上,因此这个损失只更新编码器,不更新码本
  • βββ 是一个超参数,通常取 0.25–0.5

1.4 指数滑动平均

指数滑动平均 (Exponential Moving Average, EMA) 使用聚类视角更新码本参数,可以避免显式的码本惩罚 (codebook-penalty):
Nk←γNk+(1−γ)⋅countkmk←γmk+(1−γ)⋅∑i∈kze,iek←mkNk+ϵN_k\leftarrow\gamma N_k+(1-\gamma)\cdot count_k\\ m_k\leftarrow\gamma m_k+(1-\gamma)\cdot \sum_{i\in k}z_{e,i}\\ e_k\leftarrow \frac{m_k}{N_k+\epsilon} NkγNk+(1γ)countkmkγmk+(1γ)ikze,iekNk+ϵmk
只保留 Reconstruction LossCommitment Loss,收敛更稳,码本利用率更好。

1.5 梯度直通 (Straight-Through)

量化过程 (argmin) 是不可导的,这阻碍了梯度从解码器传回编码器。VQ-VAE采用了一个巧妙的技巧:在反向传播时,直接将解码器关于 z_q 的梯度 ∂L/∂zq∂\mathcal L/∂z_qL/zq 复制给编码器的输出 zez_eze。即:
∂L/∂ze=∂L/∂zq∂\mathcal L/∂z_e = ∂\mathcal L/∂z_q L/ze=L/zq
这样,虽然量化操作本身没有梯度,但编码器仍然可以接收到来自解码器的梯度信号并进行更新。

2. VQ-VAE 网络架构

VQ-VAE 的整体结构如下图所示,其核心是一个由编码器、码本、解码器组成的架构。
网络架构

网络训练流程如下:

  • 编码器生成 zez_eze
  • 量化层选择最近码本向量 zq=ekz_q = e_kzq=ek
  • 解码器重构 x^\hat{x}x^
  • 计算总损失并反向传播,更新编码器、解码器和码本。

3. 实现 VQ-VAE

了解了 VQ-VAE 的核心原理和训练流程后,接下来,使用 PyTorch 从零开始实现 VQ-VAE 模型。

3.1 模型构建

(1) 首先,导入所需库,并定义命令行参数解析函数:

# vqvae.py
import os, math, random, argparse, time
from pathlib import Pathimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torch import ampdef save_image_grid(tensors, filename, nrow=8):grid = utils.make_grid(torch.clamp(tensors, -0.5, 0.5) + 0.5, nrow=nrow)os.makedirs(os.path.dirname(filename), exist_ok=True)utils.save_image(grid, filename)def parse_args():p = argparse.ArgumentParser()p.add_argument('--data', type=str, default='./data')p.add_argument('--epochs', type=int, default=30)p.add_argument('--batch_size', type=int, default=128)p.add_argument('--lr', type=float, default=2e-4)p.add_argument('--commit', type=float, default=0.25, help='commitment beta')p.add_argument('--ema', action='store_true', help='use EMA codebook')p.add_argument('--codebook_size', type=int, default=512)p.add_argument('--embed_dim', type=int, default=64, help='latent channel dim (before quantize)')p.add_argument('--levels', type=int, default=1, help='number of VQ levels (this demo uses 1)')p.add_argument('--ckpt', type=str, default='./checkpoints/vqvae_best.pt')p.add_argument('--eval_only', action='store_true')p.add_argument('--amp', action='store_true', help='enable AMP mixed precision')return p.parse_args()

(2) 定义残差块、编码器与解码器:

class ResBlock(nn.Module):def __init__(self, c):super().__init__()self.net = nn.Sequential(nn.ReLU(),nn.Conv2d(c, c, 3, padding=1),nn.ReLU(),nn.Conv2d(c, c, 1))def forward(self, x):return x + self.net(x)class Encoder(nn.Module):def __init__(self, in_channels=3, hidden=128, embed_dim=64):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channels, hidden//2, 4, stride=2, padding=1), # 32->16nn.ReLU(),nn.Conv2d(hidden//2, hidden, 4, stride=2, padding=1),      # 16->8nn.ReLU(),ResBlock(hidden),ResBlock(hidden),nn.ReLU(),nn.Conv2d(hidden, embed_dim, 1)  # project to z_e (C=embed_dim))def forward(self, x):return self.net(x)class Decoder(nn.Module):def __init__(self, out_channels=3, hidden=128, embed_dim=64):super().__init__()self.net = nn.Sequential(nn.Conv2d(embed_dim, hidden, 3, padding=1),ResBlock(hidden),ResBlock(hidden),nn.ReLU(),nn.ConvTranspose2d(hidden, hidden//2, 4, stride=2, padding=1), # 8->16nn.ReLU(),nn.ConvTranspose2d(hidden//2, out_channels, 4, stride=2, padding=1), # 16->32nn.Tanh()  # output in [-1,1])

(3) 实现标准向量量化器,将编码器输出的连续向量映射到码本中最近的离散嵌入向量:

class VectorQuantizer(nn.Module):def __init__(self, codebook_size=512, embed_dim=64, beta=0.25):super().__init__()self.codebook_size = codebook_sizeself.embed_dim = embed_dimself.beta = betaself.embedding = nn.Embedding(codebook_size, embed_dim)nn.init.uniform_(self.embedding.weight, -1.0 / codebook_size, 1.0 / codebook_size)@torch.no_grad()def _nearest_indices(self, z_e):# z_e: (B,C,H,W) -> (BHW,C)z = z_e.permute(0,2,3,1).contiguous().view(-1, self.embed_dim)# distances: |z|^2 + |e|^2 - 2 z e^Te = self.embedding.weight  # (K,C)z2 = (z ** 2).sum(dim=1, keepdim=True)  # (N,1)e2 = (e ** 2).sum(dim=1)                # (K,)# (N,K)distances = z2 + e2.unsqueeze(0) - 2.0 * z @ e.t()indices = distances.argmin(dim=1)       # (N,)return indicesdef forward(self, z_e):B, C, H, W = z_e.shapewith torch.no_grad():indices = self._nearest_indices(z_e)   # (BHW,)# straight-through estimatorz_q = self.embedding(indices).view(B, H, W, C).permute(0,3,1,2).contiguous()# codebook + commitment losses# codebook: ||sg[z_e] - e||^2 -> pull e toward z_eloss_cb = F.mse_loss(z_q.detach(), z_e)# commitment: beta * ||z_e - sg[e]||^2 -> pull z_e toward eloss_commit = F.mse_loss(z_e, z_q.detach())loss_vq = loss_cb + self.beta * loss_commit# straight-through: copy gradientsz_q = z_e + (z_q - z_e).detach()# perplexitywith torch.no_grad():one_hot = F.one_hot(indices, num_classes=self.codebook_size).float()probs = one_hot.mean(dim=0)perplexity = torch.exp(-(probs * (probs + 1e-10).log()).sum())return z_q, loss_vq, perplexity, indices.view(B, H, W)

(4) 定义数滑动平均 (Exponential Moving Average, EMA) 向量量化器,用 EMA 聚类更新码本:

class VectorQuantizerEMA(nn.Module):def __init__(self, codebook_size=512, embed_dim=64, beta=0.25, decay=0.99, eps=1e-5):super().__init__()self.codebook_size = codebook_sizeself.embed_dim = embed_dimself.beta = betaself.decay = decayself.eps = epsembed = torch.randn(codebook_size, embed_dim) * 0.1self.register_buffer('embedding', embed)self.register_buffer('ema_cluster_size', torch.zeros(codebook_size))self.register_buffer('ema_embed', embed.clone())@torch.no_grad()def _nearest_indices(self, z_e):z = z_e.permute(0,2,3,1).contiguous().view(-1, self.embed_dim)  # (N,C)e = self.embedding  # (K,C)z2 = (z ** 2).sum(dim=1, keepdim=True)e2 = (e ** 2).sum(dim=1)distances = z2 + e2.unsqueeze(0) - 2.0 * z @ e.t()return distances.argmin(dim=1)def forward(self, z_e):B, C, H, W = z_e.shapewith torch.no_grad():indices = self._nearest_indices(z_e)  # (BHW,)one_hot = F.one_hot(indices, num_classes=self.codebook_size).float()  # (N,K)cluster_size = one_hot.sum(dim=0)  # (K,)# EMA updatesself.ema_cluster_size.mul_(self.decay).add_(cluster_size, alpha=1 - self.decay)z_sum = (z_e.permute(0,2,3,1).contiguous().view(-1, C).unsqueeze(2) * one_hot.unsqueeze(1)).sum(dim=0)  # (C,K)self.ema_embed.mul_(self.decay).add_(z_sum.t(), alpha=1 - self.decay)  # (K,C)n = self.ema_cluster_size.sum()cluster_size = (self.ema_cluster_size + self.eps) / (n + self.codebook_size * self.eps) * nembed_normalized = self.ema_embed / cluster_size.unsqueeze(1)self.embedding.copy_(embed_normalized)z_q = self.embedding[indices].view(B, H, W, C).permute(0,3,1,2).contiguous()# only commitment termloss_commit = F.mse_loss(z_e, z_q.detach())loss_vq = self.beta * loss_commit# straight-throughz_q = z_e + (z_q - z_e).detach()with torch.no_grad():probs = one_hot.mean(dim=0)perplexity = torch.exp(-(probs * (probs + 1e-10).log()).sum())return z_q, loss_vq, perplexity, indices.view(B, H, W)

(5) 将编码器、量化器与解码器组合为端到端模型:

class VQVAE(nn.Module):def __init__(self, in_channels=3, hidden=128, embed_dim=64, codebook_size=512, beta=0.25, use_ema=False):super().__init__()self.encoder = Encoder(in_channels, hidden, embed_dim)if use_ema:self.quantizer = VectorQuantizerEMA(codebook_size, embed_dim, beta=beta)else:self.quantizer = VectorQuantizer(codebook_size, embed_dim, beta=beta)self.decoder = Decoder(in_channels, hidden, embed_dim)def forward(self, x):z_e = self.encoder(x)                          # (B, C=embed_dim, H=8, W=8)z_q, loss_vq, perplexity, indices = self.quantizer(z_e)x_hat = self.decoder(z_q)return x_hat, loss_vq, perplexity, indices

3.2 模型训练

接下来,使用 CIFAR-10 数据集训练模型。
(1) 定义数据集加载函数:

def get_dataloaders(data_root, batch_size):tfm = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(0.5, 0.5)  # for all 3 channels: (x-0.5)/0.5 -> [-1,1]])tfm_val = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])train_set = datasets.CIFAR10(data_root, train=True, download=True, transform=tfm)val_set = datasets.CIFAR10(data_root, train=False, download=True, transform=tfm_val)train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)return train_loader, val_loader

(2) 定义训练与验证循环,使用 L1 重建损失,图像更锐利,我们也可以使用 L2 损失观察图像重建效果:

def train_one_epoch(model, loader, opt, scaler, device, use_amp=True):model.train()rec_loss_meter, vq_loss_meter, ppl_meter = 0.0, 0.0, 0.0n = 0for x, _ in loader:x = x.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)with amp.autocast("cuda", enabled=use_amp):x_hat, vq_loss, perplexity, _ = model(x)rec_loss = F.l1_loss(x_hat, x)loss = rec_loss + vq_lossscaler.scale(loss).backward()scaler.step(opt)scaler.update()bs = x.size(0)rec_loss_meter += rec_loss.item() * bsvq_loss_meter += vq_loss.item() * bsppl_meter += perplexity.item() * bsn += bsreturn rec_loss_meter/n, vq_loss_meter/n, ppl_meter/n@torch.no_grad()
def evaluate(model, loader, device, save_samples_path=None, max_batches=1):model.eval()rec_loss_meter, vq_loss_meter, ppl_meter = 0.0, 0.0, 0.0n = 0saved = Falsefor i, (x, _) in enumerate(loader):x = x.to(device, non_blocking=True)x_hat, vq_loss, perplexity, _ = model(x)rec_loss = F.l1_loss(x_hat, x)bs = x.size(0)rec_loss_meter += rec_loss.item() * bsvq_loss_meter += vq_loss.item() * bsppl_meter += perplexity.item() * bsn += bsif (save_samples_path is not None) and (not saved):# 拼接输入/重建cat = torch.cat([x[:32], x_hat[:32]], dim=0).detach().cpu()save_image_grid(cat, save_samples_path, nrow=8)saved = Trueif i+1 >= max_batches:breakreturn rec_loss_meter/n, vq_loss_meter/n, ppl_meter/n

(3) 组织训练/评估流程、保存最优权重与重建结果:

def main():args = parse_args()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f'Using device: {device}')train_loader, val_loader = get_dataloaders(args.data, args.batch_size)model = VQVAE(in_channels=3,hidden=128,embed_dim=args.embed_dim,codebook_size=args.codebook_size,beta=args.commit,use_ema=args.ema).to(device)opt = torch.optim.Adam(model.parameters(), lr=args.lr)scaler = amp.GradScaler("cuda", enabled=args.amp)best_val = float('inf')ckpt_dir = Path(args.ckpt).parentckpt_dir.mkdir(parents=True, exist_ok=True)if args.eval_only and os.path.isfile(args.ckpt):print(f'Loading checkpoint: {args.ckpt}')state = torch.load(args.ckpt, map_location=device)model.load_state_dict(state['model'])rec, vq, ppl = evaluate(model, val_loader, device, save_samples_path='./samples/recon_eval.png')print(f'[Eval] rec={rec:.4f} vq={vq:.4f} ppl={ppl:.2f}')returnfor epoch in range(1, args.epochs+1):t0 = time.time()tr_rec, tr_vq, tr_ppl = train_one_epoch(model, train_loader, opt, scaler, device, use_amp=args.amp)val_rec, val_vq, val_ppl = evaluate(model, val_loader, device, save_samples_path=f'./samples/recon_epoch_{epoch:03d}.png')elapsed = time.time() - t0val_total = val_rec + val_vqprint(f'Epoch {epoch:03d} | {elapsed:.1f}s | 'f'train rec={tr_rec:.4f} vq={tr_vq:.4f} ppl={tr_ppl:.2f} || 'f'val rec={val_rec:.4f} vq={val_vq:.4f} ppl={val_ppl:.2f}')if val_total < best_val:best_val = val_totaltorch.save({'model': model.state_dict(),'args': vars(args)}, args.ckpt)print(f'  -> Saved best to {args.ckpt}')if __name__ == '__main__':main()

(4) 在命令行中使用以下命令运行模型训练过程:

# 训练(非 EMA)
python vqvae.py --epochs 20 --batch_size 128 --commit 0.25# 训练(EMA 码本)
python vqvae.py --epochs 20 --batch_size 128 --commit 0.25 --ema# 从已训练权重做重建可视化
python vqvae.py --eval_only --ckpt ./checkpoints/vqvae_best.pt

模型训练过程保存的重建图像如下所示,可以看到随着训练的进行,模型的重建效果逐步得到提升:

模型重建效果

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)——基于Transformer实现机器翻译
PyTorch生成式人工智能(26)——使用PyTorch构建GPT模型
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现

http://www.dtcms.com/a/346685.html

相关文章:

  • chapter06_应用上下文与门面模式
  • pcie实现虚拟串口
  • k8s之 Pod 资源管理与 QoS
  • 深入理解 C++ SFINAE:从编译技巧到现代元编程的演进
  • rust语言 (1.88) egui (0.32.1) 学习笔记(逐行注释)(八)按键事件
  • vscode 中自己使用的 launch.json 设置
  • SpringBoot中实现接口查询数据动态脱敏
  • 倍福下的EC-A10020-P2-24电机调试说明
  • NVIDIA Nsight Systems性能分析工具
  • ISO 22341 及ISO 22341-2:2025安全与韧性——防护安全——通过环境设计预防犯罪(CPTED)
  • 武大智能与集成导航小组!i2Nav-Robot:用于的室内外机器人导航与建图的大规模多传感器融合数据集
  • 【字母异位分组】
  • 火车头使用Post方法采集Ajax页面教程
  • 量子计算驱动的Python医疗诊断编程前沿展望(中)
  • kubernetes-dashboard使用http不登录
  • 快速了解命令行界面(CLI)的行编辑模式
  • PyTorch框架之图像识别模型与训练策略
  • 一键部署开源 Coze Studio
  • 蓝牙链路层状态机精解:从待机到连接的状态跃迁与功耗控制
  • 全面解析了Java微服务架构的设计模式
  • 新疆地州市1米分辨率土地覆盖图
  • GOLANG 接口
  • 可自定义的BMS管理系统
  • 论文阅读:Inner Monologue: Embodied Reasoning through Planning with Language Models
  • SpringBoot 自动配置深度解析:从注解原理到自定义启动器​
  • 【JVM】JVM的内存结构是怎样的?
  • 调味品生产过程优化中Ethernet/IP转ProfiNet协议下施耐德 PLC 与欧姆龙 PLC 的关键通信协同案例
  • 字符串的大小写字母转换
  • linux中文本文件操作之grep命令
  • Linux-常用文件IO函数