当前位置: 首页 > news >正文

条件生成对抗网络(cGAN)详解与实现

条件生成对抗网络(cGAN)详解与实现

    • 0. 前言
    • 1. 条件生成对抗网络
    • 2. 实现条件生成对抗网络
    • 3. cGAN 变体
      • 3.1 使用嵌入层
      • 3.2 逐元素乘法
      • 3.3 在中间层插入标签

0. 前言

我们已经学习了如何使用变分自编码器 (Variational Autoencoder, VAE)和生成对抗网络 (Generative Adversarial Network, GAN) 生成逼真的图像。这些生成模型可以将一些简单的随机噪声转换为具有复杂分布的高维图像,但是,生成过程是无条件的,并不能很好地控制要生成的图像。如果以 MNIST 为例,我们不知道模型将生成哪个数字。在本节中,我们将学习构建条件生成对抗网络 (Conditional Generative Adversarial Network, cGAN),使我们能够指定要生成的图像的类别,为后续更复杂的网络奠定基础。

1. 条件生成对抗网络

生成模型的首要目标是能够产生高质量的图像。然后,我们希望能够对要生成的图像进行一些控制。
在使用 TensorFlow 生成图像入门一节中,我们了解了条件概率并使用简单的条件概率模型生成了具有某些属性的人脸。在该模型中,我们通过强制模型仅从具有笑脸的图像中采样来生成笑脸。当我们以某种属性为条件时,该属性将始终存在,并且不再是具有随机概率的变量,还可以看到具有这些条件的概率设置为 1
在神经网络上强制条件很简单。我们只需要在训练和推理过程中向网络显示标签。例如,如果我们希望生成器生成数字 1,则除了将通常的随机噪声作为生成器的输入之外,我们还需要显示标签 1。有几种实现方法。下图显示了条件生成对抗网络 (Conditional Generative Adversarial Network, cGAN) 一文中出现的一种实现,该实现首先介绍了 cGAN 的概念:

cGAN架构

在无条件 GAN 中,生成器输入仅是潜矢量 zzz。在条件 GAN 中,潜矢量 zzz 与独热编码输入标签 yyy 结合在一起以形成更长的矢量,如上图所示。下表显示了使用 tf.one_hot() 的独热编码:

类别标签独热编码
0[1,0,0,0,0,0,0,0,0,0]
1[0,1,0,0,0,0,0,0,0,0]
2[0,0,1,0,0,0,0,0,0,0]
3[0,0,0,1,0,0,0,0,0,0]
4[0,0,0,0,1,0,0,0,0,0]
5[0,0,0,0,0,1,0,0,0,0]
6[0,0,0,0,0,0,1,0,0,0]
7[0,0,0,0,0,0,0,1,0,0]
8[0,0,0,0,0,0,0,0,1,0]
9[0,0,0,0,0,0,0,0,0,1]

独热编码将标签转换为尺寸等于类别数量的向量。向量仅有一个位置为 1 其他位置全零。某些机器学习框架在向量中 1 的位序不同;例如,类别标签 0 编码为 0000000001,其中 1 位于最右边。只要在训练和推理中始终使用它们,顺序就无关紧要。这是因为“独热编码”仅用于表示类别,而没有语义信息。

2. 实现条件生成对抗网络

接下来,我们使用 MNIST 数据集实现条件生成对抗网络 (Conditional Generative Adversarial Network, cGAN)。我们已经学习了深度卷积生成对抗网络 (Deep Convolutional Generative Adversarial Network, DCGAN),我们通过添加条件来扩展网络。首先,实现生成器。
第一步是对类别标签进行独热编码。由于 tf.one_hot([1],10) 将创建 (1,10) 的形状,因此我们需要将其重塑为包含 10 个元素的一维向量,以便我们可以将其与潜向量 zzz 连接:

input_label = layers.Input(shape=1, dtype=tf.int32, name='ClassLabel')
one_hot_label = tf.one_hot(input_label, self.num_classes)
one_hot_label = layers.Reshape((self.num_classes,))(one_hot_label)

下一步是使用 Concatenate 层将向量连接在一起。默认情况下,串联发生在最后一个维度上 (axis = -1)。因此,将形状为 (batch_size, 100) 的潜变量与 (batch_size, 10) 的独热编码连接起来将产生张量形状为 (batch_size, 110)

input_z = layers.Input(shape=self.z_dim, name='LatentVector')
generator_input = layers.Concatenate()([input_z,one_hot_label])

这是生成器所需的唯一更改,之后输入将经过一个全连接层,然后经过几个上采样和卷积层,以生成形状为 (32, 32, 1) 的图像,如以下模型图所示:

模型架构

下一步是将标签注入判别器,因为判别器不仅能够分辨图像是真是假,而且还能够分辨图像是否正确满足条件。
原始 cGAN 仅使用全连接层。输入图像被展平并与独热编码类标签相连。但是,这不适用于 DCGAN,因为判别器的第一层是卷积层,需要将 2D 图像作为输入。如果使用相同的方法,则最终将得到 32×32×1 + 10 = 1,034 的输入向量,并且无法将其整形为 2D 图像。我们将需要另一种方式来将独热向量投影到正确形状的张量中。
一种实现方法是使用全连接层将独热编码向量投影为输入图像 (32,32,1) 的形状,并将其连接起来以生成形状 (32, 32, 2)。第一个颜色通道是我们的灰度图像,第二个通道将是投影的一个独热标签。同样,判别器网络的其余部分保持不变,如以下模型摘要所示:

模型架构

可以看到,对网络所做的唯一更改是通过添加另一条将类别标签作为输入的路径。在开始模型训练之前,剩下的最后一点是将附加标签类添加到模型的输入中。要创建具有多个输入的模型,我们传递输入层列表:

discriminator = Model([input_image, input_label], output]

同样,在执行前向传递时,我们以相同的顺序传递图像和标签的列表:

pred_real = discriminator([real_images, class_labels])

在训练期间,我们为生成器创建随机标签:

fake_class_labels = tf.random.uniform((batch_size), minval=0, maxval=10, dtype=tf.dtypes.int32)
fake_images = generator.predict([latent_vector, fake_class_labels])

我们使用 DCGAN 的训练步骤和损失函数。以下是通过对输入标签从 09 进行条件处理而生成的数字示例:

生成结果

我们也可以在 Fashion-MNIST 上训练条件生成对抗网络,而无需进行任何更改。生成样本如下:

生成结果

cGANMNISTFashion-MNIST 取得了良好的效果,接下来,我们将研究在 GAN 上应用类别条件的不同方式。

3. cGAN 变体

我们通过对标签进行独热编码,将其通过密集层并连接输入层来实现条件 DCGAN。实现很简单,并且给出了很好的结果。我们将介绍其他一些实现条件 GAN 的流行方法。

3.1 使用嵌入层

一种流行的实现方式是用嵌入层代替独热编码和全连接层。嵌入层将分类值作为输入,而输出是向量。换句话说,它具有与 label-> one-hot-encoding-> dense 块相同的输入和输出形状:

encoded_label = tf.one_hot(input_label, self.num_classes)
embedding = layers.Dense(32 * 32 * 1, activation=None)(encoded_label)

等效于:

embedding = layers.Embedding(self.num_classes, 32*32*1)(input_label)

两种方法都产生相似的结果,但嵌入层的计算效率更高,因为对于大量类而言,独热编码向量的大小会快速增长。由于词汇量众多,嵌入被广泛用于对单词进行编码。对于诸如MNIST之类的小类,计算优势可忽略不计。

3.2 逐元素乘法

将潜向量与输入图像连接起来会增加网络的维数。除了串联外,我们还可以将标签嵌入与原始网络输入进行逐元素乘法,并保持原始输入形状,研究表明这种方法的性能优于独热编码。在图像和嵌入之间执行逐元素乘法的代码段如下:

x = layers.Multiply()([input_image, embedding])

将前面的代码与嵌入层结合起来,可以得到下图:
模型架构

3.3 在中间层插入标签

无需将标签插入网络的第一层,我们可以选择在中间层进行此操作。这种方法在具有编码器-解码器体系结构的生成器中很流行,其中标签被插入到具有最小尺寸的编码器末端附近的层中。将标签嵌入到判别器输出端,因此判别器的大部分参数可以专注于确定图像是否看起来真实。判别器的最后几层用于确定图像是否与标签匹配。
在后续学习中,我们将学习如何在中间层和归一化化层中插入标签嵌入。


文章转载自:

http://E4f0Xqyy.tfzjL.cn
http://PQZX4nWG.tfzjL.cn
http://FIpxNoyh.tfzjL.cn
http://3KxFSgya.tfzjL.cn
http://Bi4cIcLk.tfzjL.cn
http://03Ef8X9j.tfzjL.cn
http://kKa4s6Tb.tfzjL.cn
http://vORawnYo.tfzjL.cn
http://ojk4esw3.tfzjL.cn
http://2nCrZxM6.tfzjL.cn
http://h3wYDbID.tfzjL.cn
http://tbBiT8V1.tfzjL.cn
http://W4OzvMnb.tfzjL.cn
http://zTnXXrSj.tfzjL.cn
http://SpxMYGQh.tfzjL.cn
http://yA3Dwrxp.tfzjL.cn
http://TOsSPnus.tfzjL.cn
http://CpkSfmOk.tfzjL.cn
http://b7AellGL.tfzjL.cn
http://yDxgn0jq.tfzjL.cn
http://u8Vq9Wla.tfzjL.cn
http://F95vny7l.tfzjL.cn
http://RW5LoVmc.tfzjL.cn
http://piBCEHA5.tfzjL.cn
http://QwDRmQMu.tfzjL.cn
http://PEym6GLK.tfzjL.cn
http://HW6BLwyZ.tfzjL.cn
http://6JawTnoJ.tfzjL.cn
http://1AIfww7d.tfzjL.cn
http://5284gzAe.tfzjL.cn
http://www.dtcms.com/a/384696.html

相关文章:

  • Mysql杂志(十六)——缓存池
  • 408学习之c语言(结构体)
  • 使用Qt实现从文件对话框选择并加载点数据
  • qt5连接mysql数据库
  • C++库的相互包含(即循环依赖,Library Circular Dependency)
  • 如何用GitHub Actions为FastAPI项目打造自动化测试流水线?
  • LVS与Keepalived详解(二)LVS负载均衡实现实操
  • 闪电科创-无人机轨迹预测SCI/EI会议辅导
  • 自动驾驶中的传感器技术48——Radar(9)
  • HDLBits 解题更新
  • Python 自动化测试开发教程:Selenium 从入门到实战(1)
  • 树莓派4B实现网络电视详细指南
  • Docker:在Windows上安装和使用,加速容器应用开发
  • Android中怎么使用C动态库
  • Redis 安装实战:在 CentOS 中通过源码包安装
  • 抛砖引玉:神经网络的激活函数在生活中也有
  • Java生成与解析大疆无人机KMZ航线文件
  • Mysql 主从复制、读写分离
  • Linux网络设备驱动结构
  • 第四阶段C#通讯开发-3:串口通讯之Modbus协议
  • 使用生成式 AI 和 Amazon Bedrock Data Automation 处理大规模智能文档
  • 可可图片编辑 HarmonyOS(7)图片绘画
  • django登录注册案例(上)
  • 查看iOS设备文件管理 访问iPhone用户文件、App沙盒目录 系统日志与缓存
  • 基于Echarts+HTML5可视化数据大屏展示-白茶大数据溯源平台V2
  • android 框架—网络访问Okhttp
  • CUDA 中Thrust exclusive_scan使用详解
  • Quat 四元数库使用教程:应用场景概述
  • GitHub 热榜项目 - 日榜(2025-09-15)
  • 让AI数据中心突破性能极限