当前位置: 首页 > news >正文

WGAN-GP 原理及实现(pytorch版)

WGAN-GP 原理及实现

  • 一、WGAN-GP 原理
    • 1.1 WGAN-GP 核心原理
    • 1.2 WGAN-GP 实现步骤
    • 1.3 总结
  • 二、WGAN-GP 实现
    • 2.1 导包
    • 2.2 数据加载和处理
    • 2.3 构建生成器
    • 2.4 构建判别器
    • 2.5 训练和保存模型
    • 2.6 图片转GIF

一、WGAN-GP 原理

Wasserstein GAN with Gradient Penalty (WGAN-GP) 是对原始 WGAN 的改进,通过梯度惩罚(Gradient Penalty)替代权重裁剪(Weight Clipping),解决了 WGAN 训练不稳定、权重裁剪导致梯度消失或爆炸的问题。


1.1 WGAN-GP 核心原理

(1) Wasserstein 距离(Earth-Mover 距离)

  • 原始 GAN 的 JS 散度在分布不重叠时梯度消失,而 WGAN 使用 Wasserstein 距离衡量生成分布 P g P_g Pg 和真实分布 P r P_r Pr 的距离:
    W ( P r , P g ) = inf ⁡ γ ∼ Π ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(P_r, P_g) = \inf_{\gamma \sim \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim \gamma} [\|x-y\|] W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy]
  • 通过 Kantorovich-Rubinstein 对偶形式,转化为:
    W ( P r , P g ) = sup ⁡ ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] W(P_r, P_g) = \sup_{\|D\|_L \leq 1} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] W(Pr,Pg)=supDL1ExPr[D(x)]EzPz[D(G(z))],其中 D D D 是 1-Lipschitz 函数(梯度范数不超过 1)

(2) 梯度惩罚(Gradient Penalty)

  • 原始 WGAN 的问题:通过权重裁剪强制判别器(Critic)满足 Lipschitz 约束,但会导致梯度不稳定或容量下降
  • WGAN-GP 的改进:直接对判别器的梯度施加惩罚项,强制其梯度范数接近 1: λ ⋅ E x ^ ∼ P x ^ \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} λEx^Px^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] \left [(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right] [(x^D(x^)21)2]
    • x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点: x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1ϵ)G(z) ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵU[0,1]
    • λ \lambda λ 是惩罚系数(通常设为 10)

1.2 WGAN-GP 实现步骤

(1) 判别器(Critic)的损失函数
判别器的目标是最大化 Wasserstein 距离,同时满足梯度约束:
L D = E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] ⏟ Wasserstein 距离 + λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] ⏟ 梯度惩罚 L_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right]}_{\text{梯度惩罚}} LD=Wasserstein 距离 ExPr[D(x)]EzPz[D(G(z))]+梯度惩罚 λEx^Px^[(x^D(x^)21)2]

(2) 生成器(Generator)的损失函数
生成器的目标是最小化 Wasserstein 距离: L G = − E z ∼ P z [ D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim P_z}[D(G(z))] LG=EzPz[D(G(z))]

(3) 训练流程

  1. 输入:真实数据 x x x,噪声 z ∼ N ( 0 , 1 ) z \sim \mathcal{N}(0,1) zN(0,1)
  2. 生成数据 G ( z ) G(z) G(z)
  3. 插值采样 x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1ϵ)G(z) ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵU[0,1]
  4. 计算梯度惩罚
    • 对插值样本 x ^ \hat{x} x^ 计算判别器输出 D ( x ^ ) D(\hat{x}) D(x^)
    • 求梯度 ∇ x ^ D ( x ^ ) \nabla_{\hat{x}} D(\hat{x}) x^D(x^) 并计算惩罚项
  5. 更新判别器:最小化 L D L_D LD
  6. 更新生成器:最小化 L G L_G LG(每 n critic n_{\text{critic}} ncritic 次判别器更新后更新 1 次生成器)

1.3 总结

WGAN-GP 通过梯度惩罚替代权重裁剪,显著提升了 WGAN 的训练稳定性,是生成对抗网络的重要改进之一。实际应用中需注意:

  • 判别器架构设计
  • 梯度惩罚的正确实现
  • 学习率和训练次数的调优

二、WGAN-GP 实现

2.1 导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

import os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  
from torchsummary import summary

# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 指定存放日志路径
writer=SummaryWriter(log_dir="./runs/wgan_gp")

os.makedirs("./img/wgan_gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录

2.2 数据加载和处理

# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,28,28)):
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到[-1,1]
    ])
    
    # 下载训练集和测试集
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 创建 DataLoader
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)
    return train_loader, test_loader

2.3 构建生成器

class Generator(nn.Module):
    """生成器"""
    def __init__(self, latent_dim=100,img_shape=(1,28,28)):
        super(Generator,self).__init__()

        # 网络块
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh() # 输出归一化到[-1,1] 
        )

        
    def forward(self,z): # 噪声z,2维[batch_size,latent_dim]
        gen_img=self.model(z) 
        gen_img=gen_img.view(gen_img.shape[0],*img_shape)
        return gen_img # 4维[batch_size,1,H,W]

2.4 构建判别器

class Discriminator(nn.Module):
    """判别器"""
    def __init__(self,img_shape=(1,28,28)):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
                nn.Linear(int(np.prod(img_shape)), 512),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(512, 256),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(256, 1)
            )

    def forward(self,img): # 输入图片,4维[batc_size,1,H,W]
        img=img.view(img.shape[0], -1) 
        pred = self.model(img)
        return pred # 2维[batch_size,1] 

2.5 训练和保存模型

  • WGAN-GP 算法流程

  • 定义梯度惩罚函数

def compute_gradient_penalty(critic, real, fake, device):
    batch_size = real.shape[0]
    epsilon = torch.rand(batch_size, 1, 1, 1).to(device)  # 随机插值系数
    interpolates = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
    critic_interpolates = critic(interpolates)
    
    # 计算梯度
    gradients = torch.autograd.grad(
        outputs=critic_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(critic_interpolates),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradients = gradients.view(gradients.shape[0], -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty
  • 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本

# WGAN的特别设置
num_iter_critic = 5
lambda_gp = 10

# 设置图片形状1*28*28
img_shape = (1,28,28)

# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)

# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)

# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# 开始训练
batches_done=0
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):
    # 进入训练模式
    G.train()
    D.train()
    
    loop = tqdm(train_loader, desc=f"第{epoch+1}轮")
    for i, (real_imgs, _) in enumerate(loop):
        real_imgs=real_imgs.to(device)  # [B,C,H,W]

        
        # -----------------
        #  训练判别器
        # -----------------
        
        # 获取噪声样本[B,latent_dim)
        z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device)  #从正态分布中抽样
 
        # Step-1 计算判断器损失=判断真实图片损失+判断生成图片损失+惩罚项
        fake_imgs=G(z).detach()
        gradient_penalty=compute_gradient_penalty(D, real_imgs, fake_imgs, device)
        dis_loss=-torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs))+lambda_gp*gradient_penalty
       
        # Step-2 更新判别器参数
        optimizer_D.zero_grad() # 梯度清零
        dis_loss.backward() #反向传播,计算梯度
        optimizer_D.step()  #更新判别器 

        # -----------------
        #  训练生成器
        # -----------------
 
        # 判别器每迭代 num_iter_critic 次,生成器迭代一次
        if i % num_iter_critic ==0 :

            gen_imgs=G(z).detach()

            # 更新生成器参数
            optimizer_G.zero_grad() #梯度清零
            gen_loss=-torch.mean(D(gen_imgs))
            gen_loss.backward() #反向传播,计算梯度
            optimizer_G.step()  #更新生成器  

             # 更新进度条
            loop.set_postfix(
                gen_loss=f"{gen_loss:.8f}",
                dis_loss=f"{dis_loss:.8f}"
            )
            

        # 每 sample_interval 次迭代保存生成样本
        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], f"./img/wgan_gp_mnist/{epoch}_{i}.png", nrow=5, normalize=True)
        batches_done += 1

print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))

#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/WGAN-GP_G.pth") 
torch.save(D.state_dict(), "./model/WGAN-GP_D.pth") 

2.6 图片转GIF

from PIL import Image

def create_gif(img_dir="./img/wgan_gp_mnist", output_file="./img/wgan_gp_mnist/wgan_gp_figure.gif", duration=100):
    images = []
    img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]
    
    # 自定义排序:按 "x_y.png" 的 x 和 y 排序
    img_paths_sorted = sorted(
        img_paths,
        key=lambda x: (
            int(x.split('_')[0]),  # 第一个数字(如 0_400.png 的 0)
            int(x.split('_')[1].split('.')[0])  # 第二个数字(如 0_400.png 的 400)
        )
    )
    
    for img_file in img_paths_sorted:
        img = Image.open(os.path.join(img_dir, img_file))
        images.append(img)
    
    images[0].save(output_file, save_all=True, append_images=images[1:], 
                  duration=duration, loop=0)
    print(f"GIF已保存至 {output_file}")
create_gif()

相关文章:

  • 卫龙的网站是谁做的安卓优化大师官方版本下载
  • 网站安装模板外贸建站推广公司
  • 前几年做那个网站致富百度手机助手官网
  • 惠州公司做网站seo文章外包
  • 江苏省国家示范校建设专题网站樱桃电视剧西瓜视频在线观看
  • 网站的后缀名怎么建设网络seo哈尔滨
  • MySQL 备份与恢复:数据库的灾难保险计划
  • 兔单B细胞单抗制备服务
  • 蓝桥杯嵌入式十五届模拟二(串口DMA,占空比的另一种测量方式)
  • Python人工智能算法 基于遗传算法解决流水车间调度问题
  • (学习总结33)Linux Ext2 文件系统与软硬链接
  • js 效果展示
  • 机器学习 | 强化学习 vs 深度学习 vs 深度强化学习 | 概念向
  • 初入Web网页开发
  • 基于大模型的阵发性室上性心动过速风险预测与治疗方案研究
  • mySQL数据库和mongodb数据库的详细对比
  • LeetCode】寻找重复子树:深度解析与高效解法
  • Dynamics 365 Business Central Recurring Sales Lines 经常购买销售行 来作 订阅
  • 2025年美国CPI数据公布时间表
  • 循环神经网络 - 参数学习之实时循环学习
  • UML类图综合实验三补档
  • 类初始化、类加载、垃圾回收---JVM
  • Heap_dijkstra
  • SnakeMake搭建pipeline 1
  • 隔行换色总结
  • MCP vs LangChain:标准化协议与开发框架的优劣对比