当前位置: 首页 > news >正文

生成模型实战 | 实时任意风格迁移

生成模型实战 | 实时任意风格迁移

    • 0. 前言
    • 1. 自适应实例归一化
    • 2. 风格迁移网络架构
    • 3. 构建编码器
    • 4. 构建解码器
    • 5. 构建风格迁移网络

0. 前言

我们已经学习了如何使用条件实例归一化来传输固定数量的风格。在本节中,我们将实现自适应实例归一化 (adaptive instance normalization, AdaIN),以使用多样性的风格执行风格迁移。

1. 自适应实例归一化

自适应实例归一化 (adaptive instance normalization, AdaIN) 也是条件实例归一化 (conditional instance normalization, CIN) 的一种,这意味着均值和标准差是在每个图像和每个通道 (H, W) 上计算的,而批归一化是在 (N, H, W) 上计算的。在 CIN 中, γ γ γ β β β 系数是可训练的变量,它们学习不同风格所需的均值和方差。在 AdaIN 中, γ γ γ β β β 被风格特征的标准差和均值所取代:
A d a I N ( x , y ) = σ ( y ) x − μ ( x ) σ ( x ) + μ ( y ) AdaIN(x,y)=\sigma(y)\frac {x-\mu (x)}{\sigma(x)} + \mu(y) AdaIN(x,y)=σ(y)σ(x)xμ(x)+μ(y)
AdaIN 可以理解为条件实例规范化的一种形式,其中条件是风格特征而不是风格标签。在训练和推理时,我们使用 VGG 提取风格层输出并将其统计信息用作风格条件,这样避免了只能预先定义一组固定风格。使用 PyTorch 实现 AdaIN 函数:

import os
import random
from PIL import Image
from tqdm import tqdmimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.utils as vutils
import torchvision# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ImageNet normalization for VGG
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
imagenet_std  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)def normalize_batch(batch):return (batch - imagenet_mean) / imagenet_stddef denormalize_batch(batch):# inverse of normalize, clamp to [0,1]batch = batch * imagenet_std + imagenet_meanreturn torch.clamp(batch, 0, 1)def calc_mean_std(feat, eps=1e-5):# feat: N x C x H x WN, C = feat.size()[:2]feat_var = feat.view(N, C, -1).var(dim=2, unbiased=False) + epsstd = feat_var.sqrt().view(N, C, 1, 1)mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)return mean, stddef adaptive_instance_normalization(content_feat, style_feat):"""AdaIN: align content feature's channelwise mean/std to style's"""c_mean, c_std = calc_mean_std(content_feat)s_mean, s_std = calc_mean_std(style_feat)normalized = (content_feat - c_mean) / c_stdreturn normalized * s_std + s_mean

可以看出,adaptive_instance_normalization() 是对 AdaIN 方程式的直接实现。其中 calc_mean_std 用于计算特征图的均值和方差,还使用 view() 方法使结果保持四个维度,形状为 (N, 1, 1, C),而不是默认值 (N, C)。接下来,我们将 AdaIN 整合到风格迁移中。

2. 风格迁移网络架构

下图显示了风格迁移网络的架构和训练管道:

风格迁移网络的架构和训练管道

风格迁移网络 (style transfer network, STN) 是编码器/解码器网络,其中,编码器使用固定的 VGG 对内容和风格特征进行编码。然后,AdaIN 将风格特征编码至内容特征的统计信息,然后解码器采用这些新特征来生成风格化图像。

3. 构建编码器

使用 VGG 构建编码器:

class VGGEncoder(nn.Module):def __init__(self):super().__init__()vgg = models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features# indices: up to relu4_1 (this slicing is commonly used)self.enc_layers = nn.Sequential(*[vgg[i] for i in range(21)])for p in self.enc_layers.parameters():p.requires_grad = Falsedef forward(self, x):return self.enc_layers(x)def get_vgg_features(x, vgg):features = []vgg_full = models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features.to(device)vgg_full.eval()h = x# we will collect at layer indices corresponding to relu1_1, relu2_1, relu3_1, relu4_1collect_indices = [1, 6, 11, 20]  # these index choices match common mappingfor i, layer in enumerate(vgg_full):h = layer(h)if i in collect_indices:features.append(h)if i == 20:breakreturn features

4. 构建解码器

尽管我们在编码器代码中使用了 4VGG 层 (block1_conv1block4_conv1),但 AdaIN 仅使用编码器的最后一层 block4_conv1。因此,解码器的输入张量具有与 block4_conv1 相同的激活。解码器架构由卷积和上采样层组成:

class Decoder(nn.Module):"""A decoder that mirrors part of VGG; simple Upsample + Conv blocks.Input channels: 512 -> output: 3"""def __init__(self):super().__init__()# we design a reasonable decoder; channel sizes inspired by VGGself.model = nn.Sequential(# input: 512 x H x Wnn.Conv2d(512, 256, 3, 1, 1),nn.ReLU(inplace=True),nn.Upsample(scale_factor=2, mode='nearest'),  # 256 -> upnn.Conv2d(256, 256, 3, 1, 1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, 1, 1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, 1, 1),nn.ReLU(inplace=True),nn.Conv2d(256, 128, 3, 1, 1),nn.ReLU(inplace=True),nn.Upsample(scale_factor=2, mode='nearest'),  # upnn.Conv2d(128, 128, 3, 1, 1),nn.ReLU(inplace=True),nn.Conv2d(128, 64, 3, 1, 1),nn.ReLU(inplace=True),nn.Upsample(scale_factor=2, mode='nearest'),  # upnn.Conv2d(64, 64, 3, 1, 1),nn.ReLU(inplace=True),nn.Conv2d(64, 3, 3, 1, 1),# output no activation; we'll clamp later)def forward(self, x):return self.model(x)

除不具有任何非线性激活功能的输出层外,所有层均使用 ReLU 激活功能。现在,我们已经完成了 AdaIN,编码器和解码器,并且可以继续进行图像预处理流程了。现在,我们已经准备好所有构件,剩下要做的就是将它们放在一起以创建 STN 和训练管道。

5. 构建风格迁移网络

(1) 定义数据集加载器:

class ImageFolderDataset(Dataset):def __init__(self, folder, transform):super().__init__()self.files = []for root, _, fnames in os.walk(folder):for f in fnames:if f.lower().endswith(('png', 'jpg', 'jpeg', 'bmp')):self.files.append(os.path.join(root, f))self.transform = transformdef __len__(self):return len(self.files)def __getitem__(self, idx):path = self.files[idx]img = Image.open(path).convert('RGB')return self.transform(img)class ContentStyleDataset(Dataset):"""Returns a pair (content, style) by sampling from content_dir and style_dir."""def __init__(self, content_dir, style_dir, size=256):self.content_ds = ImageFolderDataset(content_dir, transforms.Compose([transforms.Resize(int(size*1.15)),transforms.RandomCrop(size),transforms.ToTensor()]))self.style_ds = ImageFolderDataset(style_dir, transforms.Compose([transforms.Resize(int(size*1.15)),transforms.RandomCrop(size),transforms.ToTensor()]))def __len__(self):return max(len(self.content_ds), len(self.style_ds))def __getitem__(self, idx):c = self.content_ds[idx % len(self.content_ds)]s = self.style_ds[random.randint(0, len(self.style_ds)-1)]return c, s

(2) 构造 STN 非常简单,只需连接编码器,AdaIN 和解码器即可,如以上架构图所示。 STN 还是我们将用来执行推理的模型:

def train_adain(content_dir, style_dir, save_dir,image_size=256, batch_size=8, epochs=10,content_weight=1.0, style_weight=10.0, lr=1e-4,save_every=1000):os.makedirs(save_dir, exist_ok=True)# datasets & loadersdataset = ContentStyleDataset(content_dir, style_dir, size=image_size)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)# modelsvgg = VGGEncoder().to(device).eval()# freeze a separate full-vgg for multi-layer featuresvgg_full = models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features.to(device).eval()for p in vgg_full.parameters():p.requires_grad = Falsedecoder = Decoder().to(device)optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)iter_count = 0for epoch in range(epochs):pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")for content, style in pbar:iter_count += 1content = content.to(device)style = style.to(device)# normalize for vggc_in = normalize_batch(content)s_in = normalize_batch(style)# extract featureswith torch.no_grad():c_feat = vgg.forward(c_in)s_feat = vgg.forward(s_in)# compute AdaIN target (t)t = adaptive_instance_normalization(c_feat, s_feat)t = t.detach()  # do not backprop through encoder# decodeg = decoder(t)# content loss: compare features of generated g and t at relu4_1g_in = normalize_batch(g)def extract_multifeats(x):feats = []h = xcollect_indices = [1, 6, 11, 20]for i, layer in enumerate(vgg_full):h = layer(h)if i in collect_indices:feats.append(h)if i == 20:breakreturn featswith torch.no_grad():# style feats from original style images_feats = extract_multifeats(s_in)g_feats = extract_multifeats(g_in)content_loss = F.mse_loss(g_feats[-1], t)# style loss: compare mean/std of g_feats and s_featsstyle_loss = style_loss_from_features(g_feats, s_feats)loss = content_weight * content_loss + style_weight * style_lossoptimizer.zero_grad()loss.backward()optimizer.step()pbar.set_postfix({'loss': f"{loss.item():.4f}",'c_loss': f"{content_loss.item():.4f}",'s_loss': f"{style_loss.item():.4f}"})if iter_count % save_every == 0:torch.save(decoder.state_dict(), os.path.join(save_dir, f"decoder_iter_{iter_count}.pth"))print(f"Saved decoder at iter {iter_count}")# final savetorch.save(decoder.state_dict(), os.path.join(save_dir, "decoder_final.pth"))print("Training finished, model saved.")

内容和风格图像经过预处理,然后送入编码器。最后一个特征层 block4_conv1 进入 AdaIN()。然后风格化特征进入解码器以生成 RGB 风格化的图像。

(3) 内容损失和风格损失是根据固定 VGG 提取的激活来计算的。内容损失是 L2 范数,将生成的风格化图像的内容特征与 AdaIN 的输出进行比较,这使收敛速度更快。对于风格损失,将常用的 Gram 矩阵替换为均值和方差激活统计的 L2 范数。这产生的结果类似于 Gram 矩阵,但从概念上讲更清晰。以下是风格损失函数方程:
L s = ∑ i = 1 L ∣ ∣ μ ( ϕ i ( s t y l i z e d ) ) − μ ( ϕ i ( s t y l e ) ) ∣ ∣ 2 + ∣ ∣ σ ( ϕ i ( s t y l i z e d ) ) − σ ( ϕ i ( s t y l e ) ) ∣ ∣ 2 \mathcal L_s=\sum_{i=1}^L||\mu (\phi_i(stylized))-\mu(\phi_i(style))||_2+||\sigma(\phi_i(stylized))-\sigma(\phi_i(style))||_2 Ls=i=1L∣∣μ(ϕi(stylized))μ(ϕi(style))2+∣∣σ(ϕi(stylized))σ(ϕi(style))2
其中, ϕ i \phi_i ϕi 表示 VGG-19 中用于计算风格损失的层。
我们使用 calc_mean_std() 来计算来自风格化图像和风格图像的特征之间的统计量和 L2 范数,对内容层的损失求均值:

def style_loss_from_features(generated_feats, style_feats):# both lists for multiple layersloss = 0.0for gf, sf in zip(generated_feats, style_feats):gm, gs = calc_mean_std(gf)sm, ss = calc_mean_std(sf)loss += F.mse_loss(gm, sm) + F.mse_loss(gs, ss)return loss

(4) 实现模型推理管道:

from torchvision.transforms.functional import to_pil_imagedef load_image(path, size=512):img = Image.open(path).convert('RGB')transform = transforms.Compose([transforms.Resize(size),transforms.CenterCrop(size),transforms.ToTensor()])return transform(img).unsqueeze(0).to(device)  # 1 x 3 x H x Wdef stylize(content_path, style_path, decoder, vgg, alpha=1.0, size=512):c = load_image(content_path, size=size)s = load_image(style_path, size=size)c_in = normalize_batch(c)s_in = normalize_batch(s)with torch.no_grad():c_feat = vgg(c_in)s_feat = vgg(s_in)t = adaptive_instance_normalization(c_feat, s_feat)t = alpha * t + (1 - alpha) * c_featg = decoder(t)img = denormalize_batch(g).cpu().squeeze(0)pil = to_pil_image(img)return pil

(5) 最后,实现主函数:

if __name__ == '__main__':import argparseparser = argparse.ArgumentParser()parser.add_argument('--content_dir', type=str, default='data/content')parser.add_argument('--style_dir', type=str, default='data/style')parser.add_argument('--save_dir', type=str, default='checkpoints')parser.add_argument('--mode', type=str, default='train', choices=['train', 'stylize'])parser.add_argument('--content_image', type=str, default='')parser.add_argument('--style_image', type=str, default='')parser.add_argument('--alpha', type=float, default=1.0)parser.add_argument('--image_size', type=int, default=256)args = parser.parse_args()if args.mode == 'train':train_adain(args.content_dir, args.style_dir, args.save_dir,image_size=args.image_size, batch_size=8, epochs=10,content_weight=1.0, style_weight=10.0, lr=1e-4, save_every=2000)else:# load decoderdecoder = Decoder().to(device)# load checkpoint - make sure path existsckpt = os.path.join(args.save_dir, 'decoder_final.pth')if not os.path.exists(ckpt):raise RuntimeError("Please train model or provide checkpoint path")decoder.load_state_dict(torch.load(ckpt, map_location=device))vgg = VGGEncoder().to(device).eval()out = stylize(args.content_image, args.style_image, decoder, vgg, alpha=args.alpha, size=512)out.save('stylized_result.png')print("Stylized image saved to stylized_result.png")

我们将内容权重固定为 1,并调整风格权重,在此示例中,我们将风格权重设置为 1e-4。在网络架构中,看起来好像有三个要训练的网络,但是其中两个是固定的 VGG,因此唯一可训练的网络是解码器。因此,我们仅跟踪梯度并将其应用于解码器。

(6) 在本节中,将使用 CelebA 人脸图像 作为内容图像,并使用 Best Artworks of All Time 作为风格图像。Best Artworks of All Time 数据集收录了 50 位历史上极具影响力的艺术家的作品,包含约 12000 张画作图像。数据集下载完成后,将内容图像和风格图像分别放在 data/contentdata/style 目录下,使用以下命令启动训练过程:

$ python3 adain_train.py --content_dir data/content --style_dir data/style

训练完成后,使用以下命令生成风格化图像:

python3 adain_train.py --mode stylize --content_image data/content/016280.jpg --style_image data/style/test/Vincent_van_Gogh_875.jpg --alpha 10.0

效果展示

上图中的图像显示了使用网络训练时未看到的风格图像在推理时进行风格迁移的情况。每种风格转移仅通过单个前向计算进行,这比原始神经风格迁移算法的迭代优化快得多。

http://www.dtcms.com/a/486228.html

相关文章:

  • C++ --- 模版初阶
  • 外贸家具网站.net网站开发简介
  • Django 的文档接口
  • blender中对合并的物体重复设置材质,删除重复材质,批量复制材质
  • IDEA界面突然出现一条“竖线”,附解决办法
  • Git 学习及使用
  • 使用OpenGL加速图像处理
  • CUDA 调试器 sanitizer,检测数据竞争,竞争条件 race condition
  • Blender布料物理模拟生成插件 Simply Cloth Studio V1.4.4 + Simply Cloth Pro v3.0附使用教程
  • AWS CloudWatch:服务器的“眼睛”,实时监控一切动向
  • 云南省建设厅合同网站嵊州门户网站
  • 做网站需要学jsp我也来做外国网站购物
  • 异步数据采集实践:用 Python/Node.js 构建高并发淘宝商品 API 调用引擎
  • Spring Boot 3零基础教程,yml文件中配置和类的属性绑定,笔记15
  • Lua C API 中一段LUA建表过程解释
  • 用于大语言模型后训练阶段的新方法GVPO(Group Variance Policy Optimization)
  • k8s集群环境下Jenkins环境性能测试项目实战
  • 【k8s】在 k8s上部署一个 web 服务
  • 怎做网站网页设计属于什么行业
  • 02 Oracle JDK 下载及配置(解压缩版)
  • 「10.11」闪崩前比特币做空风波
  • 目标检测学习总结
  • java求职学习day40
  • 服装公司网站首页做头像的网站空白
  • 在 Microsoft Azure 上部署 ClickHouse 数据仓库:托管服务与自行部署的全面指南
  • 橙色可以做哪些网站沈阳网站建设024w
  • 网络设备配置:交换机、路由器OSPF和BGP、防火墙策略管理
  • 深圳建设工程信息网站科技有限公司网页设计
  • h5网站开发培训哪里好项目网创业
  • C++ Hash