永康建设投标网站宁德网站建设制作
CWGAN-GP
- 一、CWGAN-GP 原理
- 1.1 CWGAN-GP 的核心改进
- 1.2 CWGAN-GP 的损失函数
- 1.3 CWGAN-GP 的优势
- 1.4 关键参数选择
- 1.5 应用场景
- 二、CWGAN-GP 实现
- 2.1 导包
- 2.2 数据加载和处理
- 2.3 构建生成器
- 2.4 构建判别器
- 2.5 训练和保存模型
- 2.6 查看训练损失
- 2.7 图片转GIF
- 2.8 模型加载和推理
一、CWGAN-GP 原理
CWGAN-GP
(Conditional Wasserstein GAN with Gradient Penalty)是 WGAN-GP 的条件生成版本,结合了 条件生成对抗网络
(CGAN) 和 Wasserstein GAN 的梯度惩罚(GP)
,以提高生成样本的质量和训练稳定性。
1.1 CWGAN-GP 的核心改进
(1)条件生成(Conditional Generation)
- CGAN 思想:生成器 G G G 和判别器 D D D 都接收额外的条件信息 y y y(如类别标签、文本描述等),使生成的数据 G ( z ∣ y ) G(z|y) G(z∣y) 与条件 y y y 相关。
- 数学表达:
- 生成器: G ( z ∣ y ) → x f a k e G(z|y) \rightarrow x_{fake} G(z∣y)→xfake
- 判别器: D ( x ∣ y ) → 评分 D(x|y) \rightarrow \text{评分} D(x∣y)→评分
(2)Wasserstein 距离(WGAN)
- WGAN 用
Earth-Mover(EM)距离
(即 Wasserstein-1 距离)替代原始 GAN 的 JS 散度,解决梯度消失/爆炸问题:
W ( P r , P g ) = sup ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E x ∼ P g [ D ( x ) ] W(P_r, P_g) = \underset{\|D\|_L \leq 1}{\sup} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{x \sim P_g}[D(x)] W(Pr,Pg)=∥D∥L≤1supEx∼Pr[D(x)]−Ex∼Pg[D(x)]- D D D 需要是 1-Lipschitz 函数(梯度范数 ≤ 1)
(3) 梯度惩罚(Gradient Penalty, GP)
- WGAN-GP 通过
梯度惩罚
强制判别器满足 Lipschitz 约束,替代 WGAN 的权重裁剪(更稳定):
L G P = λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ∣ y ) ∥ 2 − 1 ) 2 ] \mathcal{L}_{GP} = \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2 \right] LGP=λ⋅Ex^∼Px^[(∥∇x^D(x^∣y)∥2−1)2]- x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点:
x ^ = ϵ x r e a l + ( 1 − ϵ ) x f a k e , ϵ ∼ U [ 0 , 1 ] \hat{x} = \epsilon x_{real} + (1 - \epsilon) x_{fake}, \quad \epsilon \sim U[0,1] x^=ϵxreal+(1−ϵ)xfake,ϵ∼U[0,1] - λ \lambda λ 是惩罚系数(通常设为 10)
- x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点:
1.2 CWGAN-GP 的损失函数
(1)判别器损失
L D = E x ∼ P r [ D ( x ∣ y ) ] − E z ∼ p ( z ) [ D ( G ( z ∣ y ) ∣ y ) ] ⏟ Wasserstein 距离 + λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ∣ y ) ∥ 2 − 1 ) 2 ] ⏟ 梯度惩罚 \mathcal{L}_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x|y)] - \mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2]}_{\text{梯度惩罚}} LD=Wasserstein 距离 Ex∼Pr[D(x∣y)]−Ez∼p(z)[D(G(z∣y)∣y)]+梯度惩罚 λ⋅Ex^∼Px^[(∥∇x^D(x^∣y)∥2−1)2]
(2) 生成器损失
L G = − E z ∼ p ( z ) [ D ( G ( z ∣ y ) ∣ y ) ] \mathcal{L}_G = -\mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)] LG=−Ez∼p(z)[D(G(z∣y)∣y)]
- 生成器的目标是
让判别器对生成样本的评分最大化
(即最小化 − L G -\mathcal{L}_G −LG)
1.3 CWGAN-GP 的优势
改进点 | WGAN | WGAN-GP | CWGAN-GP |
---|---|---|---|
距离度量 | Wasserstein | Wasserstein + GP | Wasserstein + GP |
条件生成 | ❌ | ❌ | ✅ |
训练稳定性 | 权重裁剪(不稳定) | 梯度惩罚(稳定) | 梯度惩罚 + 条件控制 |
模式崩溃 | 较少 | 更少 | 最少(因条件约束) |
1.4 关键参数选择
- 梯度惩罚系数 λ \lambda λ:通常设为 10
- 判别器训练次数 n c r i t i c n_{critic} ncritic:通常 3~5 次/生成器 1 次
- 学习率:建议 1 0 − 4 10^{-4} 10−4(Adam 优化器, β 1 = 0.5 , β 2 = 0.9 \beta_1=0.5, \beta_2=0.9 β1=0.5,β2=0.9)
1.5 应用场景
- 图像生成(如条件生成 MNIST 数字)
- 文本到图像合成(如 StackGAN)
- 数据增强(生成特定类别的数据)
CWGAN-GP 通过 条件输入 + 梯度惩罚
,显著提升了生成质量和训练稳定性,是许多现代 GAN 变种(如 ProGAN、StyleGAN)的基础架构
二、CWGAN-GP 实现
2.1 导包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as npimport os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchsummary import summary# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 设置日志
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
log_dir = os.path.join("./logs/cwgan-gp", time_str)
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir) os.makedirs("./img/cwgan-gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录# 超参数配置
config = {"num_epochs": 100,"batch_size": 64,"lr": 1e-4,"img_channels": 1,"img_size": 32,"features_g": 64,"features_d": 64,"z_dim": 100, "num_classes": 10, # 分类数"label_embed_dim": 10, # 嵌入维度"lambda_gp": 10, # 梯度惩罚系数"n_critic": 5, # 判别器更新次数/生成器更新1次
}
2.2 数据加载和处理
# 加载 MNIST 数据集
def load_data(config):transform = transforms.Compose([transforms.Resize(config["img_size"]),transforms.ToTensor(), # 将图像转换为张量transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1,1]])# 下载训练集和测试集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建 DataLoadertrain_loader = DataLoader(dataset=train_dataset, batch_size=config["batch_size"],shuffle=True)test_loader = DataLoader(dataset=test_dataset, batch_size=config["batch_size"], shuffle=False)return train_loader, test_loader
2.3 构建生成器
class Generator(nn.Module):"""生成器"""def __init__(self, config):super(Generator,self).__init__()# 定义嵌入层 [B]-> [B,label_embed_dim]=[64,10]self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入# 定义模型结构self.model = nn.Sequential(# 转置卷积-1: [B,z_dim + label_embed_dim,1,1]-> [B,features_g*8,4,4]nn.ConvTranspose2d(config["z_dim"] + config["label_embed_dim"],config["features_g"]*8, 4, 1, 0),nn.BatchNorm2d(config["features_g"]*8),nn.LeakyReLU(negative_slope=0.0001, inplace=True),# 转置卷积-2: [B,features_g*8,4,4]-> [B,features_g*4,8,8]nn.ConvTranspose2d(config["features_g"]*8, config["features_g"]*4, 4, 2, 1),nn.BatchNorm2d(config["features_g"]*4),nn.LeakyReLU(negative_slope=0.0001, inplace=True),# 转置卷积-3: [B,features_g*4,8,8]-> [B,features_g*2,16,16]nn.ConvTranspose2d(config["features_g"]*4, config["features_g"]*2, 4, 2, 1),nn.BatchNorm2d(config["features_g"]*2),nn.LeakyReLU(negative_slope=0.0001, inplace=True),# 转置卷积-4: [B,features_g*2,16,16]-> [B,img_channels,32,32]nn.ConvTranspose2d(config["features_g"]*2, config["img_channels"], 4, 2, 1),nn.Tanh() # 输出归一化到[-1,1] )def forward(self, z, labels): # z[B,z_dim,1,1]# 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]label_embed = self.label_embed(labels)# 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]label_embed = label_embed.unsqueeze(2).unsqueeze(3)# 拼接噪声和嵌入标签 ->[B,z_dim + label_embed_dim ,1,1]=[B,100+10,1,1]gen_input= torch.cat([z,label_embed], dim=1)# 生成图片 [B,label_embed_dim + z_dim,1,1]-> [B,img_channels,32,32]img=self.model(gen_input)return img # [B,1,32,32]
2.4 构建判别器
class Discriminator(nn.Module):def __init__(self, img_shape=(1,32,32),num_classes=10,label_embed_dim=10,features_d=64):"判别器"super(Discriminator, self).__init__()# 定义嵌入层 [B]-> [B,label_embed_dim]=[B,10]self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入# 定义模型结构self.model = nn.Sequential(# 卷积-1:[B,img_channels + label_embed_dim,32,32]-> [B,features_d,16,16]nn.Conv2d(config["img_channels"] + config["label_embed_dim"], config["features_d"], 4,2,1),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 卷积-2:[B,features_d,16,16]-> [B,features_d*2,8,8]nn.Conv2d( config["features_d"], config["features_d"]* 2, 4, 2, 1),nn.InstanceNorm2d( config["features_d"]* 2),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 卷积-3:[B,features_d*2,8,8]-> [B,features_d*4,4,4]nn.Conv2d( config["features_d"]*2, config["features_d"]*4, 4, 2, 1),nn.InstanceNorm2d( config["features_d"]*4),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 卷积-4:[B,features_d*4,4,4]-> [B,1,1,1]nn.Conv2d( config["features_d"]*4, 1, 4, 1, 0),nn.Flatten(), # 展平,4维[B,1,1,1]-> 2维[B,1,1,1])def forward(self, img, labels):# 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]label_embed = self.label_embed(labels)# 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]label_embed = label_embed.unsqueeze(2).unsqueeze(3)# 沿着空间维度(高度和宽度)进行复制扩展,使其与目标图像的空间尺寸匹配label_embed = label_embed.repeat(1, 1, img.shape[2], img.shape[3])# 拼接图片和嵌入标签 ->[B,img.shape[0]+label_embed_dim ,img.shape[2],img.shape[3]]=[B,1+10,32,32]dis_input= torch.cat([img,label_embed], dim=1)# 进行判定validity = self.model(dis_input)return validity # [B,1]
2.5 训练和保存模型
- 定义梯度惩罚函数
# 梯度惩罚计算函数
def compute_gradient_penalty(disc, real_samples, fake_samples, labels,device=device):"""计算梯度惩罚项"""# 随机插值系数alpha = torch.rand(real_samples.shape[0], 1, 1, 1).to(device)# 生成插值样本interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)# 计算判别器对插值样本的输出d_interpolates = disc(interpolates, labels)# 计算梯度gradients = torch.autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=torch.ones_like(d_interpolates),create_graph=True,retain_graph=True,only_inputs=True,)[0]# 计算梯度惩罚项 gradients = gradients.view(gradients.shape[0], -1) # 多维-> 2维 [B,*]gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty
- 定义保存生成样本
def sample_image(G,n_row, batches_done,latent_dim=100,device=device):"""Saves a grid of generated digits ranging from 0 to n_classes"""# 随机噪声-> [n_row ** 2,latent_dim]=[100,100]z=torch.normal(0,1,size=(n_row ** 2,latent_dim,1,1),device=device) #从正态分布中抽样# 条件标签->[100]labels = torch.arange(n_row, dtype=torch.long, device=device).repeat_interleave(n_row)gen_imgs = G(z, labels)save_image(gen_imgs.data, "./img/cwgan-gp_mnist/%d.png" % batches_done, nrow=n_row, normalize=True)
- 训练和保存
# 加载数据
train_loader,_= load_data(config)# 实例化生成器G、判别器D
G=Generator(config).to(device)
D=Discriminator(config).to(device)# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=config["lr"],betas=(0.5, 0.9))
optimizer_D = torch.optim.Adam(D.parameters(), lr=config["lr"],betas=(0.5, 0.9))# 开始训练
start_time = time.time() # 计时器
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(config["num_epochs"]):# 进入训练模式G.train()D.train()loop = tqdm(train_loader, desc=f"第{epoch+1}轮")for i, (real_imgs, labels) in enumerate(loop):real_imgs=real_imgs.to(device) # [B,C,H,W]labels=labels.to(device) # [B]# -----------------# 训练判别器# 1. 必须 gen_imgs.detach(),防止生成器的梯度干扰判别器更新# 2. 损失函数为 D(real_img) - D(gen_img.detach()) + 惩罚项# -----------------for _ in range(config["n_critic"]):# 获取噪声样本[B,latent_dim,1,1]及对应的条件标签 [B]z=torch.normal(0,1,size=(real_imgs.shape[0],config["z_dim"],1,1),device=device) #从正态分布中抽样# --- 计算判别器损失 ---# Step-1:对真实图片损失valid_loss=torch.mean(D(real_imgs,labels))# Step-2:对生成图片损失gen_imgs=G(z,labels) # 生成图片fake_loss=torch.mean(D(gen_imgs.detach(),labels))# Step-3:计算梯度惩罚gradient_penalty = compute_gradient_penalty(D, real_imgs.data, gen_imgs.data, labels)# Step-4:整体损失dis_loss= -(valid_loss - fake_loss) + config["lambda_gp"]* gradient_penalty# 更新判别器参数optimizer_D.zero_grad() #梯度清零dis_loss.backward() #反向传播,计算梯度optimizer_D.step() #更新判别器 # -----------------# 训练生成器# 1.禁止gen_imgs.detach(),需保持完整的计算图# 2.损失函数应为 -D(gen_img)(WGAN 的目标是最大化判别器对生成样本的评分)# -----------------# 计算生成器损失 gen_loss = -torch.mean(D(gen_imgs, labels)) # 更新生成器参数optimizer_G.zero_grad() #梯度清零gen_loss.backward() #反向传播,计算梯度optimizer_G.step() #更新生成器 # --- 记录损失到TensorBoard ---batches_done = epoch * loader_len + iif batches_done % 100 == 0:writer.add_scalars("CWGAN-GP",{"Generator": gen_loss.item(),"Discriminator": dis_loss.item(),},global_step=epoch * loader_len + i, # 全局步数)# 更新进度条loop.set_postfix(gen_loss=f"{gen_loss:.8f}",dis_loss=f"{dis_loss:.8f}")if batches_done % 400 == 0:sample_image(G,10, batches_done)writer.close()
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/CWGAN-GP_G.pth")
torch.save(D.state_dict(), "./model/CWGAN-GP_D.pth")
2.6 查看训练损失
# 加载魔术命令
%reload_ext tensorboard# 启动TensorBoard(
%tensorboard --logdir ./logs/cwgan-gp --port=6007
2.7 图片转GIF
from PIL import Imagedef create_gif(img_dir="./img/cwgan-gp_mnist", output_file="./img/cwgan-gp_mnist/cwgan-gp_figure.gif", duration=100):images = []img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]# 自定义排序:按 "x.png" 的排序img_paths_sorted = sorted(img_paths,key=lambda x: (int(x.split('.')[0]), ))for img_file in img_paths_sorted:img = Image.open(os.path.join(img_dir, img_file))images.append(img)images[0].save(output_file, save_all=True, append_images=images[1:], duration=duration, loop=0)print(f"GIF已保存至 {output_file}")
create_gif()

2.8 模型加载和推理
#载入训练好的模型
G = Generator(config) # 定义模型结构
G.load_state_dict(torch.load("./model/CWGAN-GP_G.pth",weights_only=True,map_location=device)) # 加载保存的参数
G.to(device)
G.eval() # 获取噪声样本[10,100]及对应的条件标签 [10]
z=torch.normal(0,1,size=(10,100,1,1),device=device) #从正态分布中抽样
gen_labels = torch.arange(10, dtype=torch.long, device=device) #0~9整数#生成假样本
gen_imgs=G(z,gen_labels).view(-1,32,32) # 4维->3维
gen_imgs=gen_imgs.detach().cpu().numpy()#绘制
plt.figure(figsize=(8, 5))
plt.rcParams.update({'font.family': 'serif',
})
for i in range(10):plt.subplot(2, 5, i + 1) plt.xticks([], []) plt.yticks([], []) plt.imshow(gen_imgs[i], cmap='gray') plt.title(f"Figure {i}", fontsize=12)
plt.tight_layout()
plt.show()
