当前位置: 首页 > news >正文

深度学习-计算机视觉-风格迁移

使用卷积神经网络,自动将一个图像中的风格应用在另一图像之上,即——风格迁移(style transfer)

这里我们需要两张输入图像:一张是内容图像,另一张是风格图像

我们将使用神经网络修改内容图像,使其在风格上接近风格图像。

下面的内容图像为西雅图郊区的雷尼尔山国家公园拍摄的风景照,而风格图像则是一幅主题为秋天橡树的油画。 最终输出的合成图像应用了风格图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。

1. 基于卷积神经网络的风格迁移方法

  • 首先,我们初始化合成图像,例如将其初始化为内容图像。 该合成图像是风格迁移过程中唯一需要更新的变量,即风格迁移所需迭代的模型参数。

  • 然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。 这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们可以选择其中某些层的输出作为内容特征或风格特征

以下图为例,这里选取的预训练的神经网络含有3个卷积层,其中第二层输出内容特征,第一层和第三层输出风格特征。

接下来,我们通过前向传播(实线箭头方向)计算风格迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像。 风格迁移常用的损失函数由3部分组成

  1. 内容损失使合成图像与内容图像在内容特征上接近;

  2. 风格损失使合成图像与风格图像在风格特征上接近;

  3. 全变分损失则有助于减少合成图像中的噪点。

最后,当模型训练结束时,我们输出风格迁移的模型参数,即得到最终的合成图像。

  • 可以通过预训练的卷积神经网络来抽取图像的特征,并最小化损失函数来不断更新合成图像来作为模型参数。

  • 我们使用格拉姆矩阵表达风格层输出的风格。

2. 代码实现

代码实现如下:

2.1 阅读内容和风格图像

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2ld2l.set_figsize()# 内容图片(左)
content_img = d2l.Image.open('../img/rainier.jpg')
d2l.plt.imshow(content_img);# 样式图片(右)
style_img = d2l.Image.open('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img);

2.2 预处理和后处理

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0)def postprocess(img):img = img[0].to(rgb_std.device)img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

2.3 抽取图像特征

使用基于ImageNet数据集预训练的VGG-19模型来抽取图像特征:

pretrained_net = torchvision.models.vgg19(pretrained=True)

为了抽取图像的内容特征和风格特征,我们可以选择VGG网络中某些层的输出。

一般来说,越靠近输入层,越容易抽取图像的细节信息;越靠近输出层,则越容易抽取图像的全局信息

为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。 我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层

在使用的VGG网络中,包含5个卷积块。 实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。 这些层的索引可以通过打印pretrained_net实例获取。

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

构建一个新的网络net,它只保留需要用到的VGG的所有层(上面风格和内容最大为28层,因此只取28层):

net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])

给定输入X,如果我们简单地调用前向传播net(X),只能获得最后一层的输出。

由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出

def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X) # 逐层计算输入 X 的结果if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles

get_contents函数对内容图像抽取内容特征; get_styles函数对风格图像抽取风格特征。

def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Ydef get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y

2.4 定义损失函数

内容损失:

def content_loss(Y_hat, Y):# 我们从动态计算梯度的树中分离目标:# 这是一个规定的值,而不是一个变量。return torch.square(Y_hat - Y.detach()).mean()

风格损失:匹配统计分布

1. Gram 矩阵:把「样式」变成可以比较的数字

  1. Gram 矩阵 G_ij = Σ_k X_{i,k}·X_{j,k}(内积)统计了任意两个通道之间的共现强度;通道 i 与通道 j 的响应若总是同时出现,则 G_ij 很大。

  2. 这种「通道间相关性」对纹理、笔触、颜色分布等样式信息非常敏感,而对空间位置不敏感,因此能很好地代表「风格」。

2. 风格损失:让「生成图」的 Gram 去逼近「风格图」的 Gram

  1. gram(Y_hat):当前生成图像在这一层的 归一化 Gram 矩阵(形状 (C, C))。

  2. gram_Y:风格图像在同一层 提前算好并缓存 的 Gram 矩阵。其中的.detach() 告诉 PyTorch:“这是常量,别给我求梯度,我只想让生成图去逼近它。”

  3. gram(Y_hat) - gram_Y.detach():两张图 Gram 矩阵的 逐元素误差。

  4. torch.square(...).mean():把误差逐元素平方后取平均 → 均方误差 (MSE),也就是该层的 风格损失值。

def gram(X):# X 的形状:(batch, C, H, W)num_channels, n = X.shape[1], X.numel() // X.shape[1]  # C 和 H·WX = X.reshape((num_channels, n))                       # (C, H·W)return torch.matmul(X, X.T) / (num_channels * n)       # matmul(X, X.T)是(C, C),再归一化def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

全变分损失:tv降噪

def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

损失函数

风格转移的损失函数是内容损失、风格损失和总变化损失的加权和。

通过调节这些权重超参数,我们可以权衡合成图像在保留内容、迁移风格以及去噪三方面的相对重要性。

content_weight, style_weight, tv_weight = 1, 1e3, 10def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):# 分别计算内容损失、风格损失和全变分损失contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weight# 对所有损失求和l = sum(10 * styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l

2.5 初始化合成图像

class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weightdef get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer

2.6 训练模型

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)l.backward()trainer.step()scheduler.step()if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return X

现在我们训练模型: 将内容图像和风格图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像。

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

我们可以看到,合成图像保留了内容图像的风景和物体,并同时迁移了风格图像的色彩。例如,合成图像具有与风格图像中一样的色彩块,其中一些甚至具有画笔笔触的细微纹理。


文章转载自:

http://ktOipRmO.wwznd.cn
http://piXh3TW2.wwznd.cn
http://iuYZ4p2a.wwznd.cn
http://YW7bt0mR.wwznd.cn
http://T4lk007A.wwznd.cn
http://Kxu1qHUr.wwznd.cn
http://Yw2yjJLi.wwznd.cn
http://LOknsKVL.wwznd.cn
http://85FsRWAB.wwznd.cn
http://UzXmgU1O.wwznd.cn
http://ibY8BhCn.wwznd.cn
http://dc5Ju67Q.wwznd.cn
http://1OGQDu8Q.wwznd.cn
http://Q2CPrpQl.wwznd.cn
http://Ues3ODGs.wwznd.cn
http://0fqbTRGM.wwznd.cn
http://oemnZ3ZD.wwznd.cn
http://i4u93M7B.wwznd.cn
http://fCQM3do6.wwznd.cn
http://Svcw4Azc.wwznd.cn
http://BmNP5zFF.wwznd.cn
http://jDrvmsYf.wwznd.cn
http://RlEQropc.wwznd.cn
http://bzWKrvqS.wwznd.cn
http://oA11WhGg.wwznd.cn
http://9A9M5c2k.wwznd.cn
http://ZjcJ4gJj.wwznd.cn
http://6LPLSGhL.wwznd.cn
http://pFjBGJ8d.wwznd.cn
http://I7feL5oT.wwznd.cn
http://www.dtcms.com/a/385372.html

相关文章:

  • 机器学习面试题:请介绍一下你理解的集成学习算法
  • C2000基础-GPIO介绍及使用
  • 【CTF-WEB】Web基础工具的使用(burpsuit抓包并修改数值)
  • 重学前端015 --- 响应式网页设计 CSS变换
  • Spring Boot + MyBatis 报 Invalid bean definition 如何排查解决
  • 从 APP 界面设计到用户体验优化:如何让你的应用脱颖而出?
  • RabbitMQ 高可用与集群机制
  • 迎中秋庆国庆,易天假期安排通知
  • IFNet.py代码学习 自学
  • 深度学习之PyTorch基本使用(一)
  • Python 异常处理与文件操作全解析
  • 记一次神通数据库的链接不释放问题
  • FLASK 框架 (关于Flask框架的简单学习和项目实战)
  • Flutter学习项目
  • Linux中报错记录以及libRadtran的安装—Ubuntu
  • 仓颉编程语言青少年基础教程:enum(枚举)类型和Option类型
  • 124.stm32故障:程序下载不能运行,stlink调试时可以正常运行
  • 3.DSP学习记录之GPIO按键输入
  • OpenCV:图像拼接(SIFT 特征匹配 + 透视变换)
  • 基于大语言模型的有人与无人驾驶航空器协同作战框架
  • 差分: 模板+题目
  • 解读IEC62061-2021
  • SQL数据库操作语言
  • UE4工程启动卡很久如何在运行时进行dump查看堆栈
  • Day24_【深度学习—广播机制】
  • 【试题】传输专业设备L1~L3实操考题
  • CSP认证练习题目推荐(4)
  • nginx如何添加CSP策略
  • 计算机网络(一些知识与思考)
  • 【开题答辩全过程】以 4s店汽车销售系统为例,包含答辩的问题和答案