当前位置: 首页 > news >正文

Pytorch实现之结合SE注意力和多种损失的特征金字塔架构GAN的图像去模糊方法

简介

简介:提出了一种利用特征金字塔作为框架代替多尺度输入的一种方法来构建生成器模型,减少了模型规模并加快了训练速度。在模型架构中还融合了通道注意力方法来提高训练能力。作者在生成器中采用了三种常见的损失计算,在鉴别器中结合了最小二乘和相对论损失来改善模型训练。

论文题目:Image Deblurring Based on Generative Adversarial Networks(基于生成对抗网络的图像去模糊)

会议:International Conference on Intelligent Computing and Signal Processing (ICSP)

摘要:图像去模糊技术利用深度学习方法解决单幅图像的模糊问题,这是计算机视觉领域的一个具有挑战性的问题。 近年来,深度学习和计算机视觉的快速发展,提高了模糊处理算法的性能。 本文从深度学习的角度研究图像去模糊问题,利用卷积神经网络实现图像去模糊的目的。 针对多尺度网络单次去模糊处理规模庞大,重要特征信息未得到充分利用的问题,提出了一种基于生成对抗网络的去模糊算法。 该模型采用特征金字塔网络作为框架代替多尺度输入,有效地减小了网络规模,加快了训练速度。 为了更好地利用特征信息,在网络中引入了注意机制和双尺度判别器。 为了使训练过程更加稳定,该算法采用最小二乘和相对论相结合的方法改善了鉴别器的损失。 实验结果表明,基于生成对抗网络的图像去模糊算法比其他算法具有更好的恢复效果。

模型结构

生成器架构

生成器设计介绍

作者提到,目前,在现有的图像去模糊任务中,骨干网通常使用类似ResNet的网络。 大多数处理不同程度模糊图像的先进方法都使用多尺度输入方法来消除模糊。

然而,多尺度模式下的输入法往往会消耗更多的时间和占用大量的内存,因此该模型中的生成器使用特征金字塔网络而不是多尺度网络。设计的特征金字塔网络结构是一种编解码的形式,它包括两条路径。从浅到深的路径可以看作是编码部分,主要用于提取输入图像的特征。分辨率降低了,但它可以提取高级特征并压缩更多的上下文语义信息。从深到浅的路径可以看作是解码部分,通过上采样恢复空间分辨率,并结合高级特征和丰富的语义信息生成清晰的图像。此外,两条路径之间的水平连接补充了高分辨率的细节,有助于恢复更清晰的图像。

在论文的提议中,网络模型通过激活权值的方式,在生成器中加入关注模块,加强对重要特征的关注。 当卷积层数较浅时,加入注意力模块会使计算量过大,而卷积层数较深时,特征图之间的差异较小。  

 

class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ConvBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 4, 2, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.model(x)
        return x

class ConvBlock_1(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ConvBlock_1, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.model(x)
        return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.ConvBlock1 = ConvBlock(3, 32)
        self.ConvBlock2 = ConvBlock(32, 32)
        self.ConvBlock3 = ConvBlock(32, 32)
        self.ConvBlock4 = ConvBlock(32, 32)
        self.ConvBlock5 = ConvBlock(32, 32)
        self.ConvBlock6 = ConvBlock_1(32, 3)
        self.ConvBlock1_1 = ConvBlock_1(32, 32)
        self.con1_1 = nn.Conv2d(32, 32, 1)
        self.SE = SE(32, 8)
        self.Up = nn.Upsample(scale_factor=2)
        self.Up4 = nn.Upsample(scale_factor=8)
        self.Up3 = nn

相关文章:

  • CLIP学习笔记
  • 安全运维,等保测试常见解决问题。
  • 智慧校园系统在学生学习与生活中的应用
  • RK Android11 WiFi模组 AIC8800 驱动移植调试记录
  • 力扣-回溯-37 解数独
  • JavaScript异步编程方式多,区别是什么?
  • 有时候通过无线上网,有线共享局域网通过该有线为网关进行上网,设置指定的网关IP信息
  • UE5 编辑器辅助/加强 插件搜集
  • C#使用Semantic Kernel:接入本地deepseek-r1
  • 【多模态处理篇五】【DeepSeek文档解析:PDF/Word智能处理引擎】
  • C#初级教程(6)——函数:从基础到实践
  • 后端之路——阿里云OSS云存储
  • 【JavaScript进阶】构造函数数据常用函数
  • 【AI】openEuler 22.03 LTS SP4安装 docker NVIDIA Container Toolkit
  • Java集合框架全解析:从LinkedHashMap到TreeMap与HashSet面试题实战
  • 微信小程序修改个人信息头像(uniapp开发)
  • 机器学习实战(11):时间序列预测——循环神经网络(RNN)与 LSTM
  • NVIDIA A100 SXM4与NVIDIA A100 PCIe版本区别深度对比:架构、性能与场景解析
  • einops测试
  • C#导出dataGridView数据
  • 习近平同俄罗斯总统普京会谈
  • 陈丹燕:赤龙含珠
  • 视频丨习近平主席专机抵达莫斯科,俄战机升空护航
  • A股三大股指集体高开大涨超1%,券商、房地产涨幅居前
  • 短剧剧组在贵州拍戏突遇极端天气,演员背部、手臂被冰雹砸伤
  • 竞彩湃|巴萨客场淘汰国际米兰,巴黎双杀阿森纳