CycleGAN实现MNIST与SVHN风格迁移
CycleGAN实现MNIST与SVHN风格迁移
- 0. 前言
- 1. 代码实现
- 2. 生成结果
- 3. 标签翻转问题
0. 前言
我们已经学习了 CycleGAN 的基本原理,并展示了 CycleGAN 在图像转换中的应用。在本节中,我们将处理一个更具挑战性的任务。假设以灰度 MNIST
手写数字作为源数据,并希望借鉴 SVHN
数据集(作为目标数据)的风格特征。下图展示了两个域中的样本数据:
1. 代码实现
我们可以复用跨域生成对抗网络一节的 CycleGAN
构建和训练函数来实现风格迁移。唯一区别在于需要增加加载 MNIST
和 SVHN 数据的功能。
该 CycleGAN
结构与跨域生成对抗网络一节相同,唯一区别是鉴于两个域存在显著差异,我们采用了大小为 5
的卷积核。实现 MNIST
与 SVHN
间跨域风格迁移的 CycleGAN
:
def mnist_cross_svhn(g_models=None):model_name = 'cyclegan_mnist_svhn'batch_size = 64train_steps = 100000patchgan = Truekernel_size = 5postfix = ('%dp' % kernel_size) if patchgan else ('%d' % kernel_size)data,shapes = mnist_svhn_utils.load_data()source_data,_,test_source_data,test_target_data = datatitles = ('MNIST predicted source images.','SVHN predicted target images.','MNIST reconstructed source images.','SVHN reconstructed target images.')dirs = ('mnist_source-%s' % postfix,'svhn_target-%s' % postfix)#generate predictedif g_models is not None:g_source,g_target = g_modelsother_utils.test_generator((g_source,g_target),(test_source_data,test_target_data),step=0,titles=titles,dirs=dirs,show=True)return#build the cycleganmodels = build_cyclegan(shapes,'mnist-%s' % postfix,'svhn-%s' % postfix,kernel_size=kernel_size,patchgan=patchgan)# patch size is divided by 2^n since we downscaled the input# in the discriminator by 2^n (ie. we use strides=2 n times)patchgan = int(source_data.shape[1])params = (batch_size,train_steps,patchgan,model_name)test_params = (titles,dirs)#train the cyclegantrain_cyclegan(models,data,params,test_params,other_utils.test_generator)
使用以下命令启动模型训练:
$ python cyclegan.py -m
2. 生成结果
模型训练完成后,可以使用以下命令执行模型测试
$ python cyclegan.py --mnist_svhn_g_source=cyclegan_mnist_svhn-g_source.h5 --mnist_svhn_g_target=cyclegan_mnist_svhn-g_target.h5
测试集中 MNIST
数字向 SVHN
风格转换的结果如下图所示。生成图像虽具有 SVHN
的风格特征,但数字形态并未完全转换。例如数字 3
、1
经过 CycleGAN
风格化处理后,仍保留了原始数字的基本形态。而数字 9
、6
的风格化结果出现了显著变化:未使用 PatchGAN
的 CycleGAN
将其转换为了 0
、6
,而采用 PatchGAN
的 CycleGAN
则生成了 0
、65
等混合形态:
反向循环的转换结果如下图所示。目标图像来自 SVHN
测试数据集。生成图像虽具有 MNIST
的风格特征,但数字识别结果存在偏差。例如第一行中,数字 5
、2
和 210
在未使用 PatchGAN
的 CycleGAN
中被风格化为 7
、7
、8
,而在采用 PatchGAN
的版本中则被转换为 3
、3
、1
:
下图展示了 CycleGAN
在前向循环中重建 MNIST
数字的效果。重建后的 MNIST
数字与原始源数字几乎完全相同:
下图则呈现了 CycleGAN
在反向循环中重建 SVHN
数字的表现:
多数目标图像得以重建,部分数字保持高度一致,也有部分数字在风格保留的情况下发生转换。
3. 标签翻转问题
在 MNIST
向 SVHN
的转换过程中出现的源域数字被转换为目标域不同数字的现象,被称为标签翻转。尽管 CycleGAN
的预测结果满足循环一致性,但未必保证语义一致性,数字的原始语义含义在转换过程中出现了丢失。
为解决这一问题,Hoffman
提出了一种改进的 CycleGAN
模型——循环一致对抗域自适应 (Cycle-Consistent Adversarial Domain Adaptation
, CyCADA
)。其核心改进在于引入额外的语义损失项,确保预测结果不仅满足循环一致性,同时保持语义一致性。