CWGAN-GP 原理及实现(pytorch版)
CWGAN-GP
- 一、CWGAN-GP 原理
- 1.1 CWGAN-GP 的核心改进
- 1.2 CWGAN-GP 的损失函数
- 1.3 CWGAN-GP 的优势
- 1.4 关键参数选择
- 1.5 应用场景
- 二、CWGAN-GP 实现
- 2.1 导包
- 2.2 数据加载和处理
- 2.3 构建生成器
- 2.4 构建判别器
- 2.5 训练和保存模型
- 2.6 查看训练损失
- 2.7 图片转GIF
- 2.8 模型加载和推理
一、CWGAN-GP 原理
CWGAN-GP
(Conditional Wasserstein GAN with Gradient Penalty)是 WGAN-GP 的条件生成版本,结合了 条件生成对抗网络
(CGAN) 和 Wasserstein GAN 的梯度惩罚(GP)
,以提高生成样本的质量和训练稳定性。
1.1 CWGAN-GP 的核心改进
(1)条件生成(Conditional Generation)
- CGAN 思想:生成器 G G G 和判别器 D D D 都接收额外的条件信息 y y y(如类别标签、文本描述等),使生成的数据 G ( z ∣ y ) G(z|y) G(z∣y) 与条件 y y y 相关。
- 数学表达:
- 生成器: G ( z ∣ y ) → x f a k e G(z|y) \rightarrow x_{fake} G(z∣y)→xfake
- 判别器: D ( x ∣ y ) → 评分 D(x|y) \rightarrow \text{评分} D(x∣y)→评分
(2)Wasserstein 距离(WGAN)
- WGAN 用
Earth-Mover(EM)距离
(即 Wasserstein-1 距离)替代原始 GAN 的 JS 散度,解决梯度消失/爆炸问题:
W ( P r , P g ) = sup ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E x ∼ P g [ D ( x ) ] W(P_r, P_g) = \underset{\|D\|_L \leq 1}{\sup} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{x \sim P_g}[D(x)] W(Pr,Pg)=∥D∥L≤1supEx∼Pr[D(x)]−Ex∼Pg[D(x)]- D D D 需要是 1-Lipschitz 函数(梯度范数 ≤ 1)
(3) 梯度惩罚(Gradient Penalty, GP)
- WGAN-GP 通过
梯度惩罚
强制判别器满足 Lipschitz 约束,替代 WGAN 的权重裁剪(更稳定):
L G P = λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ∣ y ) ∥ 2 − 1 ) 2 ] \mathcal{L}_{GP} = \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2 \right] LGP=λ⋅Ex^∼Px^[(∥∇x^D(x^∣y)∥2−1)2]-
x
^
\hat{x}
x^ 是真实数据和生成数据的随机插值点:
x ^ = ϵ x r e a l + ( 1 − ϵ ) x f a k e , ϵ ∼ U [ 0 , 1 ] \hat{x} = \epsilon x_{real} + (1 - \epsilon) x_{fake}, \quad \epsilon \sim U[0,1] x^=ϵxreal+(1−ϵ)xfake,ϵ∼U[0,1] - λ \lambda λ 是惩罚系数(通常设为 10)
-
x
^
\hat{x}
x^ 是真实数据和生成数据的随机插值点:
1.2 CWGAN-GP 的损失函数
(1)判别器损失
L
D
=
E
x
∼
P
r
[
D
(
x
∣
y
)
]
−
E
z
∼
p
(
z
)
[
D
(
G
(
z
∣
y
)
∣
y
)
]
⏟
Wasserstein 距离
+
λ
⋅
E
x
^
∼
P
x
^
[
(
∥
∇
x
^
D
(
x
^
∣
y
)
∥
2
−
1
)
2
]
⏟
梯度惩罚
\mathcal{L}_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x|y)] - \mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2]}_{\text{梯度惩罚}}
LD=Wasserstein 距离
Ex∼Pr[D(x∣y)]−Ez∼p(z)[D(G(z∣y)∣y)]+梯度惩罚
λ⋅Ex^∼Px^[(∥∇x^D(x^∣y)∥2−1)2]
(2) 生成器损失
L
G
=
−
E
z
∼
p
(
z
)
[
D
(
G
(
z
∣
y
)
∣
y
)
]
\mathcal{L}_G = -\mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)]
LG=−Ez∼p(z)[D(G(z∣y)∣y)]
- 生成器的目标是
让判别器对生成样本的评分最大化
(即最小化 − L G -\mathcal{L}_G −LG)
1.3 CWGAN-GP 的优势
改进点 | WGAN | WGAN-GP | CWGAN-GP |
---|---|---|---|
距离度量 | Wasserstein | Wasserstein + GP | Wasserstein + GP |
条件生成 | ❌ | ❌ | ✅ |
训练稳定性 | 权重裁剪(不稳定) | 梯度惩罚(稳定) | 梯度惩罚 + 条件控制 |
模式崩溃 | 较少 | 更少 | 最少(因条件约束) |
1.4 关键参数选择
- 梯度惩罚系数 λ \lambda λ:通常设为 10
- 判别器训练次数 n c r i t i c n_{critic} ncritic:通常 3~5 次/生成器 1 次
- 学习率:建议 1 0 − 4 10^{-4} 10−4(Adam 优化器, β 1 = 0.5 , β 2 = 0.9 \beta_1=0.5, \beta_2=0.9 β1=0.5,β2=0.9)
1.5 应用场景
- 图像生成(如条件生成 MNIST 数字)
- 文本到图像合成(如 StackGAN)
- 数据增强(生成特定类别的数据)
CWGAN-GP 通过 条件输入 + 梯度惩罚
,显著提升了生成质量和训练稳定性,是许多现代 GAN 变种(如 ProGAN、StyleGAN)的基础架构
二、CWGAN-GP 实现
2.1 导包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchsummary import summary
# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 设置日志
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
log_dir = os.path.join("./logs/cwgan-gp", time_str)
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
os.makedirs("./img/cwgan-gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录
# 超参数配置
config = {
"num_epochs": 100,
"batch_size": 64,
"lr": 1e-4,
"img_channels": 1,
"img_size": 32,
"features_g": 64,
"features_d": 64,
"z_dim": 100,
"num_classes": 10, # 分类数
"label_embed_dim": 10, # 嵌入维度
"lambda_gp": 10, # 梯度惩罚系数
"n_critic": 5, # 判别器更新次数/生成器更新1次
}
2.2 数据加载和处理
# 加载 MNIST 数据集
def load_data(config):
transform = transforms.Compose([
transforms.Resize(config["img_size"]),
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1,1]
])
# 下载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=config["batch_size"],shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=config["batch_size"], shuffle=False)
return train_loader, test_loader
2.3 构建生成器
class Generator(nn.Module):
"""生成器"""
def __init__(self, config):
super(Generator,self).__init__()
# 定义嵌入层 [B]-> [B,label_embed_dim]=[64,10]
self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入
# 定义模型结构
self.model = nn.Sequential(
# 转置卷积-1: [B,z_dim + label_embed_dim,1,1]-> [B,features_g*8,4,4]
nn.ConvTranspose2d(config["z_dim"] + config["label_embed_dim"],config["features_g"]*8, 4, 1, 0),
nn.BatchNorm2d(config["features_g"]*8),
nn.LeakyReLU(negative_slope=0.0001, inplace=True),
# 转置卷积-2: [B,features_g*8,4,4]-> [B,features_g*4,8,8]
nn.ConvTranspose2d(config["features_g"]*8, config["features_g"]*4, 4, 2, 1),
nn.BatchNorm2d(config["features_g"]*4),
nn.LeakyReLU(negative_slope=0.0001, inplace=True),
# 转置卷积-3: [B,features_g*4,8,8]-> [B,features_g*2,16,16]
nn.ConvTranspose2d(config["features_g"]*4, config["features_g"]*2, 4, 2, 1),
nn.BatchNorm2d(config["features_g"]*2),
nn.LeakyReLU(negative_slope=0.0001, inplace=True),
# 转置卷积-4: [B,features_g*2,16,16]-> [B,img_channels,32,32]
nn.ConvTranspose2d(config["features_g"]*2, config["img_channels"], 4, 2, 1),
nn.Tanh() # 输出归一化到[-1,1]
)
def forward(self, z, labels): # z[B,z_dim,1,1]
# 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]
label_embed = self.label_embed(labels)
# 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]
label_embed = label_embed.unsqueeze(2).unsqueeze(3)
# 拼接噪声和嵌入标签 ->[B,z_dim + label_embed_dim ,1,1]=[B,100+10,1,1]
gen_input= torch.cat([z,label_embed], dim=1)
# 生成图片 [B,label_embed_dim + z_dim,1,1]-> [B,img_channels,32,32]
img=self.model(gen_input)
return img # [B,1,32,32]
2.4 构建判别器
class Discriminator(nn.Module):
def __init__(self, img_shape=(1,32,32),num_classes=10,label_embed_dim=10,features_d=64):
"判别器"
super(Discriminator, self).__init__()
# 定义嵌入层 [B]-> [B,label_embed_dim]=[B,10]
self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入
# 定义模型结构
self.model = nn.Sequential(
# 卷积-1:[B,img_channels + label_embed_dim,32,32]-> [B,features_d,16,16]
nn.Conv2d(config["img_channels"] + config["label_embed_dim"], config["features_d"], 4,2,1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# 卷积-2:[B,features_d,16,16]-> [B,features_d*2,8,8]
nn.Conv2d( config["features_d"], config["features_d"]* 2, 4, 2, 1),
nn.InstanceNorm2d( config["features_d"]* 2),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# 卷积-3:[B,features_d*2,8,8]-> [B,features_d*4,4,4]
nn.Conv2d( config["features_d"]*2, config["features_d"]*4, 4, 2, 1),
nn.InstanceNorm2d( config["features_d"]*4),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# 卷积-4:[B,features_d*4,4,4]-> [B,1,1,1]
nn.Conv2d( config["features_d"]*4, 1, 4, 1, 0),
nn.Flatten(), # 展平,4维[B,1,1,1]-> 2维[B,1,1,1]
)
def forward(self, img, labels):
# 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]
label_embed = self.label_embed(labels)
# 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]
label_embed = label_embed.unsqueeze(2).unsqueeze(3)
# 沿着空间维度(高度和宽度)进行复制扩展,使其与目标图像的空间尺寸匹配
label_embed = label_embed.repeat(1, 1, img.shape[2], img.shape[3])
# 拼接图片和嵌入标签 ->[B,img.shape[0]+label_embed_dim ,img.shape[2],img.shape[3]]=[B,1+10,32,32]
dis_input= torch.cat([img,label_embed], dim=1)
# 进行判定
validity = self.model(dis_input)
return validity # [B,1]
2.5 训练和保存模型
- 定义梯度惩罚函数
# 梯度惩罚计算函数
def compute_gradient_penalty(disc, real_samples, fake_samples, labels,device=device):
"""计算梯度惩罚项"""
# 随机插值系数
alpha = torch.rand(real_samples.shape[0], 1, 1, 1).to(device)
# 生成插值样本
interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
# 计算判别器对插值样本的输出
d_interpolates = disc(interpolates, labels)
# 计算梯度
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
# 计算梯度惩罚项
gradients = gradients.view(gradients.shape[0], -1) # 多维-> 2维 [B,*]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
- 定义保存生成样本
def sample_image(G,n_row, batches_done,latent_dim=100,device=device):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# 随机噪声-> [n_row ** 2,latent_dim]=[100,100]
z=torch.normal(0,1,size=(n_row ** 2,latent_dim,1,1),device=device) #从正态分布中抽样
# 条件标签->[100]
labels = torch.arange(n_row, dtype=torch.long, device=device).repeat_interleave(n_row)
gen_imgs = G(z, labels)
save_image(gen_imgs.data, "./img/cwgan-gp_mnist/%d.png" % batches_done, nrow=n_row, normalize=True)
- 训练和保存
# 加载数据
train_loader,_= load_data(config)
# 实例化生成器G、判别器D
G=Generator(config).to(device)
D=Discriminator(config).to(device)
# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=config["lr"],betas=(0.5, 0.9))
optimizer_D = torch.optim.Adam(D.parameters(), lr=config["lr"],betas=(0.5, 0.9))
# 开始训练
start_time = time.time() # 计时器
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(config["num_epochs"]):
# 进入训练模式
G.train()
D.train()
loop = tqdm(train_loader, desc=f"第{epoch+1}轮")
for i, (real_imgs, labels) in enumerate(loop):
real_imgs=real_imgs.to(device) # [B,C,H,W]
labels=labels.to(device) # [B]
# -----------------
# 训练判别器
# 1. 必须 gen_imgs.detach(),防止生成器的梯度干扰判别器更新
# 2. 损失函数为 D(real_img) - D(gen_img.detach()) + 惩罚项
# -----------------
for _ in range(config["n_critic"]):
# 获取噪声样本[B,latent_dim,1,1]及对应的条件标签 [B]
z=torch.normal(0,1,size=(real_imgs.shape[0],config["z_dim"],1,1),device=device) #从正态分布中抽样
# --- 计算判别器损失 ---
# Step-1:对真实图片损失
valid_loss=torch.mean(D(real_imgs,labels))
# Step-2:对生成图片损失
gen_imgs=G(z,labels) # 生成图片
fake_loss=torch.mean(D(gen_imgs.detach(),labels))
# Step-3:计算梯度惩罚
gradient_penalty = compute_gradient_penalty(D, real_imgs.data, gen_imgs.data, labels)
# Step-4:整体损失
dis_loss= -(valid_loss - fake_loss) + config["lambda_gp"]* gradient_penalty
# 更新判别器参数
optimizer_D.zero_grad() #梯度清零
dis_loss.backward() #反向传播,计算梯度
optimizer_D.step() #更新判别器
# -----------------
# 训练生成器
# 1.禁止gen_imgs.detach(),需保持完整的计算图
# 2.损失函数应为 -D(gen_img)(WGAN 的目标是最大化判别器对生成样本的评分)
# -----------------
# 计算生成器损失
gen_loss = -torch.mean(D(gen_imgs, labels))
# 更新生成器参数
optimizer_G.zero_grad() #梯度清零
gen_loss.backward() #反向传播,计算梯度
optimizer_G.step() #更新生成器
# --- 记录损失到TensorBoard ---
batches_done = epoch * loader_len + i
if batches_done % 100 == 0:
writer.add_scalars(
"CWGAN-GP",
{
"Generator": gen_loss.item(),
"Discriminator": dis_loss.item(),
},
global_step=epoch * loader_len + i, # 全局步数
)
# 更新进度条
loop.set_postfix(gen_loss=f"{gen_loss:.8f}",dis_loss=f"{dis_loss:.8f}")
if batches_done % 400 == 0:
sample_image(G,10, batches_done)
writer.close()
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))
#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/CWGAN-GP_G.pth")
torch.save(D.state_dict(), "./model/CWGAN-GP_D.pth")
2.6 查看训练损失
# 加载魔术命令
%reload_ext tensorboard
# 启动TensorBoard(
%tensorboard --logdir ./logs/cwgan-gp --port=6007
2.7 图片转GIF
from PIL import Image
def create_gif(img_dir="./img/cwgan-gp_mnist", output_file="./img/cwgan-gp_mnist/cwgan-gp_figure.gif", duration=100):
images = []
img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]
# 自定义排序:按 "x.png" 的排序
img_paths_sorted = sorted(
img_paths,
key=lambda x: (
int(x.split('.')[0]),
)
)
for img_file in img_paths_sorted:
img = Image.open(os.path.join(img_dir, img_file))
images.append(img)
images[0].save(output_file, save_all=True, append_images=images[1:],
duration=duration, loop=0)
print(f"GIF已保存至 {output_file}")
create_gif()

2.8 模型加载和推理
#载入训练好的模型
G = Generator(config) # 定义模型结构
G.load_state_dict(torch.load("./model/CWGAN-GP_G.pth",weights_only=True,map_location=device)) # 加载保存的参数
G.to(device)
G.eval()
# 获取噪声样本[10,100]及对应的条件标签 [10]
z=torch.normal(0,1,size=(10,100,1,1),device=device) #从正态分布中抽样
gen_labels = torch.arange(10, dtype=torch.long, device=device) #0~9整数
#生成假样本
gen_imgs=G(z,gen_labels).view(-1,32,32) # 4维->3维
gen_imgs=gen_imgs.detach().cpu().numpy()
#绘制
plt.figure(figsize=(8, 5))
plt.rcParams.update({
'font.family': 'serif',
})
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.xticks([], [])
plt.yticks([], [])
plt.imshow(gen_imgs[i], cmap='gray')
plt.title(f"Figure {i}", fontsize=12)
plt.tight_layout()
plt.show()
