当前位置: 首页 > news >正文

unet结构, 为什么要下采样, 上采样?

U-Net 结构中,下采样(downsampling)和上采样(upsampling)是核心部分,它们分别帮助网络从输入图像中提取特征并最终恢复输出图像的空间分辨率。具体来说,U-Net 的结构包括一个 编码器(下采样)部分和一个 解码器(上采样)部分,这两部分共同作用,实现图像的分割或重建任务。

1. 下采样(Downsampling)

下采样通常是通过 卷积层池化层(如 MaxPooling)实现的。在 U-Net 中,编码器负责提取图像的低级和高级特征。下采样有以下几个目的:

特征提取:下采样通过卷积操作在较小的空间范围内提取图像的特征(如边缘、纹理、形状等)。随着网络层数的增加,网络会逐渐捕捉到更复杂的特征。

减少空间维度:每进行一次下采样,空间尺寸会减小,但通道数通常会增加。这使得模型能够以更高的抽象级别处理图像,而无需保留过多的细节。

捕捉全局信息:通过下采样,网络能够在更大的感受野内捕获全局信息。这对于分割等任务非常重要,因为要决定像素或区域是否属于某个类别,通常需要上下文信息。

2. 上采样(Upsampling)

上采样通常是通过 转置卷积(Deconvolution)上采样层(如 上采样 + 卷积)来实现的。U-Net 的解码器部分负责将低分辨率的特征图恢复成与原始图像相同的空间分辨率。上采样的目的包括:

恢复空间分辨率:经过下采样后,图像的空间分辨率降低,特征图的尺寸变小。上采样的目的是逐渐恢复这些低分辨率特征图的空间尺寸,直到恢复成原始图像的尺寸。

精细化预测:上采样通过将高层的抽象特征与低层的细节特征结合,帮助模型恢复图像的精细结构。在 U-Net 中,解码器的每一层通常会与编码器的对应层进行 跳跃连接(skip connection),将低层的特征与高层的特征融合,以便保留细节信息(如边缘、纹理等)。

细化分割边界:上采样在解码器中不仅是恢复分辨率,还通过跳跃连接保留了细节信息,有助于提高分割任务的精确度,特别是在分割边界或小区域时。

3. 为什么要下采样和上采样:

下采样 使得网络可以提取图像的高级抽象特征,同时通过减少空间维度减小计算量。

上采样 使得网络能够恢复图像的空间细节和结构,同时将高级特征与低级特征结合,进行精细化的输出。

U-Net 中,通过这种对称结构的 下采样和上采样,网络能够同时处理全局信息和局部细节,从而获得更好的分割效果。

4. U-Net 结构概述:

U-Net 结构通常包含以下几个关键部分:

编码器(下采样部分):由多个卷积层和池化层组成,逐渐降低图像的空间分辨率,同时增加通道数,提取特征。

解码器(上采样部分):通过上采样操作逐步恢复空间分辨率,同时结合编码器中对应层的特征。

跳跃连接(Skip Connections):将编码器和解码器中对应层的特征图进行连接,帮助恢复图像的细节。

5. U-Net 的下采样和上采样举例:

假设输入图像的大小为 256✖️256,那么:

• 在编码器中,经过多次卷积和池化(如 2x2 max pooling),空间分辨率可能降到 32 \times 32,但特征图的通道数会显著增加。

• 在解码器中,利用转置卷积或上采样操作将特征图的尺寸恢复到原始的 256 \times 256,同时通过跳跃连接引入高分辨率的特征,以便精确恢复细节。

总结:

下采样:通过减少空间分辨率提取高级特征,减少计算量并捕捉全局信息。

上采样:逐步恢复图像的空间分辨率,同时通过跳跃连接融合低层细节,帮助提高分割精度。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义U-Net的基本卷积块(双层卷积 + ReLU)
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# U-Net 编码(下采样)部分
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        # 编码部分(下采样)
        self.encoder1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 尺寸缩小一半
        self.encoder2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck 层
        self.bottleneck = DoubleConv(512, 1024)

        # 解码部分(上采样)
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)  # 上采样
        self.decoder4 = DoubleConv(1024, 512)  # 连接跳跃层
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(128, 64)

        # 最后的 1x1 卷积层用于输出
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # 编码路径(下采样)
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))

        # 解码路径(上采样 + 跳跃连接)
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((enc4, dec4), dim=1)  # 跳跃连接
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((enc3, dec3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((enc2, dec2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((enc1, dec1), dim=1)
        dec1 = self.decoder1(dec1)

        # 最终输出
        output = self.final_conv(dec1)

        return output

# 实例化模型
model = UNet(in_channels=3, out_channels=1)
print(model)

http://www.dtcms.com/a/109653.html

相关文章:

  • Docker安装开源项目x-ui详细图文教程
  • 【一步步开发AI运动APP】六、运动计时计数能调用
  • 天津大学合成生物技术全国重点实验室-随笔09
  • USB(通用串行总线)数据传输机制和包结构简介
  • 【蓝桥杯】算法笔记2
  • 怎么让一台云IPPBX实现多家酒店相同分机号码一起使用
  • LJF-Framework 第13章 LjfAsyncManager异步任务管理
  • keep-alive缓存
  • [dp5_多状态dp] 按摩师 | 打家劫舍 II | 删除并获得点数 | 粉刷房子
  • HTTP数据传输的几个关键字Header
  • 《操作系统真象还原》第五章(1)——获取内存容量
  • Leetcode 1262 -- 动态规划
  • #window系统php-v提示错误#
  • 一周学会Pandas2 Python数据处理与分析-Pandas2简介
  • Node.js 与 MySQL:深入理解与高效实践
  • VisMin:视觉最小变化理解
  • 强化学习_Paper_1988_Learning to predict by the methods of temporal differences
  • 【Pandas】pandas DataFrame values
  • MacOS中配置完环境变量后执行source ~/.bash_profile后,只能在当前shell窗口中生效
  • 【eNSP实验】RIP协议
  • WHAT - JWT(JSON Web Token)
  • 颜色归一化操作
  • 设计心得——状态机
  • STM32单片机入门学习——第12节: [5-2]对射式红外传感器计次旋转编码器计次
  • 多模态学习(八):2022 TPAMI——U2Fusion: A Unified Unsupervised Image Fusion Network
  • MySQL数据库脱敏实战指南:从原理到企业级实现
  • torch.nn中的非线性激活介绍合集——Pytorch中的非线性激活
  • Webacy 利用 Walrus 技术构建链上风险分析决策层
  • 软考又将迎来新的改革?
  • c#和c++脚本解释器科学运算