当前位置: 首页 > news >正文

CWGAN-GP 原理及实现(pytorch版)

CWGAN-GP

  • 一、CWGAN-GP 原理
    • 1.1 CWGAN-GP 的核心改进
    • 1.2 CWGAN-GP 的损失函数
    • 1.3 CWGAN-GP 的优势
    • 1.4 关键参数选择
    • 1.5 应用场景
  • 二、CWGAN-GP 实现
    • 2.1 导包
    • 2.2 数据加载和处理
    • 2.3 构建生成器
    • 2.4 构建判别器
    • 2.5 训练和保存模型
    • 2.6 查看训练损失
    • 2.7 图片转GIF
    • 2.8 模型加载和推理

一、CWGAN-GP 原理

CWGAN-GPConditional Wasserstein GAN with Gradient Penalty)是 WGAN-GP 的条件生成版本,结合了 条件生成对抗网络(CGAN)Wasserstein GAN 的梯度惩罚(GP),以提高生成样本的质量和训练稳定性。

1.1 CWGAN-GP 的核心改进

(1)条件生成(Conditional Generation)

  • CGAN 思想:生成器 G G G 和判别器 D D D 都接收额外的条件信息 y y y(如类别标签、文本描述等),使生成的数据 G ( z ∣ y ) G(z|y) G(zy) 与条件 y y y 相关。
  • 数学表达
    • 生成器: G ( z ∣ y ) → x f a k e G(z|y) \rightarrow x_{fake} G(zy)xfake
    • 判别器: D ( x ∣ y ) → 评分 D(x|y) \rightarrow \text{评分} D(xy)评分

(2)Wasserstein 距离(WGAN)

  • WGAN 用 Earth-Mover(EM)距离(即 Wasserstein-1 距离)替代原始 GAN 的 JS 散度,解决梯度消失/爆炸问题:
    W ( P r , P g ) = sup ⁡ ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E x ∼ P g [ D ( x ) ] W(P_r, P_g) = \underset{\|D\|_L \leq 1}{\sup} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{x \sim P_g}[D(x)] W(Pr,Pg)=DL1supExPr[D(x)]ExPg[D(x)]
    • D D D 需要是 1-Lipschitz 函数(梯度范数 ≤ 1)

(3) 梯度惩罚(Gradient Penalty, GP)

  • WGAN-GP 通过 梯度惩罚 强制判别器满足 Lipschitz 约束,替代 WGAN 的权重裁剪(更稳定):
    L G P = λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ∣ y ) ∥ 2 − 1 ) 2 ] \mathcal{L}_{GP} = \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2 \right] LGP=λEx^Px^[(x^D(x^y)21)2]
    • x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点:
      x ^ = ϵ x r e a l + ( 1 − ϵ ) x f a k e , ϵ ∼ U [ 0 , 1 ] \hat{x} = \epsilon x_{real} + (1 - \epsilon) x_{fake}, \quad \epsilon \sim U[0,1] x^=ϵxreal+(1ϵ)xfake,ϵU[0,1]
    • λ \lambda λ 是惩罚系数(通常设为 10)

1.2 CWGAN-GP 的损失函数

(1)判别器损失
L D = E x ∼ P r [ D ( x ∣ y ) ] − E z ∼ p ( z ) [ D ( G ( z ∣ y ) ∣ y ) ] ⏟ Wasserstein 距离 + λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ∣ y ) ∥ 2 − 1 ) 2 ] ⏟ 梯度惩罚 \mathcal{L}_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x|y)] - \mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x}|y)\|_2 - 1)^2]}_{\text{梯度惩罚}} LD=Wasserstein 距离 ExPr[D(xy)]Ezp(z)[D(G(zy)y)]+梯度惩罚 λEx^Px^[(x^D(x^y)21)2]

(2) 生成器损失
L G = − E z ∼ p ( z ) [ D ( G ( z ∣ y ) ∣ y ) ] \mathcal{L}_G = -\mathbb{E}_{z \sim p(z)}[D(G(z|y)|y)] LG=Ezp(z)[D(G(zy)y)]

  • 生成器的目标是让判别器对生成样本的评分最大化(即最小化 − L G -\mathcal{L}_G LG

1.3 CWGAN-GP 的优势

改进点WGANWGAN-GPCWGAN-GP
距离度量WassersteinWasserstein + GPWasserstein + GP
条件生成
训练稳定性权重裁剪(不稳定)梯度惩罚(稳定)梯度惩罚 + 条件控制
模式崩溃较少更少最少(因条件约束)

1.4 关键参数选择

  • 梯度惩罚系数 λ \lambda λ:通常设为 10
  • 判别器训练次数 n c r i t i c n_{critic} ncritic:通常 3~5 次/生成器 1 次
  • 学习率:建议 1 0 − 4 10^{-4} 104(Adam 优化器, β 1 = 0.5 , β 2 = 0.9 \beta_1=0.5, \beta_2=0.9 β1=0.5,β2=0.9

1.5 应用场景

  • 图像生成(如条件生成 MNIST 数字)
  • 文本到图像合成(如 StackGAN)
  • 数据增强(生成特定类别的数据)

CWGAN-GP 通过 条件输入 + 梯度惩罚显著提升了生成质量和训练稳定性,是许多现代 GAN 变种(如 ProGAN、StyleGAN)的基础架构


二、CWGAN-GP 实现

2.1 导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

import os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  
from torchsummary import summary

# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 设置日志
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) 
log_dir = os.path.join("./logs/cwgan-gp", time_str) 
os.makedirs(log_dir, exist_ok=True) 
writer = SummaryWriter(log_dir=log_dir) 

os.makedirs("./img/cwgan-gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录

# 超参数配置
config = {
    "num_epochs": 100,
    "batch_size": 64,
    "lr": 1e-4,
    "img_channels": 1,
    "img_size": 32,
    "features_g": 64,
    "features_d": 64,
    "z_dim": 100, 
    "num_classes": 10, # 分类数
    "label_embed_dim": 10,  # 嵌入维度
    "lambda_gp": 10,  # 梯度惩罚系数
    "n_critic": 5,    # 判别器更新次数/生成器更新1次     
}

2.2 数据加载和处理

# 加载 MNIST 数据集
def load_data(config):
    transform = transforms.Compose([
        transforms.Resize(config["img_size"]),
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到[-1,1]
    ])
    
    # 下载训练集和测试集
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 创建 DataLoader
    train_loader = DataLoader(dataset=train_dataset, batch_size=config["batch_size"],shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=config["batch_size"], shuffle=False)
    return train_loader, test_loader 

2.3 构建生成器

class Generator(nn.Module):
    """生成器"""
    def __init__(self, config):
               
        super(Generator,self).__init__()

        # 定义嵌入层 [B]-> [B,label_embed_dim]=[64,10]
        self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入

        # 定义模型结构
        self.model = nn.Sequential(
            # 转置卷积-1: [B,z_dim + label_embed_dim,1,1]-> [B,features_g*8,4,4]
            nn.ConvTranspose2d(config["z_dim"] + config["label_embed_dim"],config["features_g"]*8, 4, 1, 0),
            nn.BatchNorm2d(config["features_g"]*8),
            nn.LeakyReLU(negative_slope=0.0001, inplace=True),
            
            # 转置卷积-2: [B,features_g*8,4,4]-> [B,features_g*4,8,8]
            nn.ConvTranspose2d(config["features_g"]*8, config["features_g"]*4, 4, 2, 1),
            nn.BatchNorm2d(config["features_g"]*4),
            nn.LeakyReLU(negative_slope=0.0001, inplace=True),
            
            # 转置卷积-3: [B,features_g*4,8,8]-> [B,features_g*2,16,16]
            nn.ConvTranspose2d(config["features_g"]*4, config["features_g"]*2, 4, 2, 1),
            nn.BatchNorm2d(config["features_g"]*2),
            nn.LeakyReLU(negative_slope=0.0001, inplace=True),
            
            # 转置卷积-4: [B,features_g*2,16,16]-> [B,img_channels,32,32]
            nn.ConvTranspose2d(config["features_g"]*2, config["img_channels"], 4, 2, 1),
            nn.Tanh() # 输出归一化到[-1,1]  
        )
    
    def forward(self, z, labels): # z[B,z_dim,1,1]
        # 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]
        label_embed = self.label_embed(labels)
        
        # 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]
        label_embed = label_embed.unsqueeze(2).unsqueeze(3)
        
        # 拼接噪声和嵌入标签 ->[B,z_dim + label_embed_dim ,1,1]=[B,100+10,1,1]
        gen_input= torch.cat([z,label_embed], dim=1)
       
        # 生成图片 [B,label_embed_dim + z_dim,1,1]-> [B,img_channels,32,32]
        img=self.model(gen_input)
        return img # [B,1,32,32]

2.4 构建判别器

class Discriminator(nn.Module):
    def __init__(self, img_shape=(1,32,32),num_classes=10,label_embed_dim=10,features_d=64):
        "判别器"
        super(Discriminator, self).__init__()

        # 定义嵌入层 [B]-> [B,label_embed_dim]=[B,10]
        self.label_embed = nn.Embedding(config["num_classes"], config["label_embed_dim"]) # num_classes 个类别, label_embed_dim 维嵌入

        # 定义模型结构
        self.model = nn.Sequential(
            # 卷积-1:[B,img_channels + label_embed_dim,32,32]-> [B,features_d,16,16]
            nn.Conv2d(config["img_channels"] + config["label_embed_dim"], config["features_d"], 4,2,1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            # 卷积-2:[B,features_d,16,16]-> [B,features_d*2,8,8]
            nn.Conv2d( config["features_d"],  config["features_d"]* 2, 4, 2, 1),
            nn.InstanceNorm2d( config["features_d"]* 2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            # 卷积-3:[B,features_d*2,8,8]-> [B,features_d*4,4,4]
            nn.Conv2d( config["features_d"]*2,  config["features_d"]*4, 4, 2, 1),
            nn.InstanceNorm2d( config["features_d"]*4),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # 卷积-4:[B,features_d*4,4,4]-> [B,1,1,1]
            nn.Conv2d( config["features_d"]*4, 1, 4, 1, 0),
            nn.Flatten(), # 展平,4维[B,1,1,1]-> 2维[B,1,1,1]
        )
    
    def forward(self, img, labels):
        # 嵌入标签 [B]-> [B,label_embed_dim]=[B,10]
        label_embed = self.label_embed(labels)
        
        # 标签升维 [B,label_embed_dim]-> [B,label_embed_dim,1,1]=[B,10,1,1]
        label_embed = label_embed.unsqueeze(2).unsqueeze(3)

        # 沿着空间维度(高度和宽度)进行复制扩展,使其与目标图像的空间尺寸匹配
        label_embed = label_embed.repeat(1, 1, img.shape[2], img.shape[3])

        # 拼接图片和嵌入标签 ->[B,img.shape[0]+label_embed_dim ,img.shape[2],img.shape[3]]=[B,1+10,32,32]
        dis_input= torch.cat([img,label_embed], dim=1)
        
        # 进行判定
        validity = self.model(dis_input)
        return validity # [B,1]

2.5 训练和保存模型

  1. 定义梯度惩罚函数
# 梯度惩罚计算函数
def compute_gradient_penalty(disc, real_samples, fake_samples, labels,device=device):
    """计算梯度惩罚项"""
    # 随机插值系数
    alpha = torch.rand(real_samples.shape[0], 1, 1, 1).to(device)
    # 生成插值样本
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    # 计算判别器对插值样本的输出
    d_interpolates = disc(interpolates, labels)
    # 计算梯度
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    # 计算梯度惩罚项 
    gradients = gradients.view(gradients.shape[0], -1) # 多维-> 2维 [B,*]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty
  1. 定义保存生成样本
def sample_image(G,n_row, batches_done,latent_dim=100,device=device):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # 随机噪声-> [n_row ** 2,latent_dim]=[100,100]
    z=torch.normal(0,1,size=(n_row ** 2,latent_dim,1,1),device=device)  #从正态分布中抽样
    # 条件标签->[100]
    labels = torch.arange(n_row, dtype=torch.long, device=device).repeat_interleave(n_row)
    gen_imgs = G(z, labels)
    save_image(gen_imgs.data, "./img/cwgan-gp_mnist/%d.png" % batches_done, nrow=n_row, normalize=True)
  1. 训练和保存
# 加载数据
train_loader,_= load_data(config)

# 实例化生成器G、判别器D
G=Generator(config).to(device)
D=Discriminator(config).to(device)

# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=config["lr"],betas=(0.5, 0.9))
optimizer_D = torch.optim.Adam(D.parameters(), lr=config["lr"],betas=(0.5, 0.9))


# 开始训练
start_time = time.time()  # 计时器
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(config["num_epochs"]):
    # 进入训练模式
    G.train()
    D.train()
     
    loop = tqdm(train_loader, desc=f"第{epoch+1}轮")
    for i, (real_imgs, labels) in enumerate(loop):
        real_imgs=real_imgs.to(device)  # [B,C,H,W]
        labels=labels.to(device) # [B]

        # -----------------
        #  训练判别器
        # 1. 必须 gen_imgs.detach(),防止生成器的梯度干扰判别器更新
        # 2. 损失函数为 D(real_img) - D(gen_img.detach()) + 惩罚项
        # -----------------
        
        for _ in range(config["n_critic"]):
            # 获取噪声样本[B,latent_dim,1,1]及对应的条件标签 [B]
            z=torch.normal(0,1,size=(real_imgs.shape[0],config["z_dim"],1,1),device=device)  #从正态分布中抽样
            
            # --- 计算判别器损失 ---
            # Step-1:对真实图片损失
            valid_loss=torch.mean(D(real_imgs,labels))
                                  
            # Step-2:对生成图片损失
            gen_imgs=G(z,labels) # 生成图片
            fake_loss=torch.mean(D(gen_imgs.detach(),labels))
            
            # Step-3:计算梯度惩罚
            gradient_penalty = compute_gradient_penalty(D, real_imgs.data, gen_imgs.data, labels)
            
            # Step-4:整体损失
            dis_loss= -(valid_loss - fake_loss) + config["lambda_gp"]* gradient_penalty

            # 更新判别器参数
            optimizer_D.zero_grad() #梯度清零
            dis_loss.backward() #反向传播,计算梯度
            optimizer_D.step()  #更新判别器   
        
        # -----------------
        #  训练生成器
        # 1.禁止gen_imgs.detach(),需保持完整的计算图
        # 2.损失函数应为 -D(gen_img)(WGAN 的目标是最大化判别器对生成样本的评分)
        # -----------------

        # 计算生成器损失       
        gen_loss = -torch.mean(D(gen_imgs, labels)) 
        
        # 更新生成器参数
        optimizer_G.zero_grad() #梯度清零
        gen_loss.backward() #反向传播,计算梯度
        optimizer_G.step()  #更新生成器  


        # --- 记录损失到TensorBoard ---
        batches_done = epoch * loader_len + i
        if batches_done % 100 == 0:
            writer.add_scalars(
                "CWGAN-GP",
                {
                    "Generator": gen_loss.item(),
                    "Discriminator": dis_loss.item(),
                },
                global_step=epoch * loader_len + i,  # 全局步数
            )

            # 更新进度条
            loop.set_postfix(gen_loss=f"{gen_loss:.8f}",dis_loss=f"{dis_loss:.8f}")

        if batches_done % 400 == 0:
            sample_image(G,10, batches_done)
            
writer.close()
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))

#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/CWGAN-GP_G.pth") 
torch.save(D.state_dict(), "./model/CWGAN-GP_D.pth") 

2.6 查看训练损失

# 加载魔术命令
%reload_ext tensorboard

# 启动TensorBoard(
%tensorboard --logdir ./logs/cwgan-gp --port=6007 

在这里插入图片描述

2.7 图片转GIF

from PIL import Image

def create_gif(img_dir="./img/cwgan-gp_mnist", output_file="./img/cwgan-gp_mnist/cwgan-gp_figure.gif", duration=100):
    images = []
    img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]
    
    # 自定义排序:按 "x.png" 的排序
    img_paths_sorted = sorted(
        img_paths,
        key=lambda x: (
            int(x.split('.')[0]),  
        )
    )
    
    for img_file in img_paths_sorted:
        img = Image.open(os.path.join(img_dir, img_file))
        images.append(img)
    
    images[0].save(output_file, save_all=True, append_images=images[1:], 
                  duration=duration, loop=0)
    print(f"GIF已保存至 {output_file}")
create_gif()

2.8 模型加载和推理

#载入训练好的模型
G = Generator(config) # 定义模型结构
G.load_state_dict(torch.load("./model/CWGAN-GP_G.pth",weights_only=True,map_location=device)) # 加载保存的参数
G.to(device) 
G.eval() 

# 获取噪声样本[10,100]及对应的条件标签 [10]
z=torch.normal(0,1,size=(10,100,1,1),device=device)  #从正态分布中抽样
gen_labels = torch.arange(10, dtype=torch.long, device=device) #0~9整数

#生成假样本
gen_imgs=G(z,gen_labels).view(-1,32,32) # 4维->3维
gen_imgs=gen_imgs.detach().cpu().numpy()

#绘制
plt.figure(figsize=(8, 5)) 
plt.rcParams.update({
    'font.family': 'serif',  
})
for i in range(10):
    plt.subplot(2, 5, i + 1)  
    plt.xticks([], []) 
    plt.yticks([], [])  
    plt.imshow(gen_imgs[i], cmap='gray')  
    plt.title(f"Figure {i}", fontsize=12)  
plt.tight_layout()  
plt.show()

相关文章:

  • CISCO组建RIP V2路由网络
  • 【面试分享】Spring Boot 面试题及答案整理,最新面试题
  • Android Spotify-v9.0.36.443-arm64-Experimental Merged版
  • SaaS微服务架构的智慧工地源码,基于Spring Cloud +UniApp +MySql开发
  • 第十八天 - ELK日志体系集成 - 自定义Logstash插件 - 练习:分布式日志分析平台
  • 物联网传感器技术架构与功能解析
  • LIB-ZC, 一个跨平台(Linux)平台通用C/C++扩展库,命令行参数和配置文件
  • Redis核心功能实现
  • 科技项目验收测试包括哪些内容?有什么作用?
  • ESP32小智AI机器人全栈开发:从云端部署到语音交互实战(附源码)
  • Eclipse 悬浮提示功能详解
  • Android12源码编译之预置Android Studio项目Android.mk文件编写
  • 泰鸿万立上市:加强产品规划和前瞻性研发 打造优质汽车零部件制造商
  • 电子学会 信息素养大赛图形化、python、c++历年试题
  • C++中的虚克隆模式:实现多态对象的安全深拷贝
  • 最新版DataGrip超详细图文安装教程,带补丁包(2025最新版保姆级教程)
  • 【Bug】BEVFormer配置bug:ModuleNotFoundError: No module named ‘tools.data_converter‘
  • [python] 作用域
  • BlueNRG-LP v3.x 协议栈主要事件列表与含义解析
  • 玩转ESP32-S3:UDP网络通信技术详解
  • 二级域名对网站帮助/网络推广app是干什么的
  • 分类信息网站建设方案/常州seo关键词排名
  • 凡科快图网站/查关键词的排名工具
  • 做网站公司圣辉友联/广州网站优化服务
  • 简单写文章的网站/廊坊网络推广公司
  • 长春网站建设公司怎么样/浏览器大全