当前位置: 首页 > news >正文

辅助分类器GAN(ACGAN)

辅助分类器GAN(ACGAN)

    • 0. 前言
    • 1. ACGAN 网络架构
    • 2. ACGAN 损失函数
    • 3. 实现 ACGAN
    • 4. 生成结果

0. 前言

辅助分类器生成对抗网络 (Auxiliary Classifier GAN, ACGAN) 的原理与条件 GAN (CGAN) 相似。对于 CGANACGAN,生成器输入均为噪声及其标签,输出均为属于输入类别标签的伪造图像。在 CGAN 中,判别器的输入是图像(真实或伪造)及其标签,输出是图像为真实概率。而在 ACGAN 中,判别器仅以图像作为输入,输出则同时包含图像真实性概率及其类别标签。

1. ACGAN 网络架构

下图展示了生成器训练阶段 CGANACGAN 的核心差异:

CGAN与ACGAN

本质上,CGAN 通过注入辅助信息(标签)来构建网络,而 ACGAN 则通过辅助类别解码器网络重构辅助信息。ACGAN 理论指出,强制网络执行额外任务已被证实能提升原始任务的性能。在此框架中,图像分类作为附加任务,而原始任务仍是生成伪造图像。

2. ACGAN 损失函数

下表展示了 ACGAN 与 CGAN 的损失函数对比:

网络类型损失函数表达式
CGANL(D)=−Ex∼pdatalogD(x∣y)−Ezlog(1−D(G(z∣y)))L(G)=−EzlogD(G(z∣y))\mathcal L^{(D)} =-\mathbb E_{x\sim p_{data}}logD(x|y)-\mathbb E_zlog(1-D(G(z|y)))\\ \mathcal L^{(G)}=-\mathbb E_zlogD(G(z|y))L(D)=ExpdatalogD(xy)Ezlog(1D(G(zy)))L(G)=EzlogD(G(zy))
ACGANL(D)=−Ex∼pdatalogD(x)−Ezlog(1−D(G(z∣y)))−Ex∼pdatalogP(c∣x)−EzlogP(c∣G(z∣y))L(G)=−EzlogD(G(z∣y))−EzlogP(c∣G(z∣y))\mathcal L^{(D)} =-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog(1-D(G(z|y)))-\mathbb E_{x\sim p_{data}}logP(c|x)-\mathbb E_zlogP(c|G(z|y))\\\mathcal L^{(G)}=-\mathbb E_zlogD(G(z|y))-\mathbb E_zlogP(c|G(z|y))L(D)=ExpdatalogD(x)Ezlog(1D(G(zy)))ExpdatalogP(cx)EzlogP(cG(zy))L(G)=EzlogD(G(zy))EzlogP(cG(zy))

ACGAN 的损失函数在 CGAN 基础上增加了分类器损失函数。除了辨别图像真伪的核心任务 −Ex∼pdatalogD(x)−Ezlog(1−D(G(z∥y)))-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog(1-D(G(z\|y)))ExpdatalogD(x)Ezlog(1D(G(zy))),判别器公式还增加了对真实与伪造图像正确分类的任务 −Ex∼pdatalogP(c∥x)−EzlogP(c∥G(z∥y))-\mathbb E_{x\sim p_{data}}logP(c\|x)-\mathbb E_zlogP(c\|G(z\|y))ExpdatalogP(cx)EzlogP(cG(zy))。生成器公式意味着除了试图用伪造图像欺骗判别器 −EzlogD(G(z∥y))-\mathbb E_zlogD(G(z\|y))EzlogD(G(zy)),还要求判别器对这些伪造图像进行正确分类 −EzlogP(c∥G(z∥y))-\mathbb E_zlogP(c\|G(z\|y))EzlogP(cG(zy))

3. 实现 ACGAN

基于 CGAN 代码实现 ACGAN 时,仅需修改判别器和训练函数。模型包含两个损失函数:第一个是原始二元交叉熵损失,用于训练判别器评估输入图像为真实图像的概率。

import numpy as np
from tensorflow import keras
import tensorflow as tf
import os
from matplotlib import pyplot as pyplot
import gandef build_and_train_models():#数据加载及预处理(x_train,y_train),_ = keras.datasets.mnist.load_data()image_size = x_train.shape[1]x_train = np.reshape(x_train,[-1,image_size,image_size,1])x_train = x_train.astype('float32') / 255.num_labels = len(np.unique(y_train))y_train = keras.utils.to_categorical(y_train)#超参数model_name = 'acgan-mnist'latent_size = 100batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size,image_size,1)label_shape = (num_labels,)#discriminatorinputs = keras.layers.Input(shape=input_shape,name='discriminator_input')discriminator = gan.discriminator(inputs,num_labels=num_labels)optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)loss = ['binary_crossentropy','categorical_crossentropy']discriminator.compile(loss=loss,optimizer=optimizer,metrics=['acc'])discriminator.summary()#generatorinput_shape = (latent_size,)inputs = keras.layers.Input(shape=input_shape,name='z_input')labels = keras.layers.Input(shape=label_shape,name='labels')generator = gan.generator(inputs,image_size,labels=labels)generator.summary()optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)discriminator.trainable = Falseadversarial = keras.Model([inputs,labels],discriminator(generator([inputs,labels])),name=model_name)adversarial.compile(loss=loss,optimizer=optimizer,metrics=['acc'])adversarial.summary()models = (generator,discriminator,adversarial)data = (x_train,y_train)params = (batch_size,latent_size,train_steps,num_labels,model_name)train(models,data,params)def train(models,data,params):generator,discriminator,adversarial = modelsx_train,y_train = databatch_size,latent_size,train_steps,num_labels,model_name = paramssave_interval = 500noise_input = np.random.uniform(-1.,1.,size=[16,latent_size])noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]train_size = x_train.shape[0]print(model_name,'Labels for generated images: ',np.argmax(noise_label,axis=1))for i in range(train_steps):#train the diacriminator for 1 batch#1 batch of real and fake imagesrand_indexes = np.random.randint(0,train_size,size=batch_size)real_images = x_train[rand_indexes]real_labels = y_train[rand_indexes]#generate fake images from noise using generatornoise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]fake_images = generator.predict([noise,fake_labels])#train babelsx = np.concatenate((real_images,fake_images))labels = np.concatenate((real_labels,fake_labels))#label real and fake imagesy = np.ones([2*batch_size,1])y[batch_size:,:] = 0.0#train modelmetrics = discriminator.train_on_batch(x,[y,labels])fmt = '%d: [disc loss: %f, srcloss: %f],'fmt += 'lbloss: %f, srcacc: %f, lblacc: %f'log = fmt % (i,metrics[0],metrics[1],metrics[2],metrics[3],metrics[4])#train adversarial network for 1 batchnoise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]y = np.ones([batch_size,1])metrics = adversarial.train_on_batch([noise,fake_labels],[y,fake_labels])fmt = "%s [advr loss: %f, srcloss: %f,"fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4])print(log)if (i + 1) % save_interval == 0:# plot generator images on a periodic basisgan.plot_images(generator,noise_input=noise_input,noise_label=noise_label,show=False,step=(i + 1),model_name=model_name)generator.save(model_name + ".h5")if __name__ == '__main__':build_and_train_models()

4. 生成结果

事实证明,通过引入附加任务,ACGAN 的性能相比原始 GAN 模型有了显著提升。如下图所示,ACGAN 的训练过程表现稳定:

生成结果

http://www.dtcms.com/a/481579.html

相关文章:

  • 交流网站建设心得体会wordpress首页固定页面
  • 专门做有机食品的网站dedecms怎么部署网站
  • 大学生个体创业的网站建设百度搭建wordpress
  • 自己公司网站维护上海动易 网站
  • 摄影网站设计企业官网用什么系统
  • 网站开发工具最适合网站建设和网络优化
  • 东莞东坑网站设计中牟网站建设
  • 官方模板关键字生成的代码添加在网站的什么地方?郴州网站建设服务
  • 查看一个网站开发语言wap网站分享到微信
  • 网站备案流程教程今天的热点新闻
  • 网站站内链接福田欧曼前四后八
  • 网站建设具体方案免费企业邮箱排名
  • 温州建设小学的网站企业网站建设与实现的论文
  • 网站机房建设目的wordpress导航设置
  • 怎么构建网站wordpress 关闭伪静态
  • 做app推广上哪些网站做金融的看哪些网站
  • 机械设备做公司网站下载好了网站模板怎么开始做网站
  • 珠宝网站模板网络营销的概念与含义谷歌
  • 沧州网站建设联系电话做学徒哪个网站好
  • 著名的网站有哪些网页设计工资一般多少
  • 网站建设能挣钱免费的宣传平台有哪些
  • 外贸网站经典营销案例网站空间商是什么意思
  • 做教案比较好的网站国外友链买卖平台
  • 广东网站建设人员网址在线生成二维码
  • 东莞seo整站优化怎么做网站下载链接
  • 用路由器做简单的网站宁波正规seo推广
  • 有关商业网站的风格特征杭州seo公司
  • 做网站帮外国人淘宝深圳市龙岗区建设工程交易中心
  • 地产网站建设ghost和wordpress
  • 电子 公司 网站建设自助广告位网站源码