当前位置: 首页 > news >正文

生成对抗网络(Generative Adversarial Network,GAN)

生成对抗网络(Generative Adversarial Network,GAN)

    • 0. 前言
    • 1. GAN 基本原理
    • 2. 实现 DCGAN
    • 3. 生成图像

0. 前言

生成对抗网络 (Generative Adversarial Network, GAN) 属于生成模型。与自编码器不同,生成模型能够在给定任意编码的情况下创建新的有意义的输出。本节将详细讨论 GAN 的工作原理,并通过 tf.keras 框架实现 GAN

1. GAN 基本原理

GAN 通过训练两个相互竞争又协作的网络——生成器与判别器——来学习输入数据的分布建模。生成器的任务是持续生成能够欺骗判别器的伪造数据(例如音频和图像),而判别器则被训练来区分真假数据。随着训练推进,判别器将逐渐无法区分合成数据与真实数据。至此,可弃用判别器,单独使用生成器来创造逼真数据。
虽然 GAN 的基本原理直观明了,但其核心挑战在于如何实现生成器-判别器网络的稳定训练。两个网络必须保持良性竞争关系才能实现同步学习。由于损失函数基于判别器的输出计算,其参数更新迅速。当判别器收敛过快时,生成器就无法获得足够的梯度更新,导致收敛失败。除训练困难外,GAN 还可能遭遇局部或完全模式崩溃——即生成器对不同潜编码产生几乎相同输出的情况。
如下图所示,GAN 由生成器与判别器两个网络组成:

GAN

生成器的输入是随机噪声,输出则是伪造数据。而判别器的输入既可以是真实数据也可以是伪造数据——真实数据来源于实际采样数据,伪造数据则来自生成器。所有真实数据都被标记为 1.0 (即 100% 为真的概率),而所有伪造数据则标记为 0.0 (即 0% 为真的概率)。由于标注过程是自动完成的,因此 GAN 仍被归为深度学习中的无监督学习方法。
判别器的目标是通过提供的数据集学习如何区分真实数据与伪造数据。在 GAN 训练的这个阶段,仅更新判别器参数。如同典型的二分类器,判别器被训练输出 0.01.0 之间的置信度值,以判断输入数据与真实数据的接近程度。但这只是训练过程的一半。
生成器会定期将其输出伪装成真实数据,并要求 GAN 将其标记为 1.0。当这些伪造数据被送入判别器时,自然会被判定为伪造数据并打上接近 0.0 的标签。
优化器根据输入的标签(即 1.0 )计算生成器的参数更新,同时在处理这批新数据时也会参考自身的预测结果。这意味着判别器对其预测结果存在不确定性,而 GAN 框架会充分考虑这种不确定性。此时,梯度将从判别器的最后一层反向传播至生成器的第一层。在实践过程中,该训练阶段通常会暂时冻结判别器的参数。生成器将利用这些梯度更新自身参数,从而提升合成伪造数据的能力。
总体而言,整个过程类似于两个网络在相互竞争的同时又保持协作。当 GAN 训练达到收敛时,最终结果是获得一个能生成以假乱真数据的生成器。此时判别器会认为这些合成数据是真实的(输出标签接近 1.0),至此便可弃用判别器。生成器将能够从任意噪声输入中产生有意义的输出。
判别器的训练可通过最小化以下损失函数实现:
L(D)(θ(G),θ(D))=−Ex∼pdatalogD(x)−Ezlog(1−D(G(z)))\mathcal L^{(D)}(\theta^{(G)},\theta^{(D)})=-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog(1-D(G(z))) L(D)(θ(G),θ(D))=ExpdatalogD(x)Ezlog(1D(G(z)))
该公式即为标准二元交叉熵损失函数。其损失值由两部分期望值的负值和构成:正确识别真实数据 D(x)D(x)D(x) 的期望,以及( 1 减去正确识别合成数据 D(G(z))D(G(z))D(G(z)) 的期望。对数运算不会改变局部极小值的位置。
在训练过程中,判别器会接收两批小批量数据:

  • 真实数据 xxx,来自采样数据(即 x∼pdatax\sim p_{data}xpdata),标签为 1.0
  • 生成器产生的伪造数据 x′=G(z)x'=G(z)x=G(z),标签为 0.0

为最小化损失函数,判别器参数 θ(D)\theta^{(D)}θ(D) 将通过反向传播进行更新,目标是准确识别真实数据 D(x)D(x)D(x) 与合成数据 1−D(G(z))1−D(G(z))1D(G(z))。正确识别真实数据意味着 D(x)→1.0D(x)\rightarrow1.0D(x)1.0,而正确分类伪造数据则等价于 D(G(z))→0.0D(G(z))\rightarrow0.0D(G(z))0.0(1−D(G(z))→1.0(1−D(G(z))\rightarrow1.0(1D(G(z))1.0。此处的 zzz 是生成器用于合成新信号的任意编码或噪声向量。两者共同作用以最小化损失函数。
对于生成器的训练,GAN 将判别器和生成器的损失总和视为零和博弈。生成器损失函数即为判别器损失函数的相反数:
L(G)(θ(G),θ(D))=−L(D)(θ(G),θ(D))\mathcal L^{(G)}(\theta^{(G)},\theta^{(D)})=−\mathcal L^{(D)}(\theta^{(G)},\theta^{(D)}) L(G)(θ(G),θ(D))=L(D)(θ(G),θ(D))
改写为价值函数:
V(G)(θ(G),θ(D))=−L(D)(θ(G),θ(D))\mathcal V^{(G)}(\theta^{(G)},\theta^{(D)})=−\mathcal L^{(D)}(\theta^{(G)},\theta^{(D)}) V(G)(θ(G),θ(D))=L(D)(θ(G),θ(D))
从生成器角度,应最小化上式;从判别器角度,则应最大化该价值函数。因此生成器训练准则可表述为最小最大化问题:
θ(G)∗=argminθ(G)maxθ(D)V(G)(θ(G),θ(D))\theta^{(G)*}=arg\underset {\theta ^{(G)}}{min}\underset {\theta ^{(D)}}{max}\mathcal V^{(G)}(\theta^{(G)},\theta^{(D)}) θ(G)=argθ(G)minθ(D)maxV(G)(θ(G),θ(D))

通过将生成数据伪装成标签 1.0 的真实数据来欺骗判别器。通过对 θ(D)\theta^{(D)}θ(D) 进行最大化,优化器会向判别器参数发送梯度更新,使其将合成数据判定为真实数据。同时通过对 θ(G)\theta^{(G)}θ(G) 进行最小化,优化器会训练生成器参数学习如何欺骗判别器。然而实践中,判别器往往能确信地将合成数据判定为伪造,导致 GAN 参数无法更新。此外,梯度更新量本身较小,且在传播至生成器各层时显著衰减,最终导致生成器无法收敛。
解决方案是重新构造生成器的损失函数:
L(G)(θ(G),θ(D))=−EzlogD(G(z))\mathcal L^{(G)}(\theta^{(G)},\theta^{(D)})=-\mathbb E_zlogD(G(z)) L(G)(θ(G),θ(D))=EzlogD(G(z))
该损失函数通过训练生成器,直接最大化判别器将合成数据判定为真实数据的概率。新的公式不再遵循零和博弈原则,完全由启发式方法驱动。生成器参数仅在整个对抗网络训练时才会更新,这是因为梯度从判别器向后传递至生成器。实际训练中,判别器权重在对抗训练期间会被临时冻结。

2. 实现 DCGAN

接下来,我们将实现深度卷积生成对抗网络 (Deep Convolution Generative Adversarial Network, DCGAN)。生成器的所有网络层均使用 ReLU 激活函数,输出层除外,在该层中使用了 tanh 激活。判别器的所有层中使用 Leaky ReLU

def build_generator(inputs,image_size):image_size = image_size // 4kernel_size = 5layer_filters = [128,64,32,1]x = keras.layers.Dense(image_size*image_size*layer_filters[0])(inputs)x = keras.layers.Reshape((image_size,image_size,layer_filters[0]))(x)for filters in layer_filters:if filters > layer_filters[-2]:strides = 2else:strides = 1x = keras.layers.BatchNormalization()(x)x = keras.layers.Activation('relu')(x)x = keras.layers.Conv2DTranspose(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)x = keras.layers.Activation('sigmoid')(x)generator = keras.Model(inputs,x,name='generator')return generatordef build_discriminator(inputs):kernel_size = 5layer_filters = [32,64,128,256]x = inputsfor filters in layer_filters:if filters == layer_filters[-1]:strides = 1else:strides = 2x = keras.layers.LeakyReLU(alpha=0.2)(x)x = keras.layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)x = keras.layers.Flatten()(x)x = keras.layers.Dense(1)(x)x = keras.layers.Activation('sigmoid')(x)discriminator = keras.Model(inputs,x,name='discriminator')return discriminatordef build_and_train_model():(x_train,_),_ = keras.datasets.mnist.load_data()image_size = x_train.shape[-1]x_train = np.reshape(x_train,[-1,image_size,image_size,1])x_train = x_train.astype('float32') / 255.model_name = 'dcgan_mnist'#超参数latent_size = 100batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size,image_size,1)inputs = keras.layers.Input(shape=input_shape,name='disriminator_input')discriminator = build_discriminator(inputs)optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])discriminator.summary()input_shape = (image_size,)inputs = keras.layers.Input(shape=input_shape,name='z_input')generator = build_generator(inputs,image_size)generator.summary()optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)discriminator.trainable = Falseadversarial = keras.Model(inputs,discriminator(generator(inputs)),name=model_name)adversarial.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])adversarial.summary()models = (generator,discriminator,adversarial)params = (batch_size,latent_size,train_steps,model_name)train(models,x_train,params)

由于进行了自定义训练,因此不使用常规的 fit() 函数。取而代之的是,调用 train_on_batch() 为给定的数据批处理运行单个梯度更新。然后通过对抗网络训练生成网络。训练首先从数据集中随机抽取一批真实图像,被标记为 1.0。然后,生成器将生成一批伪造图像。这被标记为 0.0。这两个批次串联在一起用于训练判别器。

def train(models,x_train,params):generator,discriminator,adversarial = modelsbatch_size,latent_size,train_steps,model_name = params#每隔500个steps保存生成的图片save_interval = 500noise_input = np.random.uniform(-1.0,1.0,size=[16,latent_size])train_size = x_train.shape[0]for i in range(train_steps):rand_indexes = np.random.randint(0,train_size,size=batch_size)real_images = x_train[rand_indexes]noise = np.random.uniform(-1.,1.,size=[batch_size,latent_size])fake_images = generator.predict(noise)x = np.concatenate([real_images,fake_images])y = np.ones([2 * batch_size,1])y[batch_size:,:] = 0.0loss,acc = discriminator.train_on_batch(x,y)log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)noise = np.random.uniform(-1.,1.,size=[batch_size,latent_size])y = np.ones([batch_size,1])# print(noise.shape)# print(y.shape)loss,acc = adversarial.train_on_batch(noise,y)log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)print(log)if (i + 1) % save_interval == 0:plot_images(generator,noise_input,show=False,step=(i+1),model_name=model_name)generator.save(model_name+'.h5')def plot_images(generator,noise_input,show=False,step=0,model_name="gan"):os.makedirs(model_name, exist_ok=True)filename = os.path.join(model_name, "%05d.png" % step)images = generator.predict(noise_input)plt.figure(figsize=(2.2, 2.2))num_images = images.shape[0]image_size = images.shape[1]rows = int(math.sqrt(noise_input.shape[0]))for i in range(num_images):plt.subplot(rows, rows, i + 1)image = np.reshape(images[i], [image_size, image_size])plt.imshow(image, cmap='gray')plt.axis('off')plt.savefig(filename)if show:plt.show()else:plt.close('all')def test_generator(generator):noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])plot_images(generator,noise_input=noise_input,show=True,model_name="test_outputs")if __name__ == '__main__':build_and_train_model()

可以看到,随着训练的进行模型的生成质量不断得到提高:
在这里插入图片描述

3. 生成图像

模型训练完成后,可以加载生成器生成图像。

from tensorflow.keras.models import load_model
generator = load_model(args.generator)
test_generator(generator)
http://www.dtcms.com/a/469702.html

相关文章:

  • 18-基于STM32的智能医嘱手环设计与实现
  • encodeURIComponent() 函数详解
  • 在JavaScript中,map方法使用指南
  • 手机网站好还是h5好找大学生做家教的网站
  • vue项目安装使用,npm、webpack版本问题注意
  • Arbess从入门到实战(12) - 使用Arbess+Gitee+SonarQube实现Node.js项目自动化构建部署
  • 旅游网站模板 手机网站构建
  • 单遍聚类:实时数据流聚类解决方案
  • 使用TimeSformer进行模型训练(mvp验证)
  • MES系统业务流程全面解析
  • ASE03-树叶随风晃动-02收尾
  • 有哪些网站可以免费做外销用自己电脑建网站
  • 【算法】1019.链表中的下一个更大节点--通俗讲解
  • 福州seo建站互联网营销师考试题库
  • Flutter中的动效实现方式
  • Agent 的感知-决策-行动循环实现
  • Azure托管标识完整指南:安全无密码的云身份验证
  • Azure Front Door 在中国区正式上线
  • 基础 - 正则表达式
  • 旅游网站系统网站上设置多语言怎么做
  • 第三方软件验收测试公司【如何深入理解SSL/TLS证书】
  • JavaWeb——ServletConfig
  • QwenVL - 202310版-论文阅读
  • 如何从 FastReport .NET 将报表导出为 JPEG / PNG / BMP / GIF / TIFF / EMF
  • .NET MCP Server 开发教程
  • LeetCode 124. 二叉树中的最大路径和(困难)
  • 建设南大街小学网站wordpress首页调用指定文章列表
  • 大型语言模型(LLM)基础:从原理到核心概念详解(GPT-4 / 文心一言 / 通义千问)
  • python高级03——多任务编程
  • 树模型优劣大比拼xgboost/lightgbm/RF/catboost,股价预测怎么选模型