使用CNN构建VAE
使用CNN构建VAE
- 0. 前言
- 1. 网络架构
- 2. 模型构建
- 3. 生成结果
0. 前言
在原始变分自编码器 (Variational Autoencoder, VAE) 中,VAE
网络采用全连接网络实现。本节将使用卷积神经网络 (Convolutional Neural Network, CNN) 提升生成数字的质量,同时将参数数量大幅减少至 134165
个。
1. 网络架构
下图展示了 CNN
编码器模型的两个输出——潜向量的均值与方差。其中的 lambda
函数实现了重参数化技巧,将随机潜编码的采样过程移至 VAE
网络外部:
下图展示了 CNN
解码器模型。其二维输入来自 lambda
函数,输出为重构的 MNIST
数字图像:
下图展示了完整的 CNN-VAE
模型,由编码器与解码器连接构成:
2. 模型构建
编码器由两层 CNN
和两层全连接层组成,用于生成潜编码。其输出结构与原始变分自编码器 (Variational Autoencoder, VAE) 类似。解码器则包含一层全连接层和三层转置卷积层。
#reparameterization trick
#z = z_mean + sqrt(var) * eps
def sampling(args):z_mean,z_log_var = argsbatch = keras.backend.shape(z_mean)[0]dim = keras.backend.shape(z_mean)[1]epsilon = keras.backend.random_normal(shape=(batch,dim))return z_mean + keras.backend.exp(0.5 * z_log_var) * epsilondef plot_results(models,data,batch_size=128,model_name='vae_mnist'):encoder,decoder = modelsx_test,y_test = dataxmin = ymin = -4xmax = ymax = +4os.makedirs(model_name,exist_ok=True)filename = os.path.join(model_name,'vae_mean.png')#display a 2D plot of the digit classes in the latent spacez,_,_ = encoder.predict(x_test,batch_size=batch_size)plt.figure(figsize=(12,10))#axes x and y rangesaxes = plt.gca()axes.set_xlim([xmin,xmax])axes.set_ylim([ymin,ymax])# subsampling to reduce density of points on the plotz = z[0::2]y_test = y_test[0::2]plt.scatter(z[:,0],z[:,1],marker='')for i,digit in enumerate(y_test):axes.annotate(digit,(z[i,0],z[i,1]))plt.xlabel('z[0]')plt.ylabel('z[1]')plt.savefig(filename)plt.show()filename = os.path.join(model_name,'digits_over_latent.png')#display a 30*30 2D mainfold of digitsn = 30digit_size = 28figure = np.zeros((digit_size * n,digit_size * n))#linearly spaced coordinates corresponding to the 2D plot of digit classes in the latent space#线性间隔的坐标,对应于潜在空间中数字类的二维图grid_x = np.linspace(-4,4,n)grid_y = np.linspace(-4,4,n)[::-1]for i,yi in enumerate(grid_x):for j,xi in enumerate(grid_y):z_sample = np.array([[xi,yi]])x_decoded = decoder.predict(z_sample)digit = x_decoded[0].reshape(digit_size,digit_size)figure[i * digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digitplt.figure(figsize=(10, 10))start_range = digit_size // 2end_range = (n-1) * digit_size + start_range + 1pixel_range = np.arange(start_range, end_range, digit_size)sample_range_x = np.round(grid_x, 1)sample_range_y = np.round(grid_y, 1)plt.xticks(pixel_range, sample_range_x)plt.yticks(pixel_range, sample_range_y)plt.xlabel("z[0]")plt.ylabel("z[1]")plt.imshow(figure, cmap='Greys_r')plt.savefig(filename)plt.show()# MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()image_size = x_train.shape[1]
# original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, image_size,image_size,1])
x_test = np.reshape(x_test, [-1, image_size,image_size,1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255#超参数
input_shape = (image_size,image_size,1)
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 50#VAE model
#encoder
inputs = keras.layers.Input(shape=input_shape,name='encoder_input')
x = inputs
for i in range(2):filters *= 2x = keras.layers.Conv2D(filters=filters,kernel_size=kernel_size,activation='relu',strides=2,padding='same')(x)
shape = keras.backend.int_shape(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(16,activation='relu')(x)
z_mean = keras.layers.Dense(latent_dim,name='z_mean')(x)
z_log_var = keras.layers.Dense(latent_dim,name='z_log_var')(x)z = keras.layers.Lambda(sampling,output_shape=(latent_dim,),name='z')([z_mean,z_log_var])encoder = keras.Model(inputs,[z_mean,z_log_var,z],name='encoder')
encoder.summary()
keras.utils.plot_model(encoder,to_file='vae_cnn_encoder.png',show_shapes=True)#decoder
latent_inputs = keras.layers.Input(shape=(latent_dim,),name='z_sampling')
x = keras.layers.Dense(shape[1]*shape[2]*shape[3],activation='relu')(latent_inputs)
x = keras.layers.Reshape((shape[1],shape[2],shape[3]))(x)
for i in range(2):x = keras.layers.Conv2DTranspose(filters=filters,kernel_size=kernel_size,activation='relu',strides=2,padding='same')(x)filters //= 2
outputs = keras.layers.Conv2DTranspose(filters=1,kernel_size=kernel_size,activation='sigmoid',padding='same',name='decoder_output')(x)
decoder = keras.Model(latent_inputs,outputs,name='decoder')
decoder.summary()keras.utils.plot_model(decoder,to_file='vae_cnn_decoder.png',show_shapes=True)outputs = decoder(encoder(inputs)[2])
vae = keras.Model(inputs,outputs,name='vae_cnn')if __name__ == '__main__':parser = argparse.ArgumentParser()help_ = "Load tf model trained weights"parser.add_argument("-w", "--weights", help=help_)help_ = "Use binary cross entropy instead of mse (default)"parser.add_argument("--bce", help=help_, action='store_true')args = parser.parse_args()models = (encoder, decoder)data = (x_test, y_test)#VAE loss = mse_loss or xent_loss + kl_lossif args.bce:reconstruction_loss = keras.losses.binary_crossentropy(keras.backend.flatten(inputs),keras.backend.flatten(outputs))else:reconstruction_loss = keras.losses.mse(keras.backend.flatten(inputs),keras.backend.flatten(outputs))reconstruction_loss *= image_size * image_sizekl_loss = 1 + z_log_var - keras.backend.square(z_mean) - keras.backend.exp(z_log_var)kl_loss = keras.backend.sum(kl_loss,axis=-1)kl_loss *= -0.5vae_loss = keras.backend.mean(reconstruction_loss + kl_loss)vae.add_loss(vae_loss)vae.compile(optimizer='adam')vae.summary()keras.utils.plot_model(vae,to_file='vae_cnn.png',show_shapes=True)save_dir = 'vae_cnn_weights'if not os.path.isdir(save_dir):os.makedirs(save_dir)if args.weights:filepath = os.path.join(save_dir,args.weights)vae = vae.load_weights(filepath)else:#trainvae.fit(x_train,epochs=epochs,batch_size=batch_size,validation_data=(x_test,None))filepath = os.path.join(save_dir,'vae_cnn.mnist.tf')vae.save_weights(filepath)plot_results(models,data,batch_size=batch_size,model_name='vae_cnn')
3. 生成结果
VAE
模型经过 30
个 epoch
训练后,下图展示了在连续潜空间中遍历时数字的分布情况。例如从中心区域向右移动时,数字会从 2
逐渐转变为 0
:
下图展示了生成模型的输出结果。从质量上看,与基于 MLP 实现 VAE 相比,模糊难辨的数字数量明显减少: