辅助分类器GAN(ACGAN)
辅助分类器GAN(ACGAN)
- 0. 前言
- 1. ACGAN 网络架构
- 2. ACGAN 损失函数
- 3. 实现 ACGAN
- 4. 生成结果
0. 前言
辅助分类器生成对抗网络 (Auxiliary Classifier GAN
, ACGAN
) 的原理与条件 GAN (CGAN) 相似。对于 CGAN
和 ACGAN
,生成器输入均为噪声及其标签,输出均为属于输入类别标签的伪造图像。在 CGAN
中,判别器的输入是图像(真实或伪造)及其标签,输出是图像为真实概率。而在 ACGAN
中,判别器仅以图像作为输入,输出则同时包含图像真实性概率及其类别标签。
1. ACGAN 网络架构
下图展示了生成器训练阶段 CGAN
与 ACGAN
的核心差异:
本质上,CGAN
通过注入辅助信息(标签)来构建网络,而 ACGAN
则通过辅助类别解码器网络重构辅助信息。ACGAN
理论指出,强制网络执行额外任务已被证实能提升原始任务的性能。在此框架中,图像分类作为附加任务,而原始任务仍是生成伪造图像。
2. ACGAN 损失函数
下表展示了 ACGAN
与 CGAN 的损失函数对比:
网络类型 | 损失函数表达式 |
---|---|
CGAN | L(D)=−Ex∼pdatalogD(x∣y)−Ezlog(1−D(G(z∣y)))L(G)=−EzlogD(G(z∣y))\mathcal L^{(D)} =-\mathbb E_{x\sim p_{data}}logD(x|y)-\mathbb E_zlog(1-D(G(z|y)))\\ \mathcal L^{(G)}=-\mathbb E_zlogD(G(z|y))L(D)=−Ex∼pdatalogD(x∣y)−Ezlog(1−D(G(z∣y)))L(G)=−EzlogD(G(z∣y)) |
ACGAN | L(D)=−Ex∼pdatalogD(x)−Ezlog(1−D(G(z∣y)))−Ex∼pdatalogP(c∣x)−EzlogP(c∣G(z∣y))L(G)=−EzlogD(G(z∣y))−EzlogP(c∣G(z∣y))\mathcal L^{(D)} =-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog(1-D(G(z|y)))-\mathbb E_{x\sim p_{data}}logP(c|x)-\mathbb E_zlogP(c|G(z|y))\\\mathcal L^{(G)}=-\mathbb E_zlogD(G(z|y))-\mathbb E_zlogP(c|G(z|y))L(D)=−Ex∼pdatalogD(x)−Ezlog(1−D(G(z∣y)))−Ex∼pdatalogP(c∣x)−EzlogP(c∣G(z∣y))L(G)=−EzlogD(G(z∣y))−EzlogP(c∣G(z∣y)) |
ACGAN
的损失函数在 CGAN
基础上增加了分类器损失函数。除了辨别图像真伪的核心任务 −Ex∼pdatalogD(x)−Ezlog(1−D(G(z∥y)))-\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog(1-D(G(z\|y)))−Ex∼pdatalogD(x)−Ezlog(1−D(G(z∥y))),判别器公式还增加了对真实与伪造图像正确分类的任务 −Ex∼pdatalogP(c∥x)−EzlogP(c∥G(z∥y))-\mathbb E_{x\sim p_{data}}logP(c\|x)-\mathbb E_zlogP(c\|G(z\|y))−Ex∼pdatalogP(c∥x)−EzlogP(c∥G(z∥y))。生成器公式意味着除了试图用伪造图像欺骗判别器 −EzlogD(G(z∥y))-\mathbb E_zlogD(G(z\|y))−EzlogD(G(z∥y)),还要求判别器对这些伪造图像进行正确分类 −EzlogP(c∥G(z∥y))-\mathbb E_zlogP(c\|G(z\|y))−EzlogP(c∥G(z∥y))。
3. 实现 ACGAN
基于 CGAN 代码实现 ACGAN
时,仅需修改判别器和训练函数。模型包含两个损失函数:第一个是原始二元交叉熵损失,用于训练判别器评估输入图像为真实图像的概率。
import numpy as np
from tensorflow import keras
import tensorflow as tf
import os
from matplotlib import pyplot as pyplot
import gandef build_and_train_models():#数据加载及预处理(x_train,y_train),_ = keras.datasets.mnist.load_data()image_size = x_train.shape[1]x_train = np.reshape(x_train,[-1,image_size,image_size,1])x_train = x_train.astype('float32') / 255.num_labels = len(np.unique(y_train))y_train = keras.utils.to_categorical(y_train)#超参数model_name = 'acgan-mnist'latent_size = 100batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size,image_size,1)label_shape = (num_labels,)#discriminatorinputs = keras.layers.Input(shape=input_shape,name='discriminator_input')discriminator = gan.discriminator(inputs,num_labels=num_labels)optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)loss = ['binary_crossentropy','categorical_crossentropy']discriminator.compile(loss=loss,optimizer=optimizer,metrics=['acc'])discriminator.summary()#generatorinput_shape = (latent_size,)inputs = keras.layers.Input(shape=input_shape,name='z_input')labels = keras.layers.Input(shape=label_shape,name='labels')generator = gan.generator(inputs,image_size,labels=labels)generator.summary()optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)discriminator.trainable = Falseadversarial = keras.Model([inputs,labels],discriminator(generator([inputs,labels])),name=model_name)adversarial.compile(loss=loss,optimizer=optimizer,metrics=['acc'])adversarial.summary()models = (generator,discriminator,adversarial)data = (x_train,y_train)params = (batch_size,latent_size,train_steps,num_labels,model_name)train(models,data,params)def train(models,data,params):generator,discriminator,adversarial = modelsx_train,y_train = databatch_size,latent_size,train_steps,num_labels,model_name = paramssave_interval = 500noise_input = np.random.uniform(-1.,1.,size=[16,latent_size])noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]train_size = x_train.shape[0]print(model_name,'Labels for generated images: ',np.argmax(noise_label,axis=1))for i in range(train_steps):#train the diacriminator for 1 batch#1 batch of real and fake imagesrand_indexes = np.random.randint(0,train_size,size=batch_size)real_images = x_train[rand_indexes]real_labels = y_train[rand_indexes]#generate fake images from noise using generatornoise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]fake_images = generator.predict([noise,fake_labels])#train babelsx = np.concatenate((real_images,fake_images))labels = np.concatenate((real_labels,fake_labels))#label real and fake imagesy = np.ones([2*batch_size,1])y[batch_size:,:] = 0.0#train modelmetrics = discriminator.train_on_batch(x,[y,labels])fmt = '%d: [disc loss: %f, srcloss: %f],'fmt += 'lbloss: %f, srcacc: %f, lblacc: %f'log = fmt % (i,metrics[0],metrics[1],metrics[2],metrics[3],metrics[4])#train adversarial network for 1 batchnoise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]y = np.ones([batch_size,1])metrics = adversarial.train_on_batch([noise,fake_labels],[y,fake_labels])fmt = "%s [advr loss: %f, srcloss: %f,"fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4])print(log)if (i + 1) % save_interval == 0:# plot generator images on a periodic basisgan.plot_images(generator,noise_input=noise_input,noise_label=noise_label,show=False,step=(i + 1),model_name=model_name)generator.save(model_name + ".h5")if __name__ == '__main__':build_and_train_models()
4. 生成结果
事实证明,通过引入附加任务,ACGAN
的性能相比原始 GAN 模型有了显著提升。如下图所示,ACGAN
的训练过程表现稳定: