当前位置: 首页 > news >正文

什么是 Perceptual Loss(感知损失)?

“Perceptual Loss”(感知损失)

“Perceptual Loss”(感知损失)是一种在图像处理和计算机视觉任务中常用的损失函数,旨在衡量两张图像在感知上的相似性,而不是仅仅依赖像素级别的差异。它通过利用预训练的深度神经网络(通常是图像分类网络,如 VGG)提取图像的高级特征,来捕捉图像的内容和风格等语义信息,而不是简单地比较像素值。这种方法特别适用于需要生成高质量图像的任务,例如图像风格迁移(Style Transfer)和超分辨率(Super-Resolution)。

以下是对 Perceptual Loss 的详细解释,基于论文《Perceptual Losses for Real-Time Style Transfer and Super-Resolution》以及通用知识:

下文中的图片来自原论文:https://arxiv.org/pdf/1603.08155


1. 什么是 Perceptual Loss?

传统的损失函数(如像素级损失,Pixel Loss)通常直接计算输出图像 ( y ^ \hat{y} y^ ) 和目标图像 ( y y y ) 之间的像素差异,例如均方误差(MSE):
ℓ pixel ( y ^ , y ) = 1 C H W ∥ y ^ − y ∥ 2 2 \ell_{\text{pixel}}(\hat{y}, y) = \frac{1}{CHW} \|\hat{y} - y\|_2^2 pixel(y^,y)=CHW1y^y22
其中 ( C , H , W C, H, W C,H,W ) 分别是图像的通道数、高度和宽度。然而,这种方法有一个显著缺点:它无法捕捉图像在人类感知上的相似性。例如,两张图像可能只有一像素的偏移,像素损失会认为它们差异很大,但人类几乎察觉不到这种差别。

Perceptual Loss 通过引入一个预训练的神经网络(称为“损失网络”,Loss Network,通常是 VGG-16 或 VGG-19),利用其提取的高级特征来定义图像之间的相似性。具体来说,它比较的是两张图像在损失网络特定层上的特征表示,而不是直接比较像素值。这种方法能够更好地反映图像的语义内容(如物体结构)和风格特征(如纹理、颜色分布)。


2. Perceptual Loss 的两种主要形式

在论文中,Perceptual Loss 被分为两种具体形式:Feature Reconstruction Loss(特征重构损失)和 Style Reconstruction Loss(风格重构损失),分别用于衡量内容相似性和风格相似性。

(1) Feature Reconstruction Loss(特征重构损失)
  • 定义:衡量输出图像 ( y ^ \hat{y} y^ ) 和目标图像 ( y y y ) 在损失网络 ( ϕ \phi ϕ ) 的某层 ( j j j ) 上的特征表示差异:
    ℓ feat ϕ , j ( y ^ , y ) = 1 C j H j W j ∥ ϕ j ( y ^ ) − ϕ j ( y ) ∥ 2 2 \ell_{\text{feat}}^{\phi, j}(\hat{y}, y) = \frac{1}{C_j H_j W_j} \|\phi_j(\hat{y}) - \phi_j(y)\|_2^2 featϕ,j(y^,y)=CjHjWj1ϕj(y^)ϕj(y)22
    其中:

    • ( ϕ j ( x ) \phi_j(x) ϕj(x) ) 是损失网络第 ( j j j ) 层的激活(特征图),形状为 ( C j × H j × W j C_j \times H_j \times W_j Cj×Hj×Wj );
    • ( C j , H j , W j C_j, H_j, W_j Cj,Hj,Wj ) 分别是该层特征图的通道数、高度和宽度。
  • 作用:鼓励输出图像 ( y ^ \hat{y} y^ ) 在内容上与目标图像 ( y y y ) 相似,但不要求像素完全一致。文档中提到,从较低层重构时,图像几乎与目标无差别;从较高层重构时,保留了整体结构,但颜色、纹理和精确形状可能改变(见图3)。

在这里插入图片描述

  • 应用:在风格迁移中,用于保留输入图像的内容;在超分辨率中,用于生成更符合人类感知的高分辨率细节。
(2) Style Reconstruction Loss(风格重构损失)
  • 定义:基于特征图的 Gram 矩阵,衡量输出图像 ( y ^ \hat{y} y^ ) 和目标图像 ( y y y ) 在风格上的差异:

    • 首先计算第 ( j j j ) 层的 Gram 矩阵 ( G j ϕ ( x ) G_j^\phi(x) Gjϕ(x) ):
      G j ϕ ( x ) c , c ′ = 1 C j H j W j ∑ h = 1 H j ∑ w = 1 W j ϕ j ( x ) h , w , c ϕ j ( x ) h , w , c ′ G_j^\phi(x)_{c, c'} = \frac{1}{C_j H_j W_j} \sum_{h=1}^{H_j} \sum_{w=1}^{W_j} \phi_j(x)_{h, w, c} \phi_j(x)_{h, w, c'} Gjϕ(x)c,c=CjHjWj1h=1Hjw=1Wjϕj(x)h,w,cϕj(x)h,w,c
      Gram 矩阵捕捉了特征之间的相关性,反映了图像的纹理和风格信息。
    • 然后计算风格损失:
      ℓ style ϕ , j ( y ^ , y ) = ∥ G j ϕ ( y ^ ) − G j ϕ ( y ) ∥ F 2 \ell_{\text{style}}^{\phi, j}(\hat{y}, y) = \|G_j^\phi(\hat{y}) - G_j^\phi(y)\|_F^2 styleϕ,j(y^,y)=Gjϕ(y^)Gjϕ(y)F2
      其中 ( ∥ ⋅ ∥ F \| \cdot \|_F F ) 是 Frobenius 范数。
  • 作用:鼓励 ( y ^ \hat{y} y^ ) 在颜色、纹理和模式等风格特征上与目标图像 ( y y y ) 一致,但不关心空间结构。文档中提到,从较高层重构风格时,会转移更大尺度的结构(见图4)。
    在这里插入图片描述

  • 应用:主要用于风格迁移任务,确保输出图像具有目标风格图像的艺术特性。


3. Perceptual Loss 的优势

与传统的像素损失相比,Perceptual Loss 有以下优势:

  • 感知鲁棒性:它基于人类感知更相关的高级特征,而非低级像素差异。例如,文档中提到,即使两张图像像素差异很大(如偏移一像素),感知损失仍能认为它们相似。
  • 语义信息:利用预训练网络(如 VGG)已经学习到的语义知识,适用于需要推理的任务(如超分辨率中的细节生成,风格迁移中的内容保留)。
  • 实时性:通过训练一个前馈网络(如论文中的 Image Transformation Network)来近似优化感知损失,速度比传统优化方法快三个数量级(表1)。

在这里插入图片描述


4. 应用

《Perceptual Losses for Real-Time Style Transfer and Super-Resolution》提出了将 Perceptual Loss 用于训练前馈网络,解决两种任务:

  1. 风格迁移(Style Transfer)

    • 目标:生成一张图像 ( y ^ \hat{y} y^ ),结合内容图像 ( y c y_c yc ) 的内容和风格图像 ( y s y_s ys ) 的风格。
    • 损失函数:
      L = λ c ℓ feat ϕ , j ( y ^ , y c ) + λ s ℓ style ϕ , J ( y ^ , y s ) + λ T V ℓ T V ( y ^ ) L = \lambda_c \ell_{\text{feat}}^{\phi, j}(\hat{y}, y_c) + \lambda_s \ell_{\text{style}}^{\phi, J}(\hat{y}, y_s) + \lambda_{TV} \ell_{TV}(\hat{y}) L=λcfeatϕ,j(y^,yc)+λsstyleϕ,J(y^,ys)+λTVTV(y^)
      其中 ( ℓ T V \ell_{TV} TV ) 是总变差正则化,用于平滑输出。
    • 结果:与 Gatys 等人的优化方法相比,质量相似,但速度提升至实时级别。
  2. 单图像超分辨率(Super-Resolution)

    • 目标:从低分辨率输入 ( x x x ) 生成高分辨率输出 ( y ^ \hat{y} y^ ),接近真实高分辨率图像 ( y c y_c yc )。
    • 损失函数:仅使用特征重构损失 ( ℓ feat \ell_{\text{feat}} feat ),不使用风格损失。
    • 结果:相比像素损失,感知损失生成的图像细节更丰富(如边缘和纹理),尽管 PSNR/SSIM 可能较低。

5. 为什么用 Perceptual Loss?

  • 弥补像素损失的不足:像素损失假设输出和目标图像应逐像素匹配,但这在风格迁移(无唯一正确输出)和超分辨率(多解问题)中不适用。感知损失通过高级特征捕捉更符合人类视觉的质量。
  • 知识迁移:预训练的损失网络(如 VGG)已学会提取语义信息,训练时无需从头学习。
  • 实用性:结合前馈网络后,生成速度大幅提升,适合实时应用。

6. 代码示例(简要)

以下是一个简单的 PyTorch 示例,展示如何计算 Feature Reconstruction Loss:

import torch
import torch.nn as nn
import torchvision.models as models

# 加载预训练的 VGG-16 作为损失网络
vgg = models.vgg16(pretrained=True).features.eval()
loss_network = nn.Sequential(*list(vgg.children())[:9])  # 取到 relu2_2 层

# 假设输入图像和目标图像
x = torch.randn(1, 3, 256, 256)  # 输入图像
y = torch.randn(1, 3, 256, 256)  # 目标图像

# 计算特征
feat_x = loss_network(x)
feat_y = loss_network(y)

# 特征重构损失
feat_loss = torch.mean((feat_x - feat_y) ** 2)
print(f"Feature Reconstruction Loss: {feat_loss.item()}")

总结

Perceptual Loss 是一种强大的工具,通过预训练网络的高级特征衡量图像相似性,弥补了像素损失的局限性。它在风格迁移和超分辨率中表现出色,尤其在需要语义推理和实时性时。文档中提出的方法将其与前馈网络结合,进一步推动了其实用性。

( c c c ) 和 ( c ′ c' c ) 的含义

在计算 Gram 矩阵 ( G j ϕ ( x ) G_j^\phi(x) Gjϕ(x) ) 和风格损失 ( ℓ style ϕ , j ( y ^ , y ) \ell_{\text{style}}^{\phi, j}(\hat{y}, y) styleϕ,j(y^,y) ) 的公式中,( c c c ) 和 ( c ′ c' c ) 表示的是特征图(feature map)的通道索引。下面解释一下它们的含义以及在上下文中的作用。


1. Gram 矩阵的定义

Gram 矩阵 ( G j ϕ ( x ) G_j^\phi(x) Gjϕ(x) ) 是用来捕捉图像在第 ( j j j ) 层特征图中不同通道之间的相关性,反映纹理和风格信息。它的公式是:

G j ϕ ( x ) c , c ′ = 1 C j H j W j ∑ h = 1 H j ∑ w = 1 W j ϕ j ( x ) h , w , c ϕ j ( x ) h , w , c ′ G_j^\phi(x)_{c, c'} = \frac{1}{C_j H_j W_j} \sum_{h=1}^{H_j} \sum_{w=1}^{W_j} \phi_j(x)_{h, w, c} \phi_j(x)_{h, w, c'} Gjϕ(x)c,c=CjHjWj1h=1Hjw=1Wjϕj(x)h,w,cϕj(x)h,w,c

其中:

  • ( ϕ j ( x ) \phi_j(x) ϕj(x) ):损失网络(如 VGG-16)在第 ( j j j ) 层的特征图,形状为 ( ( C j , H j , W j ) (C_j, H_j, W_j) (Cj,Hj,Wj) ),分别表示通道数(channels)、高度(height)和宽度(width)。
  • ( h h h ) 和 ( w w w ):特征图的空间位置索引,分别对应高度和宽度的坐标。
  • ( c c c ) 和 ( c ′ c' c ):特征图的通道索引,范围是 ( 0 ≤ c , c ′ < C j 0 \leq c, c' < C_j 0c,c<Cj)。
( c c c ) 和 ( c ′ c' c ) 的具体含义
  • ( c c c ) 表示第 ( j j j ) 层特征图中的某个通道(channel),例如第 ( c c c ) 个卷积核的输出。
  • ( c ′ c' c ) 表示第 ( j j j ) 层特征图中的另一个通道,通常与 ( c c c ) 不同(但也可以相同,例如 ( c = c ′ c = c' c=c ))。
  • ( ϕ j ( x ) h , w , c \phi_j(x)_{h, w, c} ϕj(x)h,w,c ):在位置 ( ( h , w ) (h, w) (h,w) ) 上,第 ( c c c ) 个通道的特征值。
  • ( ϕ j ( x ) h , w , c ′ \phi_j(x)_{h, w, c'} ϕj(x)h,w,c ):在同一位置 ( ( h , w ) (h, w) (h,w) ) 上,第 ( c ′ c' c ) 个通道的特征值。

Gram 矩阵的元素 ( G j ϕ ( x ) c , c ′ G_j^\phi(x)_{c, c'} Gjϕ(x)c,c ) 是 ( c c c ) 和 ( c ′ c' c ) 这两个通道在整个特征图上的点积(内积),然后归一化(除以 ( C j H j W j C_j H_j W_j CjHjWj ))。它衡量了这两个通道的特征在空间上的相关性:

  • 如果 ( c c c ) 和 ( c ′ c' c ) 的特征值在空间上经常同时激活(比如都对应某种纹理模式),则 ( G j ϕ ( x ) c , c ′ G_j^\phi(x)_{c, c'} Gjϕ(x)c,c ) 的值会较大。
  • 如果它们激活模式不相关,则值接近 0。
Gram 矩阵的形状
  • ( G j ϕ ( x ) G_j^\phi(x) Gjϕ(x) ) 是一个 ( C j × C j C_j \times C_j Cj×Cj ) 的矩阵:
    • 行索引 ( c c c ) 和列索引 ( c ′ c' c ) 分别对应第 ( j j j ) 层特征图的通道。
    • 矩阵大小与空间维度 ( H j H_j Hj ) 和 ( W j W_j Wj ) 无关,仅依赖通道数 ( C j C_j Cj )。

2. 为什么用 ( c c c ) 和 ( c ′ c' c )?

Gram 矩阵的核心思想是捕捉特征之间的协方差关系,从而忽略空间结构,专注于风格信息(纹理、颜色分布等)。具体来说:

  • ( c c c ) 和 ( c ′ c' c ) 的组合:通过遍历所有可能的 ( ( c , c ′ ) (c, c') (c,c) ) 对,Gram 矩阵记录了第 ( j j j ) 层所有通道两两之间的相关性。
  • 去空间化:对 ( h h h ) 和 ( w w w ) 求和消除了空间信息,使得 ( G j ϕ ( x ) G_j^\phi(x) Gjϕ(x) ) 只关心特征的全局统计特性,而不是它们在图像中的具体位置。

例如,在 VGG-16 的 relu3_3 层,特征图可能有 ( C j = 256 C_j = 256 Cj=256 ) 个通道。Gram 矩阵 ( G j ϕ ( x ) G_j^\phi(x) Gjϕ(x) ) 就是 ( 256 × 256 256 \times 256 256×256 ) 的矩阵,每个元素 ( G j ϕ ( x ) c , c ′ G_j^\phi(x)_{c, c'} Gjϕ(x)c,c ) 表示通道 ( c c c ) 和 ( c ′ c' c ) 的相关性。


3. 风格损失的计算

风格损失 ( ℓ style ϕ , j ( y ^ , y ) \ell_{\text{style}}^{\phi, j}(\hat{y}, y) styleϕ,j(y^,y) ) 是基于 Gram 矩阵的差异:

ℓ style ϕ , j ( y ^ , y ) = ∥ G j ϕ ( y ^ ) − G j ϕ ( y ) ∥ F 2 \ell_{\text{style}}^{\phi, j}(\hat{y}, y) = \|G_j^\phi(\hat{y}) - G_j^\phi(y)\|_F^2 styleϕ,j(y^,y)=Gjϕ(y^)Gjϕ(y)F2

  • ( G j ϕ ( y ^ ) G_j^\phi(\hat{y}) Gjϕ(y^) ):输出图像 ( y ^ \hat{y} y^ ) 在第 ( j j j ) 层的 Gram 矩阵。
  • ( G j ϕ ( y ) G_j^\phi(y) Gjϕ(y) ):目标风格图像 ( y y y ) 在第 ( j j j ) 层的 Gram 矩阵。
  • ( ∥ ⋅ ∥ F \| \cdot \|_F F ):Frobenius 范数,计算两个矩阵的元素级平方差之和的平方根。

这里,( c c c ) 和 ( c ′ c' c ) 的作用体现在 Gram 矩阵的构造中。风格损失的目标是最小化 ( G j ϕ ( y ^ ) G_j^\phi(\hat{y}) Gjϕ(y^) ) 和 ( G j ϕ ( y ) G_j^\phi(y) Gjϕ(y)) 的差异,确保输出图像 ( y ^ \hat{y} y^ ) 的特征相关性(即风格)与目标图像 ( y y y ) 一致。


4. 举个例子

假设 ( ϕ j ( x ) \phi_j(x) ϕj(x) ) 是形状为 ( ( 4 , 2 , 2 ) (4, 2, 2) (4,2,2) ) 的特征图(( C j = 4 , H j = 2 , W j = 2 C_j = 4, H_j = 2, W_j = 2 Cj=4,Hj=2,Wj=2 )),表示 4 个通道的特征:

通道 0: [[1, 2], [3, 4]]
通道 1: [[0, 1], [1, 0]]
通道 2: [[2, 2], [2, 2]]
通道 3: [[1, 0], [0, 1]]

计算 ( G j ϕ ( x ) 0 , 1 G_j^\phi(x)_{0, 1} Gjϕ(x)0,1 )(通道 0 和通道 1 的相关性):

  • ( ϕ j ( x ) h , w , 0 \phi_j(x)_{h, w, 0} ϕj(x)h,w,0 ) 和 ( ϕ j ( x ) h , w , 1 \phi_j(x)_{h, w, 1} ϕj(x)h,w,1 ) 的值分别是:
    • ( ( 1 , 0 ) , ( 2 , 1 ) , ( 3 , 1 ) , ( 4 , 0 ) (1, 0), (2, 1), (3, 1), (4, 0) (1,0),(2,1),(3,1),(4,0) )
  • ( G j ϕ ( x ) 0 , 1 = 1 4 ⋅ 2 ⋅ 2 ∑ ( 1 ⋅ 0 + 2 ⋅ 1 + 3 ⋅ 1 + 4 ⋅ 0 ) = 1 16 ⋅ ( 0 + 2 + 3 + 0 ) = 5 16 G_j^\phi(x)_{0, 1} = \frac{1}{4 \cdot 2 \cdot 2} \sum (1 \cdot 0 + 2 \cdot 1 + 3 \cdot 1 + 4 \cdot 0) = \frac{1}{16} \cdot (0 + 2 + 3 + 0) = \frac{5}{16} Gjϕ(x)0,1=4221(10+21+31+40)=161(0+2+3+0)=165 )

类似地计算所有 ( ( c , c ′ ) (c, c') (c,c) ) 对,得到 ( 4 × 4 4 \times 4 4×4 ) 的 Gram 矩阵。然后风格损失比较 ( y ^ \hat{y} y^ ) 和 ( y y y ) 的 Gram 矩阵差异。


5. 总结 ( c c c ) 和 ( c ′ c' c ) 的含义

  • ( c c c ) 和 ( c ′ c' c ) 是第 ( j j j ) 层特征图的两个通道索引。
  • ( G j ϕ ( x ) c , c ′ G_j^\phi(x)_{c, c'} Gjϕ(x)c,c ) 表示通道 ( c c c ) 和 ( c ′ c' c ) 的特征在空间上的相关性。
  • 在风格损失中,( c c c ) 和 ( c ′ c' c ) 的作用是通过 Gram 矩阵提取风格特征(如纹理和颜色模式),并通过比较 ( y ^ \hat{y} y^ ) 和 ( y y y ) 的 Gram 矩阵,确保输出图像具有目标图像的风格。

实验代码:Style Transfer with Perceptual Loss

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 图像预处理
def load_image(image_path, size=256):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0).to(device)

# 图像后处理
def post_process(tensor):
    tensor = tensor.cpu().clone().squeeze(0)
    tensor = tensor * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    tensor = tensor.clamp(0, 1)
    return transforms.ToPILImage()(tensor)

# 图像转换网络(Transformation Network)
class TransformNet(nn.Module):
    def __init__(self):
        super(TransformNet, self).__init__()
        # 下采样
        self.downsampling = nn.Sequential(
            nn.Conv2d(3, 32, 9, 1, 4),  # 9x9 卷积
            nn.InstanceNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 3, 2, 1, padding_mode='reflect'),  # Stride=2 下采样
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, 2, 1, padding_mode='reflect'),  # Stride=2 下采样
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
        )
        # 残差块
        self.residuals = nn.Sequential(
            *[ResidualBlock(128) for _ in range(5)]  # 5 个残差块
        )
        # 上采样
        self.upsampling = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),  # 上采样
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),  # 上采样
            nn.InstanceNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 3, 9, 1, 4, padding_mode='reflect'),  # 输出 3 通道
            nn.Tanh()
        )

    def forward(self, x):
        x = self.downsampling(x)
        x = self.residuals(x)
        x = self.upsampling(x)
        return x * 150 + 255 / 2  # 论文中提到输出范围调整到 [0, 255]

# 残差块
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1, padding_mode='reflect'),
            nn.InstanceNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 3, 1, 1, padding_mode='reflect'),
            nn.InstanceNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)

# 感知损失网络(VGG-16)
class VGG16Loss(nn.Module):
    def __init__(self):
        super(VGG16Loss, self).__init__()
        vgg = models.vgg16(pretrained=True).features.eval().to(device)
        self.layers = {'3': 'relu1_2', '8': 'relu2_2', '15': 'relu3_3', '22': 'relu4_3'}
        self.model = nn.ModuleDict()
        for name, layer in vgg.named_children():
            self.model[name] = layer
            if name in self.layers:
                break
    
    def forward(self, x):
        features = {}
        for name, layer in self.model.items():
            x = layer(x)
            if name in self.layers:
                features[self.layers[name]] = x
        return features

# Gram 矩阵计算
def gram_matrix(feature):
    b, c, h, w = feature.size()
    feature = feature.view(b * c, h * w)
    gram = torch.mm(feature, feature.t())
    return gram.div(b * c * h * w)

# 训练函数
def train_style_transfer(content_path, style_path, epochs=1000, content_weight=1e5, style_weight=1e10, tv_weight=1e-6):
    # 加载图像
    content_img = load_image(content_path)
    style_img = load_image(style_path)

    # 初始化网络
    transform_net = TransformNet().to(device)
    vgg_loss = VGG16Loss().to(device)

    # 优化器
    optimizer = optim.Adam(transform_net.parameters(), lr=1e-3)

    # 获取风格目标特征
    style_features = vgg_loss(style_img)
    style_grams = {layer: gram_matrix(feat) for layer, feat in style_features.items()}

    # 训练循环
    for epoch in range(epochs):
        transform_net.train()
        optimizer.zero_grad()

        # 生成图像
        output = transform_net(content_img)

        # 计算内容损失
        content_features = vgg_loss(content_img)
        output_features = vgg_loss(output)
        content_loss = content_weight * F.mse_loss(output_features['relu2_2'], content_features['relu2_2'])

        # 计算风格损失
        output_grams = {layer: gram_matrix(feat) for layer, feat in output_features.items()}
        style_loss = 0
        for layer in style_grams:
            style_loss += style_weight * F.mse_loss(output_grams[layer], style_grams[layer])
        style_loss /= len(style_grams)

        # 计算总变差正则化
        tv_loss = tv_weight * (torch.sum(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) +
                               torch.sum(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :])))

        # 总损失
        total_loss = content_loss + style_loss + tv_loss
        total_loss.backward()
        optimizer.step()

        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Total Loss: {total_loss.item():.4f}, "
                  f"Content Loss: {content_loss.item():.4f}, Style Loss: {style_loss.item():.4f}")

    # 保存结果
    transform_net.eval()
    with torch.no_grad():
        output = transform_net(content_img)
    result = post_process(output)
    result.save("styled_image.png")
    return result

# 示例运行
content_path = "content.jpg"  # 替换为你的内容图像路径
style_path = "style.jpg"      # 替换为你的风格图像路径
result = train_style_transfer(content_path, style_path, epochs=1000)

# 显示结果
plt.imshow(result)
plt.axis('off')
plt.show()

代码解释

1. 网络架构(TransformNet)
  • 结构:基于论文 PAGE5 的描述,使用了:
    • 下采样:通过 stride=2 的卷积层(3x3 卷积)将输入缩小。
    • 残差块:5 个残差块(每个包含两个 3x3 卷积),参考 He 等人 [43] 的设计。
    • 上采样:通过 stride=1/2 的转置卷积(fractionally-strided convolution)恢复到原始尺寸。
    • 激活:ReLU 和 Tanh(输出层调整到 [0, 255])。
  • 特点:无池化层,使用 InstanceNorm(替代 BatchNorm),反射填充(padding_mode=‘reflect’)以减少边界伪影。
2. 感知损失(VGG16Loss)
  • 损失网络:使用预训练的 VGG-16,提取特定层的特征(relu1_2, relu2_2, relu3_3, relu4_3),如 PAGE9 所述。
  • 内容损失:在 relu2_2 层计算特征重构损失:
    ℓ feat = ∥ ϕ relu2_2 ( y ^ ) − ϕ relu2_2 ( y c ) ∥ 2 2 \ell_{\text{feat}} = \|\phi_{\text{relu2\_2}}(\hat{y}) - \phi_{\text{relu2\_2}}(y_c)\|_2^2 feat=ϕrelu2_2(y^)ϕrelu2_2(yc)22
  • 风格损失:在多层计算 Gram 矩阵差异:
    ℓ style = ∑ j ∥ G j ( y ^ ) − G j ( y s ) ∥ 2 2 \ell_{\text{style}} = \sum_{j} \|G_j(\hat{y}) - G_j(y_s)\|_2^2 style=jGj(y^)Gj(ys)22
  • 总变差正则化:( ℓ T V \ell_{TV} TV ),鼓励输出的空间平滑性。
3. 训练过程
  • 数据集:论文使用 MS-COCO 数据集,这里简化为单张内容图像和风格图像的训练。
  • 优化器:Adam,学习率 1e-3(PAGE9)。
  • 超参数
    • 内容权重 ( λ c = 1 e 5 \lambda_c = 1e5 λc=1e5 )
    • 风格权重 ( λ s = 1 e 10 \lambda_s = 1e10 λs=1e10 )
    • TV 权重 ( λ T V = 1 e − 6 \lambda_{TV} = 1e-6 λTV=1e6 )
    • 这些值可通过交叉验证调整(PAGE9)。
4. 输出
  • 训练后,网络生成风格化图像,保存为 styled_image.png,并显示。

运行说明

  1. 依赖:需要安装 PyTorch、torchvision 和 PIL。
  2. 图像准备:替换 content_pathstyle_path 为本地内容图像和风格图像路径。
  3. 调整
    • 增加 epochs 或使用更大批次数据(如 MS-COCO)以提升效果。
    • 修改超参数(如权重)以适配不同风格。

这个代码是一个简化版,真实实验可能需要更多训练数据和调参,但它完整展示了论文中风格迁移的核心思想。

后记

2025年3月10日19点45分于上海,在Grok 3大模型辅助下完成。

相关文章:

  • ForceMimic:以力为中心的模仿学习,采用力运动捕捉系统进行接触丰富的操作
  • webpack和vite的区别
  • pyspark 数据处理的三种方式RDD、DataFrame、Spark SQL案例
  • 大模型中的微调LoRA是什么
  • 多视图几何--对极几何--从0-1理解对极几何
  • 个人记录的一个插件,Unity-RuntimeMonitor
  • static 用法,函数递归与迭代详解
  • Spring Cloud之远程调用OpenFeign参数传递
  • Unity单例模式更新金币数据
  • CI/CD—Jenkins配置Poll SCM触发自动构建
  • DETR详解
  • 基于SpringBoot实现旅游酒店平台功能六
  • 【C#学习笔记02】基本元素与数据类型
  • mac本地部署Qwq-32b记录
  • 供应链工作效率如何提升
  • Java常见面试技术点整理讲解——后端框架(整理中,未完成)
  • 什么是一致性模型,在实践中如何选择?
  • 程序化广告行业(3/89):深度剖析行业知识与数据处理实践
  • MOM成功实施分享(七)电力电容制造MOM工艺分析与解决方案(第二部分)
  • 菜鸟打印机组件安装后重启显示“Windows 找不到文件‘CNPrintClient,exe‘。请确定文件名是否正确后,再试一次。”的正确解决方案
  • 海南征集民生领域涉嫌垄断违法行为线索,包括行业协会等领域
  • 特朗普与普京开始电话会谈,稍后将致电泽连斯基
  • 招商基金总经理徐勇因任期届满离任,“老将”钟文岳回归接棒
  • 4年间职务侵占、受贿逾亿元,北京高院:严惩民企内部腐败
  • 520、521婚登预约迎高峰?上海民政:将增派力量,新人可现场办理
  • 经济日报:政府采购监管篱笆要扎得更牢