当前位置: 首页 > news >正文

U-Net网络学习笔记(1)

这篇博客不是数据结构或算法相关的,是我在完成导师任务的时候害怕忘记,记录的学习内容,我把重要内容整合,通过博客的形式保存。(我发现之前学习的时候不记真是学完就忘,记在其他地方又不太方便,从我现在的角度看csdn还是太权威了哈哈哈)就是我自己的学习笔记,也无偿分享给大家。如果内容有错误大家可以自行判断,或者在评论区提醒我。内容仅供参考,祝大家学习愉快~

(偷偷夸一下,像csdn这种博客网站,我写一篇博客最多就几万字,还是可以坚持下来。完美解决了我知识记不住、想要写点东西、知识笔记安全以及学习可视化的问题和需求,还可以记录一些感想心得,以及做项目的时候的一些小tips,真是一个非常好的网站!

回想起当时看到B站的博主的建议,让大家写博客,一晃一年时间过去了,365天,多少日夜更替,虽然现在也没有变成什么大拿,但还是能感受到自己在踏踏实实的进步。欲买桂花同载酒,终不似,少年游。我们没法像光一样追上时间,也不会The World,但是我们可以结结实实的成长,搭着时间这班快车,勇往直前。数风流人物,还看今朝。这段话敬自己,也敬诸位;敬未来,更敬现在!)

1、概述

U-Net 是一种基于卷积神经网络(CNN)的架构,最初由 Olaf Ronneberger等人在 2015 年提出,专门用于生物医学图像分割任务。U-Net 的设计灵感来源于经典的全卷积网络(FCN),通过引入跳过连接(skip connections)和对称的编码器-解码器结构,可以显著提升模型在小样本数据集上的性能。目前,U-Net及其变体已经成为许多计算机视觉任务中图像分割的首选方法之一。不仅在医学图像(如肿瘤、器官分割)中广泛使用,也应用于遥感图像、自动驾驶(道路、行人分割)、工业检测等多个领域。

主要优点

什么是图像分割?

图像分割是一种计算机视觉技术,它将数字图像划分为多个部分(像素集合),这些部分被称为片段。每个片段都对应于图像中的一组对象或区域,这有助于进行更高级别的图像分析,比如对象识别、场景理解等。图像分割的结果可以是一组边界框、轮廓线,或者是像素级别的掩膜,具体取决于应用需求和所使用的算法。下图为医学图像经过图像分割后的效果。

2、U型架构

从下图中我们可以看到,U-Net 的架构呈“U”形,由三部分组成:左侧的编码器(Encoder)和右侧的解码器(Decoder),还有中间灰色的跳跃连接(Skip Connections)。

3、蓝色箭头操作

表示 3x3 卷积操作,用于特征提取,旨在捕捉输入数据中的重要特征。

对于上图蓝色箭头的操作有以下解释

如果要保持维度不变,则s(步长)要为1,p(填充)也要为1。

这里对这四个概念展开说明

再对上述内容进行一个拓展,步长也是控制上采样的一个非常重要的参数。我们前面提到了下采样和普通采样(维度不变),所以我干脆也把上采样的过程学习了。

在数学上来说,转置卷积(上采样)是把稀疏矩阵T和输入向量相乘。

我自己手推了最原始的数学过程,就是字有点丑,供大家参考,看不懂就算了不是非学不可这块(不好意思QAQ)

import torch
import torch.nn as nn# 假设输入特征图的通道是 64,我们想通过卷积得到 128 个输出通道
# 卷积核大小是 3x3,填充(padding)为 1 是为了保持输出特征图的尺寸不变
# 步长(stride)为 1 意味着卷积核每次移动一个像素
conv_layer = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=1)# 示例输入:一个批次大小为 1,通道数为 64,尺寸为 256x256 的特征图
# (Batch_Size, Channels, Height, Width)
input_feature_map = torch.randn(1, 64, 256, 256)output_feature_map = conv_layer(input_feature_map)
print(f"3x3 卷积操作的输出尺寸: {output_feature_map.shape}")
# 通常会跟一个激活函数(如 ReLU)
relu_activation = nn.ReLU()
output_feature_map = relu_activation(output_feature_map)


4、灰色箭头操作

表示跳跃连接(skip connection),用于特征融合,确保在解码阶段能够有效地利用编码阶段提取的高分辨率特征。

以顶部连接为例子,64通道568^2维度特征图与128通道392^2维度特征图连接。由于空间维度不一样,要先对568维度做中心剪裁。

接着就是把他们简单的叠放在一起,没有做任何其他的改动。通道数就变成了192。

这里再多说一嘴,后面蓝色箭头的这个卷积双维度降维的过程。和上文提到的增加通道数,减少空间维度的方法不一样,我们只需要控制步长和padding,再选用适当大小的滤波器去卷积,要几个通道就卷积几次,空间维度大小公式如下:

但是这里要在降低空间维度的同时降低通道维度,方法就不一样了。以上面说的数据为例子,对于192个通道,要得到64个通道,就要64个卷积核,每个卷积核,卷积192次在一个位置上,然后把每次的结果通过加权求和的方式,得到新的特征图的对应位置的值,然后一直卷积下去,直到得到一张390^2维度的特征图,再换下一个卷积核。一共卷积64次,得到64个特征图,即64维通道数。

(这里不是很好理解,大家可以多看看,其实看明白了就算是一个交叉计算的数学过程)

# 假设这是来自编码器的特征图(高分辨率,通道数较多)
encoder_feature = torch.randn(1, 128, 128, 128)# 假设这是来自解码器的上采样后的特征图(分辨率恢复,通道数较少)
decoder_upsampled_feature = torch.randn(1, 64, 128, 128)# 使用 torch.cat 在通道维度上进行拼接
# dim=1 表示在通道维度上拼接
concatenated_feature = torch.cat([encoder_feature, decoder_upsampled_feature], dim=1)print(f"跳跃连接(拼接)后的特征尺寸: {concatenated_feature.shape}")
# 拼接后的通道数会是 128 + 64 = 192


5、红色箭头操作

表示池化操作(pooling),用于降低特征图的空间维度,从而减少计算量并提取更具抽象性的特征。

这里的池化操作就相对简单些,在这张 U-Net 示意图中,红色箭头表示的都是 2×2 最大池化(max pooling)操作,stride=2,它只在空间维度上做“下采样”,通道数保持不变。

# 最大池化层,池化窗口大小为 2x2,步长为 2
# 这会将特征图的 H 和 W 都减半
max_pool_layer = nn.MaxPool2d(kernel_size=2, stride=2)# 假设输入是 128 通道,尺寸 256x256 的特征图
input_feature_map = torch.randn(1, 128, 256, 256)output_feature_map = max_pool_layer(input_feature_map)
print(f"池化操作的输出尺寸: {output_feature_map.shape}") # 尺寸会变为 128x128


6、绿色箭头操作

表示上采样(upsample)操作,用于恢复特征图的空间维度,以便与编码器的特征图进行拼接。

绿色箭头只管放空间不要动通道,通道数的调整全靠后面的卷积(蓝箭头)来完成。具体操作可以看我上面对于上采样的介绍。

# 转置卷积层,用于将特征图的尺寸翻倍
# 假设输入是 256 通道,尺寸 64x64 的特征图,我们想输出 128 通道,尺寸 128x128
# output_padding=1 在 stride=2 的情况下有时用于确保输出尺寸精确匹配
upsample_layer = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)# 假设输入特征图的通道是 256,尺寸为 64x64
input_feature_map = torch.randn(1, 256, 64, 64)output_feature_map = upsample_layer(input_feature_map)
print(f"上采样操作(转置卷积)的输出尺寸: {output_feature_map.shape}") # 尺寸会变为 128x128


7、青色箭头操作

表示 1x1 卷积操作,用于生成最终的输出结果。1×1 卷积是纯粹在通道维度上做线性映射,不会改变图像宽高。

# 假设我们最终的特征图是 64 个通道,我们想将其转换为 2 个输出类别(比如前景和背景)
final_conv_layer = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)# 假设输入是 64 通道,尺寸 256x256 的特征图
input_feature_map = torch.randn(1, 64, 256, 256)output_segmentation_map = final_conv_layer(input_feature_map)
print(f"1x1 卷积(最终输出)的尺寸: {output_segmentation_map.shape}")
# 输出的通道数会是你的类别数

以上就是各个关键操作的解析以及模型的大体架构,其实理解完上述内容,就已经理解了整个模型了,在这里就不再赘述整体操作流程。

7、输出结果处理

这里以二分类为例子说明一下这个过程。

为了运用得到的阈值化二值掩码,通常会采取以下两种策略。

但由于有的时候,直接看掩码(全黑+白块)很不直观,也看不到它在原图上对应什么内容,所以就会使用叠加色的方法。叠加后,就能在原图基础上“披上一层”高亮色,直观地看到哪些部分被分割为前景。一般选用纯红,因为在自然图像中很少出现纯红,所以能很明显地和背景区分开来。(就好像想象给一张照片打了一个“贴纸”——那个贴纸是半透明的红色,只贴在你想强调的区域。看照片时,红色贴纸让你马上注意到哪里是“前景”,同时还能透过贴纸看到下方的原始细节。)

具体流程如下:

通过这些方法,就可以清楚地看到最终的分割出来的图像。

8、代码实现

# u_net_voc_vis.pyimport os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms as T, datasets
import matplotlib.pyplot as plt  # ← 用于可视化# ===== 设备定义 =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ===== U-Net 定义(同前) =====
def double_conv(in_ch, out_ch):return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=0), nn.ReLU(inplace=True),nn.Conv2d(out_ch, out_ch, 3, padding=0), nn.ReLU(inplace=True),)def center_crop(enc_feat, dec_feat):_, _, H, W = dec_feat.size()enc_H, enc_W = enc_feat.size(2), enc_feat.size(3)sh, sw = (enc_H-H)//2, (enc_W-W)//2return enc_feat[:, :, sh:sh+H, sw:sw+W]def crop_and_concat(enc_feat, dec_feat):enc_c = center_crop(enc_feat, dec_feat)return torch.cat([dec_feat, enc_c], dim=1)class UNet(nn.Module):def __init__(self, in_ch=3, n_cls=21):super().__init__()self.enc1, self.pool1 = double_conv(in_ch,64), nn.MaxPool2d(2)self.enc2, self.pool2 = double_conv(64,128), nn.MaxPool2d(2)self.enc3, self.pool3 = double_conv(128,256), nn.MaxPool2d(2)self.enc4, self.pool4 = double_conv(256,512), nn.MaxPool2d(2)self.bottleneck       = double_conv(512,1024)self.up4, self.dec4   = nn.ConvTranspose2d(1024,512,2,2), double_conv(512+512,512)self.up3, self.dec3   = nn.ConvTranspose2d(512,256,2,2), double_conv(256+256,256)self.up2, self.dec2   = nn.ConvTranspose2d(256,128,2,2), double_conv(128+128,128)self.up1, self.dec1   = nn.ConvTranspose2d(128, 64,2,2), double_conv(64+64,64)self.out_conv         = nn.Conv2d(64, n_cls, 1)def forward(self, x):c1, p1 = self.enc1(x), self.pool1(self.enc1(x))c2, p2 = self.enc2(p1), self.pool2(self.enc2(p1))c3, p3 = self.enc3(p2), self.pool3(self.enc3(p2))c4, p4 = self.enc4(p3), self.pool4(self.enc4(p3))bn     = self.bottleneck(p4)u4     = self.up4(bn);   d4 = self.dec4(crop_and_concat(c4, u4))u3     = self.up3(d4);   d3 = self.dec3(crop_and_concat(c3, u3))u2     = self.up2(d3);   d2 = self.dec2(crop_and_concat(c2, u2))u1     = self.up1(d2);   d1 = self.dec1(crop_and_concat(c1, u1))return self.out_conv(d1)# ===== 数据变换与加载 =====
INPUT_SIZE = (256,256)
NUM_CLASSES = 21def voc_transform(img, mask):# 同时 Resize, ToTensorimg = T.Resize(INPUT_SIZE, interpolation=T.InterpolationMode.BILINEAR)(img)mask = T.Resize(INPUT_SIZE, interpolation=T.InterpolationMode.NEAREST)(mask)img = T.ToTensor()(img)mask = torch.as_tensor(np.array(mask), dtype=torch.long)mask[mask==255] = 255return img, maskclass VOCWrapper(torch.utils.data.Dataset):def __init__(self, ds): self.ds=dsdef __len__(self):     return len(self.ds)def __getitem__(self,i):img,mask=self.ds[i]; return voc_transform(img,mask)train_ds = datasets.VOCSegmentation('voc_root',year='2012',image_set='train',download=True)
val_ds   = datasets.VOCSegmentation('voc_root',year='2012',image_set='val',  download=False)
train_loader = DataLoader(VOCWrapper(train_ds),batch_size=8,shuffle=True,num_workers=4)
val_loader   = DataLoader(VOCWrapper(val_ds),  batch_size=1,shuffle=False,num_workers=2)# ===== 加载模型 =====
model = UNet(in_ch=3, n_cls=NUM_CLASSES).to(device)
model.load_state_dict(torch.load('unet_voc.pth'))  # 假设你已训练并保存
model.eval()# ===== 可视化一张验证图 =====
imgs, masks = next(iter(val_loader))     # 取一个 batch(batch_size=1)
imgs = imgs.to(device)
with torch.no_grad():logits = model(imgs)                 # [1,21,256,256]probs  = torch.softmax(logits, dim=1)  # 多分类 softmaxpred   = probs.argmax(dim=1).cpu().squeeze(0).numpy()  # [256,256]# 准备叠加:原图 numpy, 掩码 binary
img_np = imgs.cpu().squeeze(0).permute(1,2,0).numpy()  # HWC, [0,1]
mask_np = (pred>0).astype(np.uint8)                   # 前景(>0)为1,背景(0)为0# 构造红色高亮层
overlay = np.zeros_like(img_np)
overlay[...,0] = 1.0  # 红色通道=1alpha = 0.5
# 将原图与红色 overlay 按 alpha 混合(只在 mask==1 处)
vis = img_np.copy()
vis[mask_np==1] = (1-alpha)*img_np[mask_np==1] + alpha*overlay[mask_np==1]# 显示
plt.figure(figsize=(12,4))
plt.subplot(1,3,1); plt.imshow(img_np);        plt.title('Input Image'); plt.axis('off')
plt.subplot(1,3,2); plt.imshow(pred, cmap='tab20'); plt.title('Pred Mask');  plt.axis('off')
plt.subplot(1,3,3); plt.imshow(vis);           plt.title('Overlay');    plt.axis('off')
plt.tight_layout()
plt.show()

这是基于pytorch的框架来写的,可能不是原版论文里面的代码实现,但功能和结构应该是大差不差。

9、结语

这篇关于u-net的学习笔记就到这里了,因为草稿没法保存太多,如果后面有内容,我再更新下一篇,可能会是关于这个网络更细节的问题,还有代码实现或者实验出现的问题的说明。由于本人比较懒,但有注重美观,所以很多内容都直接截AI的图了,虽然里面有一些废话,但大体内容还是比较清楚的,并且整体逻辑我整理的是一脉相传的。这篇文章我写的蛮快的,一天就写完了,如果有问题欢迎大家评论区指出!

http://www.dtcms.com/a/274618.html

相关文章:

  • ARM单片机OTA解析(二)
  • cesium添加原生MVT矢量瓦片方案
  • 在 Spring Boot 中使用 WebMvcConfigurer
  • 【SpringBoot】配置文件学习
  • linux kernel struct regmap_config结构详解
  • 力扣242.有效的字母异位词
  • MySQL5.7版本出现同步或插入中文出现乱码或???显示问题处理
  • vector之动态二维数组的底层
  • django queryset 去重
  • JavaSE -- StreamAPI 详细介绍(上篇)
  • Java开发新宠!飞算JavaAI深度体验评测
  • 获取华为开源3D引擎 (OpenHarmony),把引擎嵌入VUE中
  • string模拟实现
  • 信号肽预测工具PrediSi本地化
  • 《打破预设的编码逻辑:Ruby元编程的动态方法艺术》
  • 内存踩踏全解析:原理 + 实战案例 + 项目排查技巧
  • 2025十大免费销售管理软件推荐
  • 基于物联网的智能体重秤设计与实现
  • 测试第一定律
  • 如何通过公网IP访问部署在kubernetes中的服务?
  • AVL平衡二叉树
  • 为什么必须掌握Java异常处理机制?——从代码健壮性到面试必考题全解析
  • 阿里云服务器,CentOS7.9上安装YApi 接口管理平台
  • Linux修炼:权限
  • vue2往vue3升级需要注意的点(个人建议非必要别直接升级)
  • 基于规则匹配的文档标题召回
  • Leaflet面试题及答案(21-40)
  • PHT-CAD 笔记
  • 【每日算法】专题八_分治_归并排序
  • k8s新增jupyter服务