生成模型实战 | MUNIT详解与实现
生成模型实战 | MUNIT详解与实现
- 0. 前言
- 1. MUNIT 原理
- 2. MUNIT 架构
- 3. 自编码器设计
- 4. 实现 MUNIT
- 4.1 数据集处理
- 4.2 模型构建
- 4.3 模型训练与测试
0. 前言
风格迁移的创新影响了生成对抗网络 (Generative Adversarial Network, GAN) 的发展。尽管 GAN
可以生成逼真的图像,但是它们多是通过使用随机潜编码生成的,我们对它们表示的内容知之甚少。即使多模式 GAN
可能会在生成的图像中产生变化,但我们仍不知道如何控制潜编码来实现所需的结果。
我们希望可以独立控制我们想要生成的特征,这称为解耦表示 (disentangled representation
),解耦表示的想法是将图像分离为独立的表示。例如,一张脸有两只眼睛,鼻子和嘴巴。正如我们从风格迁移中学到的,图像可以分解为内容和风格。因此研究人员将该想法引入了 GAN
中。
接下来,我们将介绍基于风格的 GAN
,多模式无监督图像到图像转换 (Multimodal Unsupervised Image-to-Image Translation
, MUNIT
),详细探讨整个体系结构以了解在这些模型中如何使用风格。
1. MUNIT 原理
MUNIT
是类似于 BicycleGAN 的图像到图像转换模型。两者都可以生成具有连续分布的多模式图像,但是 BicycleGAN
需要具有成对的数据,而 MUNIT
则不需要。BicycleGAN
通过使用两个将目标图像与潜编码相关联的模型来生成多模式图像。但我们并不清楚这些模型如何工作,也不清楚如何控制潜编码以修改输出。MUNIT
的方法在概念上有很多不同,但也很容易理解,其假定源图像和目标图像共享相同的内容空间,但是具有不同的风格。
下图显示了 MUNIT
背后的原理:
假设我们有两个图像,X1X_1X1 和 X2X_2X2。它们中的每一个都可以分别表示为内容编码和风格编码对 (C1,S1)(C_1, S_1)(C1,S1) 和 (C2,S2)(C_2, S_2)(C2,S2)。假定 C1C_1C1 和 C2C_2C2 都位于共享的内容空间 CCC 中。换句话说,内容可能不完全相同,但相似。风格位于它们各自的特定于域的风格空间中。因此,可以将来自 X1X_1X1 和 X2X_2X2 的图像转换为使用来自 X1X_1X1 的内容编码和来自 X2X_2X2 的风格编码,或者换句话说,根据编码 (C1,S2)(C_1, S_2)(C1,S2) 生成图像。
在风格迁移中,我们将风格视为具有不同笔触,颜色和纹理的艺术风格。现在,我们将风格的含义扩展到了艺术绘画之外。例如,老虎和狮子都是猫科,它们具有不同风格的胡须,皮肤,毛皮和形状。接下来,让我们看一下 MUNIT
模型架构。
2. MUNIT 架构
下图显示了 MUNIT
体系结构:
有两个自编码器,每个域中一个。自编码器将图像编码为其风格和内容编码,然后解码器将其解码回原始图像。这是使用对抗损失训练的,换句话说,模型由自编码器组成,但像 GAN
一样训练。
在上图中,图像重建过程显示在左侧,右边是跨域翻译。如前所述,要从 X1X_1X1 转换为 X2X_2X2,我们首先将图像编码为它们各自的内容和风格编码,然后执行以下操作:
- 我们使用 (C1,S2)(C_1, S_2)(C1,S2) 在风格域
2
中生成伪造图像。这也是使用GAN
进行训练的 - 我们将生成图像编码为内容和风格编码。如果翻译效果很好,则应类似于 (C1,S2)(C_1, S_2)(C1,S2)
这类似于 CycleGAN 的循环一致性约束,但是这里的循环一致性不应用于图像,而是应用于内容和风格编码。
3. 自编码器设计
最后,让我们看一下自编码器的详细架构,如下图所示:
与其他风格迁移模型不同,MUNIT
不使用 VGG
作为编码器。它使用两个单独的编码器,一个用于内容,另一个用于风格。内容编码器由几个残差块组成,具有实例标准化和下采样功能。这与 VGG
的风格功能非常相似。
风格编码器与内容编码器在两个方面有所不同:
- 首先,没有归一化,将激活归一化为零意味着删除风格信息
- 其次,将残差块替换为全连接层。这是因为风格被视为空间不变的,因此我们不需要卷积层即可提供空间信息
也就是说,风格编码仅包含有关眼睛颜色的信息,而无需知道眼睛在哪里,因为这是内容编码的责任。风格编码是低维向量,通常大小为 8
,这与 GAN
和变分自编码器 (Variational Autoencoder, VAE) 中的高维潜编码不同,并且在风格迁移中具有风格特征。风格编码尺寸较小的原因是,使我们可以使用较少的特征来控制风格,更易于管理。下图显示了内容和风格编码如何输入解码器:
解码器中的生成器由一组残差块组成。仅第一组中的残差块将自适应实例归一化 (adaptive instance normalization
, AdaIN
) 用作归一化层。AdaIN
方程如下,其中 zzz 是来自前一个卷积层的激活:
AdaIN(z,γ,β)=γ(z−μ(z)σ(z))+βAdaIN(z,\gamma,\beta)=\gamma(\frac{z-\mu(z)}{\sigma(z)})+\beta AdaIN(z,γ,β)=γ(σ(z)z−μ(z))+β
在前馈神经风格传递中,我们使用来自单个风格层的均值和标准差作为 AdaIN
中的 γγγ 和 βββ 。在 MUNIT
中,使用多层感知器从风格编码生成 γγγ 和 βββ。
4. 实现 MUNIT
接下来,使用 PyTorch
实现 MUNIT
模型,并使用 edges2shoes 数据集进行训练,数据集中包含鞋子的边缘图像和真实图像。
4.1 数据集处理
(1) 导入所需库,解析参数并定义设备:
import os
import argparse
import random
from PIL import Image
from tqdm import tqdmimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from glob import globparser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, required=True, help='path to dataset root, expects trainA, trainB, testA, testB')
parser.add_argument('--save_dir', type=str, default='./outputs', help='where to save samples and checkpoints')
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--img_size', type=int, default=128)
parser.add_argument('--style_dim', type=int, default=8)
parser.add_argument('--n_res', type=int, default=4)
args = parser.parse_args()os.makedirs(args.save_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
(2) 原始数据集为拼接图像,左侧为输入边缘图像,右侧为真实图像,如下所示:
将数据集拆分为边缘图像(左)与真实图像(右),同时做随机数据增强(翻转):
class PairedImageDataset(Dataset):def __init__(self, root, mode='train', transform=None):super().__init__()self.dir = os.path.join(root, mode)self.paths = sorted(glob(os.path.join(self.dir, '*.jpg')) + glob(os.path.join(self.dir, '*.png')))self.transform = transformdef __len__(self):return len(self.paths)def __getitem__(self, idx):p = self.paths[idx]img = Image.open(p).convert('RGB')w, h = img.sizew2 = w // 2# left is edges, right is photoinput_img = img.crop((0, 0, w2, h))target_img = img.crop((w2, 0, w, h))if self.transform:input_img = self.transform(input_img)target_img = self.transform(target_img)return input_img, target_imgtransform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,)*3, (0.5,)*3)
])train_dataset = PairedImageDataset(DATA_ROOT, 'train', transform=transform)
val_dataset = PairedImageDataset(DATA_ROOT, 'val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
4.2 模型构建
(1) 定义初始化权重、计算均值方差函数:
def weights_init(m):if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.kaiming_normal_(m.weight, a=0.2)if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):if m.weight is not None:nn.init.ones_(m.weight)if m.bias is not None:nn.init.zeros_(m.bias)def calc_mean_std(feat, eps=1e-5):# feat: [B, C, H, W] -> mean/std per channel (for AdaIN)b, c = feat.size()[:2]feat_var = feat.view(b, c, -1).var(dim=2) + epsfeat_std = feat_var.sqrt().view(b, c, 1, 1)feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)return feat_mean, feat_std
(2) 实现自适应实例归一化:
class AdaIN(nn.Module):def __init__(self):super().__init__()def forward(self, content_feat, gamma, beta):# gamma, beta shape: [B, C] or [B, C, 1, 1]b, c = content_feat.size()[:2]content_mean, content_std = calc_mean_std(content_feat)if gamma.dim() == 2:gamma = gamma.view(b, c, 1, 1)beta = beta.view(b, c, 1, 1)normalized = (content_feat - content_mean) / content_stdreturn normalized * gamma + beta
(3) 利用 AdaIN
类实现残差块:
class ResBlock(nn.Module):def __init__(self, dim, norm='in', use_adain=False):super().__init__()self.use_adain = use_adainself.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)self.conv2 = nn.Conv2d(dim, dim, 3, 1, 1)if not use_adain:self.norm1 = nn.InstanceNorm2d(dim, affine=True)self.norm2 = nn.InstanceNorm2d(dim, affine=True)else:# AdaIN params will be set externallyself.norm1 = Noneself.norm2 = Noneself.relu = nn.ReLU(inplace=True)def forward(self, x, adain1=None, adain2=None):y = self.conv1(x)if self.use_adain:y = adain1(y)else:y = self.norm1(y)y = self.relu(y)y = self.conv2(y)if self.use_adain:y = adain2(y)else:y = self.norm2(y)return x + y
(4) 实现内容自编码器:
class ContentEncoder(nn.Module):def __init__(self, in_channels=3, dim=64, n_downsample=2, n_res=4):super().__init__()layers = [nn.Conv2d(in_channels, dim, 7, 1, 3), nn.InstanceNorm2d(dim, affine=True), nn.ReLU(True)]# downsamplingfor i in range(n_downsample):layers += [nn.Conv2d(dim, dim*2, 4, 2, 1), nn.InstanceNorm2d(dim*2, affine=True), nn.ReLU(True)]dim *= 2# residual blocksfor i in range(n_res):layers += [ResBlock(dim, use_adain=False)]self.model = nn.Sequential(*layers)self.apply(weights_init)def forward(self, x):return self.model(x) # output shape: [B, C, H', W']
(5) 实现风格自编码器:
class StyleEncoder(nn.Module):def __init__(self, in_channels=3, dim=64, n_downsample=4, style_dim=8):super().__init__()layers = [nn.Conv2d(in_channels, dim, 7, 1, 3), nn.ReLU(True)]for i in range(n_downsample):layers += [nn.Conv2d(dim, dim*2, 4, 2, 1), nn.ReLU(True)]dim *= 2layers += [nn.AdaptiveAvgPool2d(1), nn.Conv2d(dim, style_dim, 1, 1, 0)]self.model = nn.Sequential(*layers)self.apply(weights_init)def forward(self, x):return self.model(x).view(x.size(0), -1) # [B, style_dim]
(6) 定义 MLP
用于把风格向量映射为多个 AdaIN
参数 (γ\gammaγ,β\betaβ),为每个残差层输出一对参数:
class MLP(nn.Module):def __init__(self, style_dim=8, hidden=256, num_adain_params=0):super().__init__()self.fc = nn.Sequential(nn.Linear(style_dim, hidden),nn.ReLU(True),nn.Linear(hidden, hidden),nn.ReLU(True),nn.Linear(hidden, num_adain_params))self.apply(weights_init)def forward(self, style):return self.fc(style)
(7) 实现解码器,将内容特征与风格参数合成为图像,简单起见,本节对 AdaIN
参数分配做了简化,我们把每个残差块所需的 γγγ,βββ 直接从 MLP
输出按通道切片得到,而原始实现细节会更精细(每一层的不同通道维度不同、并且 AdaIN
应在每个归一化之后正确作用):
class Decoder(nn.Module):def __init__(self, out_channels=3, dim=256, n_upsample=2, n_res=4, style_dim=8):super().__init__()# we will assume content channels = dim (e.g., 256)self.dim = dimself.n_res = n_res# residual blocks with AdaINself.resblocks = nn.ModuleList([ResBlock(dim, use_adain=True) for _ in range(n_res)])# upsamplingups = []for i in range(n_upsample):ups += [nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(dim, dim//2, 5, 1, 2),nn.InstanceNorm2d(dim//2, affine=True),nn.ReLU(True)]dim = dim // 2self.ups = nn.Sequential(*ups)self.last = nn.Conv2d(dim, out_channels, 7, 1, 3)self.tanh = nn.Tanh()self.apply(weights_init)def forward(self, content, adain_params):# 为所有通道共享参数向量切片(简化实现)b, c, h, w = content.size()# build lambda functions for AdaIN per block# adain_params shape [B, n_res * 2 * C] (we will assume it's passed appropriately)ptr = 0x = content# 将每个残差块的 gamma/beta 拆成 shape [B, C]for i, rb in enumerate(self.resblocks):# 每层取 2*C 参数:gamma, betagamma = adain_params[:, ptr:ptr + c]; ptr += cbeta = adain_params[:, ptr:ptr + c]; ptr += c# wrap into AdaIN lambdasadain1 = lambda feat, g=gamma, b_=beta: AdaIN()(feat, g, b_)adain2 = adain1x = rb(x, adain1, adain2)x = self.ups(x)x = self.last(x)return self.tanh(x)
(8) 定义 PatchGAN 判别器(用于对抗损失):
class Discriminator(nn.Module):def __init__(self, in_channels=3, ndf=64):super().__init__()layers = [nn.Conv2d(in_channels, ndf, 4, 2, 1), nn.LeakyReLU(0.2, True)]nf_mult = 1nf = ndffor n in range(1, 4):layers += [nn.Conv2d(nf, nf*2, 4, 2, 1), nn.InstanceNorm2d(nf*2), nn.LeakyReLU(0.2, True)]nf = nf * 2layers += [nn.Conv2d(nf, 1, 4, 1, 1)]self.model = nn.Sequential(*layers)self.apply(weights_init)def forward(self, x):return self.model(x)
(9) 构建模型实例,并准备优化器:
# model params
style_dim = args.style_dim
content_dim = 256 # 内容自编码器最终通道数(与 Decoder 一致)
# instantiate
E_content = ContentEncoder(in_channels=3, dim=64, n_downsample=2, n_res=args.n_res).to(device)
E_style = StyleEncoder(in_channels=3, dim=64, n_downsample=4, style_dim=style_dim).to(device)# compute number of adain params needed: for each ResBlock, 2 * C params (gamma + beta)
num_adain_params = args.n_res * 2 * content_dim
MLP_style = MLP(style_dim=style_dim, hidden=256, num_adain_params=num_adain_params).to(device)
Dec = Decoder(out_channels=3, dim=content_dim, n_upsample=2, n_res=args.n_res, style_dim=style_dim).to(device)D_A = Discriminator(3).to(device)
D_B = Discriminator(3).to(device)# optimizers
g_params = list(E_content.parameters()) + list(E_style.parameters()) + list(MLP_style.parameters()) + list(Dec.parameters())
opt_G = torch.optim.Adam(g_params, lr=args.lr, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(list(D_A.parameters()) + list(D_B.parameters()), lr=args.lr, betas=(0.5, 0.999))
(10) 定义对抗损失,使用 hinge loss
:
l1_loss = nn.L1Loss()
mse_loss = nn.MSELoss()def discriminator_hinge_loss(real_pred, fake_pred):loss_real = torch.mean(F.relu(1. - real_pred))loss_fake = torch.mean(F.relu(1. + fake_pred))return 0.5 * (loss_real + loss_fake)def generator_hinge_loss(fake_pred):# generator wants fake_pred to be largereturn -torch.mean(fake_pred)
(11) 定义训练循环,对于每个 batch
:
- 从域
A
和B
采样图片 - 编码
content_A
,style_A
,content_B
,style_B
- 随机采样
style_z
(来自正态分布)作为目标域风格,或使用风格自编码器提取 - 生成
A2B = Dec(content_A, style_z_B)
,并反向B2A
- 计算图像重建损失(把
A
的内容与style_A
重构回A_rec
) - 计算重建损失,计算
GAN
损失 - 更新判别器与生成器
global_iter = 0
sample_dir = os.path.join(args.save_dir, 'samples')
os.makedirs(sample_dir, exist_ok=True)
ckpt_dir = os.path.join(args.save_dir, 'checkpoints')
os.makedirs(ckpt_dir, exist_ok=True)for epoch in range(args.epochs):pbar = tqdm(train_loader)for i, (A_img, B_img) in enumerate(pbar):A_img = A_img.to(device)B_img = B_img.to(device)content_A = E_content(A_img)content_B = E_content(B_img)style_A = E_style(A_img)style_B = E_style(B_img)# sample random style codes for multimodal generationstyle_rand_B = torch.randn(A_img.size(0), style_dim, device=device)style_rand_A = torch.randn(A_img.size(0), style_dim, device=device)# produce adain params from style vectorsadain_params_B = MLP_style(style_rand_B) # shape [B, num_adain_params]adain_params_A = MLP_style(style_rand_A)A2B = Dec(content_A, adain_params_B)B2A = Dec(content_B, adain_params_A)# D on realD_A_real = D_A(A_img)D_B_real = D_B(B_img)# D on fake (detach to avoid gradient to G)D_A_fake = D_A(B2A.detach())D_B_fake = D_B(A2B.detach())loss_D_A = discriminator_hinge_loss(D_A_real, D_A_fake)loss_D_B = discriminator_hinge_loss(D_B_real, D_B_fake)loss_D = loss_D_A + loss_D_Bopt_D.zero_grad()loss_D.backward()opt_D.step()# adversarial loss (want D to predict real for fakes)D_A_fake_for_G = D_A(B2A)D_B_fake_for_G = D_B(A2B)loss_G_adv = generator_hinge_loss(D_A_fake_for_G) + generator_hinge_loss(D_B_fake_for_G)# reconstruction lossesadain_params_A_from_styleA = MLP_style(style_A)adain_params_B_from_styleB = MLP_style(style_B)A_rec = Dec(content_A, adain_params_A_from_styleA)B_rec = Dec(content_B, adain_params_B_from_styleB)loss_img_rec = l1_loss(A_rec, A_img) + l1_loss(B_rec, B_img)content_A2B = E_content(A2B)style_A2B = E_style(A2B)loss_content_recon = l1_loss(content_A2B, content_A.detach())loss_style_recon = l1_loss(style_A2B, style_rand_B.detach())# total generator loss (weights chosen as in MUNIT idea, here simplified)loss_G = loss_G_adv * 1.0 + loss_img_rec * 10.0 + loss_content_recon * 1.0 + loss_style_recon * 1.0opt_G.zero_grad()loss_G.backward()opt_G.step()# loggingpbar.set_description(f"Epoch {epoch} D:{loss_D.item():.4f} G:{loss_G.item():.4f} rec:{loss_img_rec.item():.4f}")# save sample images occasionallyif global_iter % 500 == 0:with torch.no_grad():# sample: A -> multiple B styles (rand + exemplar)n = min(4, A_img.size(0))# random stylesrand_styles = torch.randn(n, style_dim, device=device)rand_adain = MLP_style(rand_styles)outs = Dec(content_A[:n], rand_adain)# also reconrecs = A_rec[:n]grid = torch.cat([A_img[:n], outs, recs], dim=0)# denormalize & savesave_image((grid + 1) / 2.0, os.path.join(sample_dir, f'{global_iter}.png'), nrow=n)global_iter += 1# save checkpoint per epochtorch.save({'E_content': E_content.state_dict(),'E_style': E_style.state_dict(),'MLPStyle': MLP_style.state_dict(),'Dec': Dec.state_dict(),'D_A': D_A.state_dict(),'D_B': D_B.state_dict(),'optG': opt_G.state_dict(),'optD': opt_D.state_dict()}, os.path.join(ckpt_dir, f'ckpt_epoch_{epoch}.pth'))
4.3 模型训练与测试
(1) 将代码保存为 munit.py
,使用以下命令启动训练:
$ python munit.py --data_root data/edges2shoes/ --epochs 50 --batch_size 32
(2) 定义测试函数,并调用训练完成的模型:
from glob import glob
import torchvision.transforms as T@torch.no_grad()
def test_munit(checkpoint_path,data_root,out_dir='./test_results',domain='A2B',num_style=3,img_size=128,style_dim=8,device='cuda' if torch.cuda.is_available() else 'cpu'
):os.makedirs(out_dir, exist_ok=True)# ---- 加载模型结构 ----content_dim = 256E_content = ContentEncoder(3, 64, n_downsample=2, n_res=4).to(device)E_style = StyleEncoder(3, 64, n_downsample=4, style_dim=style_dim).to(device)num_adain_params = 4 * 2 * content_dimMLP_style = MLP(style_dim=style_dim, hidden=256, num_adain_params=num_adain_params).to(device)Dec = Decoder(3, content_dim, n_upsample=2, n_res=4, style_dim=style_dim).to(device)# ---- 加载权重 ----ckpt = torch.load(checkpoint_path, map_location=device)E_content.load_state_dict(ckpt['E_content'])E_style.load_state_dict(ckpt['E_style'])MLP_style.load_state_dict(ckpt['MLPStyle'])Dec.load_state_dict(ckpt['Dec'])E_content.eval(); E_style.eval(); MLP_style.eval(); Dec.eval()# ---- 图像读取 ----transform = T.Compose([T.Resize((img_size, img_size)),T.ToTensor(),T.Normalize([0.5]*3, [0.5]*3)])if domain == 'A2B':test_paths = sorted(glob(os.path.join(data_root, 'testA', '*')))else:test_paths = sorted(glob(os.path.join(data_root, 'testB', '*')))print(f"[INFO] Found {len(test_paths)} images for domain {domain} testing")for path in tqdm(test_paths, desc=f"Testing {domain}"):img_name = os.path.basename(path)x = Image.open(path).convert('RGB')x = transform(x).unsqueeze(0).to(device)# 内容编码c = E_content(x)# 多个风格随机采样results = []for j in range(num_style):style_z = torch.randn(1, style_dim, device=device)adain_params = MLP_style(style_z)out = Dec(c, adain_params)results.append(out)# 拼接输出results = torch.cat(results, dim=0) # [num_style, 3, H, W]grid = torch.cat([x.repeat(num_style, 1, 1, 1), results], dim=0)save_path = os.path.join(out_dir, f"{os.path.splitext(img_name)[0]}_{domain}.png")save_image((grid + 1) / 2.0, save_path, nrow=num_style)test_munit(checkpoint_path='./outputs/checkpoints/ckpt_epoch_49.pth',data_root='./data/edges2shoes',out_dir='./outputs/test_results',domain='A2B',num_style=8,img_size=128
)
生成结果如下所示: