当前位置: 首页 > news >正文

Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一)

Pytorch深度学习框架60天进阶学习计划 - 第41天:生成对抗网络进阶(一)

今天我们将深入探讨生成对抗网络(GAN)的进阶内容,特别是Wasserstein GAN(WGAN)的梯度惩罚机制,以及条件生成与无监督生成在模式坍塌方面的差异。

生成对抗网络是近年来深度学习领域最激动人心的进展之一,它由Ian Goodfellow于2014年提出,通过生成器和判别器的博弈来学习生成真实数据分布的样本。随着研究的深入,GAN的改进版本层出不穷,其中WGAN及其梯度惩罚版本(WGAN-GP)解决了原始GAN训练不稳定的问题,成为了GAN研究的重要里程碑。

今天我们将从理论到实践,系统地学习这些进阶概念,并通过PyTorch实现相关模型,探索其工作原理。

1. GAN基础回顾

在深入WGAN之前,让我们简要回顾GAN的基本原理:

1.1 GAN的基本架构

GAN由两部分组成:

  • 生成器(Generator): 学习从随机噪声生成看起来真实的数据
  • 判别器(Discriminator): 学习区分真实数据和生成器生成的假数据

这两个网络通过对抗训练相互提高:生成器尝试生成越来越逼真的样本以欺骗判别器,而判别器则努力提高其区分真假样本的能力。

1.2 原始GAN的问题

虽然GAN的思想非常优雅,但原始GAN在训练过程中存在一些问题:

  1. 训练不稳定:很难找到生成器和判别器之间的平衡点
  2. 梯度消失:当判别器表现过好时,生成器梯度接近于零
  3. 模式坍塌:生成器只生成有限种类的样本,无法覆盖真实数据的全部分布
  4. 难以量化训练进度:缺乏有效的指标来衡量生成样本的质量

这些问题促使研究者寻找GAN的改进版本,其中WGAN是最重要的改进之一。

2. Wasserstein GAN详解

2.1 从JS散度到Wasserstein距离

原始GAN隐式地最小化生成分布与真实分布之间的Jensen-Shannon(JS)散度,这在两个分布没有显著重叠时会导致梯度问题。

Wasserstein距离(也称Earth Mover’s Distance,简称EMD)提供了一种更平滑的度量方式,即使两个分布没有重叠或重叠很少,也能提供有意义的梯度。

Wasserstein距离的直观解释:想象将一个分布的概率质量移动到另一个分布所需的最小"工作量",其中工作量定义为概率质量乘以移动距离。

2.2 WGAN的核心改进

WGAN相比原始GAN有以下关键改进:

  1. 目标函数改变:使用Wasserstein距离而非JS散度
  2. 判别器(现称为评论家/Critic)输出不再是概率:移除了最后的sigmoid激活函数
  3. 权重裁剪:限制评论家的参数在一定范围内,满足Lipschitz约束
  4. 避免使用基于动量的优化器:建议使用RMSProp或Adam优化器(学习率较小)

2.3 WGAN的目标函数

WGAN的目标函数如下:

min ⁡ G max ⁡ D ∈ D E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] GminDDmaxExPr[D(x)]EzPz[D(G(z))]

其中 D \mathcal{D} D是满足1-Lipschitz约束的函数集合。

2.4 Lipschitz约束与权重裁剪

为了满足Wasserstein距离计算中的Lipschitz约束,WGAN对评论家的参数进行了权重裁剪:将权重限制在 [ − c , c ] [-c, c] [c,c]的范围内,其中 c c c是一个小常数(如0.01)。

然而,权重裁剪是一种粗糙的方法,会导致优化问题和容量浪费。这就引出了WGAN的进一步改进:梯度惩罚机制。

3. WGAN的梯度惩罚机制

3.1 权重裁剪的局限性

WGAN中的权重裁剪虽然简单有效,但存在以下问题:

  1. 容量浪费:强制权重接近0或c,导致模型倾向于使用更简单的函数
  2. 优化困难:可能导致梯度爆炸或消失
  3. 对架构敏感:不同网络架构可能需要不同的裁剪范围

3.2 梯度惩罚的原理

WGAN-GP(带梯度惩罚的WGAN)提出了一种更优雅的方式来满足Lipschitz约束。其核心思想是:

对于一个1-Lipschitz函数,其梯度范数在任何地方都不应超过1。因此,我们可以通过惩罚评论家函数梯度范数偏离1的行为来满足这一约束。

具体来说,WGAN-GP在真实数据和生成数据之间的随机插值点上施加梯度惩罚:

L G P = E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] \mathcal{L}_{GP} = \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] LGP=Ex^Px^[(∣∣x^D(x^)21)2]

其中 x ^ \hat{x} x^是在真实样本 x x x和生成样本 G ( z ) G(z) G(z)之间的随机插值:

x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon)G(z) x^=ϵx+(1ϵ)G(z)

ϵ \epsilon ϵ是一个在 [ 0 , 1 ] [0,1] [0,1]之间均匀采样的随机数。

3.3 WGAN-GP的完整目标函数

将梯度惩罚添加到WGAN的目标函数中,我们得到WGAN-GP的目标函数:

L = E z ∼ p ( z ) [ D ( G ( z ) ) ] − E x ∼ p d a t a [ D ( x ) ] + λ E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] \mathcal{L} = \mathbb{E}_{z \sim p(z)}[D(G(z))] - \mathbb{E}_{x \sim p_{data}}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}}D(\hat{x})||_2 - 1)^2] L=Ezp(z)[D(G(z))]Expdata[D(x)]+λEx^Px^[(∣∣x^D(x^)21)2]

其中 λ \lambda λ是梯度惩罚的权重,通常设为10。

3.4 WGAN-GP的优势

WGAN-GP相比WGAN有以下优势:

  1. 更好的稳定性:避免了权重裁剪带来的问题
  2. 更快的收敛:通常需要更少的迭代次数
  3. 更好的生成质量:能生成更多样、更高质量的样本
  4. 架构灵活性:适用于各种GAN架构,包括深度卷积网络

4. PyTorch实现WGAN-GP

下面我们使用PyTorch实现一个简单的WGAN-GP模型,用于生成MNIST手写数字。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 超参数
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
lambda_gp = 10  # 梯度惩罚权重

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 归一化到[-1, 1]
])

mnist_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

dataloader = DataLoader(
    mnist_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

# 生成器网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def block(in_features, out_features, normalize=True):
            layers = [nn.Linear(in_features, out_features)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_features, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()  # 输出归一化到[-1, 1]
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

# 判别器网络(评论家)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
            # 注意:没有sigmoid激活函数
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# 初始化网络
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 初始化优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))

# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples):
    """计算WGAN-GP中的梯度惩罚"""
    # 在真实样本和生成样本之间随机插值
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    
    # 计算插值点的判别器输出
    d_interpolates = D(interpolates)
    
    # 计算梯度
    fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # 计算梯度范数
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

# 训练函数
def train_wgan_gp():
    # 用于记录损失
    d_losses = []
    g_losses = []
    
    for epoch in range(n_epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.shape[0]
            
            # ---------------------
            #  训练判别器
            # ---------------------
            optimizer_D.zero_grad()
            
            # 生成随机噪声
            z = torch.randn(batch_size, latent_dim, device=device)
            
            # 生成一批假图像
            fake_imgs = generator(z)
            
            # 判别器前向传播
            real_validity = discriminator(real_imgs)
            fake_validity = discriminator(fake_imgs.detach())
            
            # 计算梯度惩罚
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            
            # WGAN-GP 判别器损失
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
            
            d_loss.backward()
            optimizer_D.step()
            
            # 每n_critic次迭代训练一次生成器
            n_critic = 5
            if i % n_critic == 0:
                # ---------------------
                #  训练生成器
                # ---------------------
                optimizer_G.zero_grad()
                
                # 生成一批新的假图像
                gen_imgs = generator(z)
                
                # 判别器评估假图像
                fake_validity = discriminator(gen_imgs)
                
                # WGAN 生成器损失
                g_loss = -torch.mean(fake_validity)
                
                g_loss.backward()
                optimizer_G.step()
                
                if i % 50 == 0:
                    print(
                        f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
                        f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
                    )
                    
                    d_losses.append(d_loss.item())
                    g_losses.append(g_loss.item())
                    
        # 每个epoch结束后保存生成的图像样本
        if (epoch + 1) % 10 == 0:
            save_sample_images(epoch)
    
    # 绘制损失曲线
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.xlabel('Iterations (x50)')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('wgan_gp_loss.png')
    plt.close()

# 保存样本图像
def save_sample_images(epoch):
    # 生成并保存样本图像
    z = torch.randn(25, latent_dim, device=device)
    gen_imgs = generator(z).detach().cpu()
    
    # 将图像像素值从[-1, 1]转换为[0, 1]
    gen_imgs = 0.5 * gen_imgs + 0.5
    
    # 创建图像网格
    fig, axs = plt.subplots(5, 5, figsize=(10, 10))
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(gen_imgs[i*5+j, 0, :, :], cmap='gray')
            axs[i, j].axis('off')
    
    # 保存图像
    plt.savefig(f'wgan_gp_epoch_{epoch+1}.png')
    plt.close()

# 运行训练
if __name__ == "__main__":
    train_wgan_gp()

这段代码实现了一个基本的WGAN-GP模型,用于生成MNIST数字图像。下面我们来解析代码的关键部分:

  1. 梯度惩罚计算compute_gradient_penalty函数实现了WGAN-GP的核心——在真实样本和生成样本之间的插值点上计算梯度惩罚。
  2. 判别器损失:包括真实数据的评论家值、生成数据的评论家值,以及梯度惩罚项。
  3. 生成器损失:仅包含生成数据的评论家值的负期望。
  4. 优化器设置:使用Adam优化器,但β1参数设为0.5,这是GAN训练的常见设置。
  5. 训练循环:判别器和生成器交替训练,但判别器通常训练多次(n_critic=5)后才训练一次生成器。

5. WGAN-GP训练流程图

以下是WGAN-GP的训练流程图,帮助理解整个训练过程:

┌────────────────────┐
│  初始化网络和优化器  │
└──────────┬─────────┘
           │
           ▼
┌────────────────────┐
│    开始训练循环     │
└──────────┬─────────┘
           │
           ▼
┌────────────────────┐
│  从数据集加载真实样本 │
└──────────┬─────────┘
           │
           ▼
┌────────────────────┐
│  生成随机噪声并产生  │
│     假样本         │
└──────────┬─────────┘
           │
           ▼
┌────────────────────┐
│  计算判别器对真实   │
│   和假样本的输出    │
└──────────┬─────────┘
           │
           ▼
┌────────────────────┐
│  在样本插值点上计算  │
│    梯度惩罚        │
└──────────┬─────────┘
           │
           ▼
┌────────────────────┐
│   计算判别器损失    │
│   并更新判别器参数   │
└──────────┬─────────┘
           │
           ▼
      ┌────┴─────┐
      │ i % n_critic │
      │   == 0?   │
      └────┬─────┘
 No        │       Yes
 ┌─────────┘       └──────────┐
 │                            ▼
 │              ┌────────────────────┐
 │              │   重新生成假样本    │
 │              └──────────┬─────────┘
 │                         │
 │                         ▼
 │              ┌────────────────────┐
 │              │   计算生成器损失    │
 │              │   并更新生成器参数   │
 │              └──────────┬─────────┘
 │                         │
 └─────────────────────────┘
           │
           ▼
┌────────────────────┐
│ 是否达到预定训练轮数? │
└──────────┬─────────┘
      No   │   Yes
      ┌────┘       └──────────┐
      │                       ▼
      │            ┌────────────────────┐
      └──────▶     │      结束训练      │
                   └────────────────────┘

这个流程图展示了WGAN-GP的训练过程,包括梯度惩罚的计算和判别器多次训练的机制。与普通GAN相比,WGAN-GP的关键区别在于梯度惩罚的引入和目标函数的改变。

6. 条件生成与无监督生成的对比

接下来,我们将探讨条件生成与无监督生成在模式坍塌方面的差异。

6.1 无监督生成与模式坍塌

无监督生成是指生成器仅从随机噪声生成样本,没有额外的条件输入。

模式坍塌(Mode Collapse)是GAN训练中的常见问题,指生成器只学会生成数据分布中的少数几种模式,而忽略了其他模式。例如,在MNIST数据集上,模型可能只生成数字"1"而不生成其他数字。

导致模式坍塌的原因:

  1. 判别器更新不足:判别器无法有效区分真假样本
  2. 梯度消失:当判别器表现过好时,生成器梯度接近零
  3. 目标函数设计问题:JS散度在两个分布不重叠时提供有限的梯度信息

6.2 条件生成对模式坍塌的缓解

条件生成是指生成器不仅接收随机噪声,还接收额外的条件信息(如类别标签)作为输入。

条件GAN(CGAN)通过以下方式缓解模式坍塌:

  1. 强制生成器覆盖所有类别:通过提供不同的类别条件,迫使生成器学习生成不同类别的样本
  2. 简化学习任务:条件信息使生成器只需要学习条件分布,而非整个联合分布
  3. 提供更多监督信号:条件信息为生成器提供了额外的指导

6.3 条件生成与无监督生成的模式坍塌差异表

特性无监督生成条件生成
输入仅随机噪声随机噪声 + 条件信息
模式覆盖容易忽略部分模式被条件强制覆盖更多模式
生成样本多样性较低,倾向于生成相似样本较高,不同条件生成不同样本
训练稳定性较差,易发生模式坍塌较好,条件信息提供稳定指导
应用灵活性生成过程不可控可控制生成特定类别/属性的样本
实现复杂度相对简单需要额外的条件嵌入机制

7. 实现条件WGAN-GP

下面我们将实现一个条件版本的WGAN-GP,以比较其与无监督版本在模式坍塌方面的差异。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 超参数
batch_size = 64
lr = 0.0002
n_epochs = 50
latent_dim = 100
img_shape = (1, 28, 28)
n_classes = 10  # MNIST有10个类别
lambda_gp = 10  # 梯度惩罚权重

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 归一化到[-1, 1]
])

mnist_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

dataloader = DataLoader(
    mnist_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

# 条件生成器网络
class ConditionalGenerator(nn.Module):
    def __init__(self):
        super(ConditionalGenerator, self).__init__()
        
        # 嵌入层将类别标签转换为嵌入向量
        self.label_embedding = nn.Embedding(n_classes, n_classes)
        
        # 输入层处理噪声和类别嵌入
        self.input_layer = nn.Linear(latent_dim + n_classes, 128)
        
        # 主要模型
        self.model = nn.Sequential(
            nn.BatchNorm1d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        # 将标签嵌入向量与噪声拼接
        label_embedding = self.label_embedding(labels)
        x = torch.cat([noise, label_embedding], dim=1)
        
        # 通过输入层
        x = self.input_layer(x)
        
        # 通过主模型
        x = self.model(x)
        
        # 重塑为图像格式
        img = x.view(x.size(0), *img_shape)
        return img

# 条件判别器网络
class ConditionalDiscriminator(nn.Module):
    def __init__(self):
        super(ConditionalDiscriminator, self).__init__()
        
        # 嵌入层将类别标签转换为嵌入向量
        self.label_embedding = nn.Embedding(n_classes, n_classes)
        
        # 处理图像和标签
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)) + n_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )
    
    def forward(self, img, labels):
        # 将图像展平
        img_flat = img.view(img.size(0), -1)
        
        # 获取标签嵌入
        label_embedding = self.label_embedding(labels)
        
        # 拼接图像特征和标签嵌入
        x = torch.cat([img_flat, label_embedding], dim=1)
        
        # 通过判别器网络
        validity = self.model(x)
        return validity

# 初始化网络
generator = ConditionalGenerator().to(device)
discriminator = ConditionalDiscriminator().to(device)

# 初始化优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))

# 计算梯度惩罚(条件版本)
def compute_gradient_penalty(D, real_samples, fake_samples, labels):
    """计算条件WGAN-GP的梯度惩罚"""
    # 在真实样本和生成样本之间随机插值
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    
    # 计算插值点的判别器输出(带条件)
    d_interpolates = D(interpolates, labels)
    
    # 计算梯度
    fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # 计算梯度范数
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

# 训练条件WGAN-GP
def train_conditional_wgan_gp():
    # 用于记录损失
    d_losses = []
    g_losses = []
    
    # 用于记录生成样本的多样性(通过类别分布)
    class_distributions = []
    
    for epoch in range(n_epochs):
        for i, (real_imgs, labels) in enumerate(dataloader):
            real_imgs = real_imgs.to(device)
            labels = labels.to(device)
            batch_size = real_imgs.shape[0]
            
            # ---------------------
            #  训练判别器
            # ---------------------
            optimizer_D.zero_grad()
            
            # 生成随机噪声
            z = torch.randn(batch_size, latent_dim, device=device)
            
            # 为生成器生成随机标签
            gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)
            
            # 生成一批假图像
            fake_imgs = generator(z, gen_labels)
            
            # 判别器前向传播
            real_validity = discriminator(real_imgs, labels)
            fake_validity = discriminator(fake_imgs.detach(), gen_labels)
            
            # 计算梯度惩罚
            gradient_penalty = compute_gradient_penalty(
                discriminator, real_imgs.data, fake_imgs.data, labels
            )
            
            # WGAN-GP 判别器损失
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
            
            d_loss.backward()
            optimizer_D.step()
            
            # 每n_critic次迭代训练一次生成器
            n_critic = 5
            if i % n_critic == 0:
                # ---------------------
                #  训练生成器
                # ---------------------
                optimizer_G.zero_grad()
                
                # 为生成器生成新的随机标签
                gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)
                
                # 生成一批新的假图像
                gen_imgs = generator(z, gen_labels)
                
                # 判别器评估假图像
                fake_validity = discriminator(gen_imgs, gen_labels)
                
                # WGAN 生成器损失
                g_loss = -torch.mean(fake_validity)
                
                g_loss.backward()
                optimizer_G.step()
                
                if i % 50 == 0:
                    print(
                        f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
                        f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
                    )
                    
                    d_losses.append(d_loss.item())
                    g_losses.append(g_loss.item())
        
        # 每个epoch结束后,评估生成样本的类别分布
        if (epoch + 1) % 10 == 0:
            class_dist = evaluate_class_distribution()
            class_distributions.append(class_dist)
            
            # 保存生成的图像样本
            save_sample_images(epoch)
    
    # 绘制损失曲线
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.xlabel('Iterations (x50)')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('cond_wgan_gp_loss.png')
    plt.close()
    
    # 绘制类别分布变化
    plot_class_distributions(class_distributions)

# 评估生成样本的类别分布
def evaluate_class_distribution():
    """评估生成样本在各类别上的分布情况"""
    # 创建一个预训练的分类器
    classifier = torchvision.models.resnet18(pretrained=True)
    # 修改第一个卷积层以适应灰度图
    classifier.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # 修改最后的全连接层以适应MNIST的10个类别
    classifier.fc = nn.Linear(classifier.fc.in_features, 10)
    
    # 加载预先训练好的MNIST分类器权重(这里假设我们有一个预训练的模型)
    # classifier.load_state_dict(torch.load('mnist_classifier.pth'))
    
    # 简化起见,这里我们使用一个简单的CNN分类器
    classifier = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(64 * 7 * 7, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ).to(device)
    
    # 这里假设这个简单分类器已经在MNIST上训练好了
    # 实际应用中,应该加载一个预先训练好的模型
    
    # 生成1000个样本
    z = torch.randn(1000, latent_dim, device=device)
    # 均匀采样所有类别
    gen_labels = torch.tensor([i % 10 for i in range(1000)], device=device)
    gen_imgs = generator(z, gen_labels)
    
    # 使用分类器预测类别
    with torch.no_grad():
        classifier.eval()
        preds = torch.softmax(classifier(gen_imgs), dim=1)
        pred_labels = torch.argmax(preds, dim=1)
    
    # 计算每个类别的样本数量
    class_counts = torch.zeros(10)
    for i in range(10):
        class_counts[i] = (pred_labels == i).sum().item() / 1000
    
    return class_counts.numpy()

# 绘制类别分布变化
def plot_class_distributions(class_distributions):
    """绘制生成样本类别分布的变化"""
    epochs = [10, 20, 30, 40, 50]  # 假设每10个epoch评估一次
    plt.figure(figsize=(12, 8))
    
    for i, dist in enumerate(class_distributions):
        plt.subplot(len(class_distributions), 1, i+1)
        plt.bar(np.arange(10), dist)
        plt.ylabel(f'Epoch {epochs[i]}')
        plt.ylim(0, 0.3)  # 限制y轴范围,便于比较
        if i == len(class_distributions) - 1:
            plt.xlabel('Digit Class')
    
    plt.tight_layout()
    plt.savefig('class_distribution.png')
    plt.close()

# 保存样本图像(条件版本)
def save_sample_images(epoch):
    """保存按类别排列的样本图像"""
    # 为每个类别生成样本
    n_row = 10  # 每个类别一行
    n_col = 10  # 每个类别10个样本
    
    fig, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
    
    for i in range(n_row):
        # 固定类别
        fixed_class = torch.tensor([i] * n_col, device=device)
        # 随机噪声
        z = torch.randn(n_col, latent_dim, device=device)
        # 生成图像
        gen_imgs = generator(z, fixed_class).detach().cpu()
        # 转换到[0, 1]范围
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        # 显示图像
        for j in range(n_col):
            axs[i, j].imshow(gen_imgs[j, 0, :, :], cmap='gray')
            axs[i, j].axis('off')
    
    plt.savefig(f'cond_wgan_gp_epoch_{epoch+1}.png')
    plt.close()

# 运行条件WGAN-GP训练
if __name__ == "__main__":
    train_conditional_wgan_gp()


清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。

怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!

相关文章:

  • 《LNMP架构+Nextcloud私有云超维部署:量子级安全与跨域穿透实战》
  • 手动部署内网穿透
  • 有序数组的平方
  • 【云安全】云原生-centos7搭建/安装/部署k8s1.23.6单节点
  • 【开源项目】Excel手撕AI算法深入理解(二):Transformer
  • 头歌educoder——数据库 第10-11章
  • 对自己的优缺点评价
  • 导入 Excel 批量替换文件夹名称
  • MySQL 分区与分库分表策略
  • 【场景应用6】Autoformer在时间序列预测任务中的应用
  • LangGraph——Agent AI的持久化状态
  • 038-flatbuffers
  • ngx_set_worker_processes
  • 考研数据结构之串的模式匹配算法——KMP算法详解(包含真题及解析)
  • 回顾CSA,CSA复习
  • Linux的网络配置的资料
  • python对mysql数据库的操作
  • 深度学习中多机训练概念下的DP与DDP
  • C++ 编程指南35 - 为保持ABI稳定,应避免模板接口
  • SQL查询语句的执行顺序
  • 以色列高等法院裁定政府解职辛贝特局长程序不当
  • 31只北交所主题基金齐刷净值纪录,年内最高涨超80%,已有产品打出“限购牌”
  • 欧洲加大力度招募美国科研人员
  • 前4个月全国新建商品房销售面积降幅收窄,房地产库存和新开工有所改善
  • 国家主席习近平任免驻外大使
  • 罗马教皇利奥十四世正式任职