基于 PyTorch 的 UNet 与 NestedUNet 图像分割
图像分割是计算机视觉领域的重要任务,它旨在将图像中的每个像素分配到特定的类别。本文将详细介绍如何使用 PyTorch 实现经典的 UNet 及其改进版本 NestedUNet,并完整展示从数据预处理到模型训练和评估的全流程。
项目概述
本项目实现了两种主流的图像分割模型:
- 经典 UNet 模型
 - NestedUNet(也称为 U-Net++)模型
 
我们使用 DSB2018 数据集作为示例,展示如何构建一个完整的图像分割系统,包括数据预处理、模型定义、训练流程和结果评估。
项目结构
首先,让我们了解项目的文件结构:
plaintext
.
├── archs.py          # 模型架构定义(UNet和NestedUNet)
├── train.py          # 训练脚本
├── val.py            # 验证与评估脚本
├── losses.py         # 自定义损失函数
├── metrics.py        # 评估指标
├── dataset.py        # 数据集加载器
├── utils.py          # 工具函数
└── preprocess_dsb2018.py # 数据预处理脚本
数据预处理
在训练模型之前,我们需要对原始数据进行预处理。preprocess_dsb2018.py脚本负责这一工作:
python
运行
import os
from glob import glob
import cv2
import numpy as np
from tqdm import tqdmdef main():img_size = 96  # 统一图像尺寸为96x96paths = glob('inputs/stage1_train/*')# 创建输出目录os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)for i in tqdm(range(len(paths))):path = paths[i]# 读取图像img = cv2.imread(os.path.join(path, 'images',os.path.basename(path) + '.png'))# 合并所有掩码mask = np.zeros((img.shape[0], img.shape[1]))for mask_path in glob(os.path.join(path, 'masks', '*')):mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127mask[mask_] = 1# 处理不同通道数的图像if len(img.shape) == 2:img = np.tile(img[..., None], (1, 1, 3))if img.shape[2] == 4:img = img[..., :3]# 调整大小img = cv2.resize(img, (img_size, img_size))mask = cv2.resize(mask, (img_size, img_size))# 保存处理后的图像和掩码cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size,os.path.basename(path) + '.png'), img)cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size,os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))if __name__ == '__main__':main()
预处理步骤主要做了以下工作:
- 将所有图像统一调整为 96x96 大小
 - 合并多个掩码文件为一个
 - 处理不同通道数的图像,统一为 3 通道
 - 组织成标准的数据集目录结构
 
数据集加载器
dataset.py实现了自定义数据集类,方便加载和预处理图像数据:
python
运行
import os
import cv2
import numpy as np
import torch
import torch.utils.dataclass Dataset(torch.utils.data.Dataset):def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):self.img_ids = img_idsself.img_dir = img_dirself.mask_dir = mask_dirself.img_ext = img_extself.mask_ext = mask_extself.num_classes = num_classesself.transform = transformdef __len__(self):return len(self.img_ids)def __getitem__(self, idx):img_id = self.img_ids[idx]# 读取图像img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))# 读取掩码mask = []for i in range(self.num_classes):mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])mask = np.dstack(mask)# 应用数据增强if self.transform is not None:augmented = self.transform(image=img, mask=mask)img = augmented['image']mask = augmented['mask']# 归一化并调整通道顺序img = img.astype('float32') / 255img = img.transpose(2, 0, 1)  # 从HWC转为CHWmask = mask.astype('float32') / 255mask = mask.transpose(2, 0, 1)return img, mask, {'img_id': img_id}
这个数据集类支持:
- 加载多类别的掩码
 - 应用数据增强(通过 albumentations 库)
 - 自动进行图像归一化和通道顺序调整
 
模型架构
archs.py文件定义了 UNet 和 NestedUNet 两种模型架构。
VGGBlock 组件
两种模型都使用了 VGGBlock 作为基本构建块:
python
运行
class VGGBlock(nn.Module):def __init__(self, in_channels, middle_channels, out_channels):super().__init__()self.relu = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)self.bn1 = nn.BatchNorm2d(middle_channels)self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)return out
每个 VGGBlock 包含两个卷积层,每个卷积层后都跟着批归一化和 ReLU 激活函数。
UNet 模型
UNet 模型由编码器、解码器和跳跃连接组成:
python
运行
class UNet(nn.Module):def __init__(self, num_classes, input_channels=3, **kwargs):super().__init__()nb_filter = [32, 64, 128, 256, 512]  # 每个层级的滤波器数量self.pool = nn.MaxPool2d(2, 2)  # 下采样self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 上采样# 编码器部分self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])# 解码器部分(带跳跃连接)self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])# 最终卷积层,输出类别数self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def forward(self, input):# 编码器前向传播x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x2_0 = self.conv2_0(self.pool(x1_0))x3_0 = self.conv3_0(self.pool(x2_0))x4_0 = self.conv4_0(self.pool(x3_0))# 解码器前向传播(带跳跃连接)x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))output = self.final(x0_4)return output
NestedUNet 模型
NestedUNet(U-Net++)是 UNet 的改进版本,它引入了更多的跳跃连接,增强了特征融合:
python
运行
class NestedUNet(nn.Module):def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):super().__init__()nb_filter = [32, 64, 128, 256, 512]self.deep_supervision = deep_supervision  # 是否启用深度监督self.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)# 编码器部分self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])# 嵌套连接的解码器部分self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])# 深度监督的输出层if self.deep_supervision:self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)else:self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def forward(self, input):# 编码器和嵌套连接的前向传播x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))x2_0 = self.conv2_0(self.pool(x1_0))x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))x3_0 = self.conv3_0(self.pool(x2_0))x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))# 根据是否启用深度监督返回不同结果if self.deep_supervision:output1 = self.final1(x0_1)output2 = self.final2(x0_2)output3 = self.final3(x0_3)output4 = self.final4(x0_4)return [output1, output2, output3, output4]else:output = self.final(x0_4)return output
NestedUNet 的主要改进是引入了更多的嵌套跳跃连接,使低层级特征能够更直接地传递到高层级,同时支持深度监督(deep supervision),即从多个层级输出结果并联合优化,有助于模型更快收敛。
损失函数
losses.py实现了适用于图像分割的损失函数:
python
运行
import torch
import torch.nn as nn
import torch.nn.functional as Fclass BCEDiceLoss(nn.Module):def __init__(self):super().__init__()def forward(self, input, target):# BCE损失bce = F.binary_cross_entropy_with_logits(input, target)# Dice损失smooth = 1e-5input = torch.sigmoid(input)num = target.size(0)input = input.view(num, -1)target = target.view(num, -1)intersection = (input * target)dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)dice = 1 - dice.sum() / num# 组合损失return 0.5 * bce + diceclass LovaszHingeLoss(nn.Module):def __init__(self):super().__init__()def forward(self, input, target):input = input.squeeze(1)target = target.squeeze(1)# Lovasz Hinge损失,需要安装对应的库loss = lovasz_hinge(input, target, per_image=True)return loss
BCEDiceLoss 是 BCE 损失和 Dice 损失的组合,在医学图像分割中表现优异:
- BCE 损失擅长处理类别不平衡问题
 - Dice 损失更关注前景区域的重叠度
 
评估指标
metrics.py实现了图像分割常用的评估指标:
python
运行
import numpy as np
import torch
import torch.nn.functional as Fdef iou_score(output, target):"""计算交并比(IoU)"""smooth = 1e-5if torch.is_tensor(output):output = torch.sigmoid(output).data.cpu().numpy()if torch.is_tensor(target):target = target.data.cpu().numpy()# 二值化输出和目标output_ = output > 0.5target_ = target > 0.5# 计算交集和并集intersection = (output_ & target_).sum()union = (output_ | target_).sum()return (intersection + smooth) / (union + smooth)def dice_coef(output, target):"""计算Dice系数"""smooth = 1e-5output = torch.sigmoid(output).view(-1).data.cpu().numpy()target = target.view(-1).data.cpu().numpy()intersection = (output * target).sum()return (2. * intersection + smooth) / \(output.sum() + target.sum() + smooth)
IoU(交并比)是语义分割中最常用的指标,计算预测区域与真实区域的交集和并集之比。
训练脚本
train.py实现了完整的模型训练流程:
参数解析
首先定义了可配置的训练参数:
python
运行
def parse_args():parser = argparse.ArgumentParser()parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS",help='model name: (default: arch+timestamp)')parser.add_argument('--epochs', default=100, type=int,help='number of total epochs to run')parser.add_argument('-b', '--batch_size', default=8, type=int,help='mini-batch size (default: 8)')# 模型参数parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',choices=ARCH_NAMES, help='model architecture')parser.add_argument('--deep_supervision', default=False, type=str2bool)parser.add_argument('--input_channels', default=3, type=int,help='input channels')parser.add_argument('--num_classes', default=1, type=int,help='number of classes')parser.add_argument('--input_w', default=96, type=int,help='image width')parser.add_argument('--input_h', default=96, type=int,help='image height')# 损失函数parser.add_argument('--loss', default='BCEDiceLoss',choices=LOSS_NAMES, help='loss function')# 数据集参数parser.add_argument('--dataset', default='dsb2018_96',help='dataset name')parser.add_argument('--img_ext', default='.png',help='image file extension')parser.add_argument('--mask_ext', default='.png',help='mask file extension')# 优化器参数parser.add_argument('--optimizer', default='SGD',choices=['Adam', 'SGD'], help='optimizer')parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,help='initial learning rate')parser.add_argument('--momentum', default=0.9, type=float,help='momentum')parser.add_argument('--weight_decay', default=1e-4, type=float,help='weight decay')# 学习率调度器parser.add_argument('--scheduler', default='CosineAnnealingLR',choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])# ... 其他参数return parser.parse_args()
训练和验证函数
python
运行
def train(config, train_loader, model, criterion, optimizer):avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}model.train()  # 设置为训练模式pbar = tqdm(total=len(train_loader))for input, target, _ in train_loader:input = input.cuda()target = target.cuda()# 前向传播if config['deep_supervision']:outputs = model(input)loss = 0# 深度监督:对所有输出计算损失并平均for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 更新指标avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))pbar.set_postfix(loss=avg_meters['loss'].avg, iou=avg_meters['iou'].avg)pbar.update(1)pbar.close()return {'loss': avg_meters['loss'].avg, 'iou': avg_meters['iou'].avg}def validate(config, val_loader, model, criterion):avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}model.eval()  # 设置为评估模式with torch.no_grad():  # 禁用梯度计算pbar = tqdm(total=len(val_loader))for input, target, _ in val_loader:input = input.cuda()target = target.cuda()# 前向传播if config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)# 更新指标avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))pbar.set_postfix(loss=avg_meters['loss'].avg, iou=avg_meters['iou'].avg)pbar.update(1)pbar.close()return {'loss': avg_meters['loss'].avg, 'iou': avg_meters['iou'].avg}
主函数
python
运行
def main():config = vars(parse_args())# 创建输出目录os.makedirs('models/%s' % config['name'], exist_ok=True)# 保存配置with open('models/%s/config.yml' % config['name'], 'w') as f:yaml.dump(config, f)# 定义损失函数if config['loss'] == 'BCEWithLogitsLoss':criterion = nn.BCEWithLogitsLoss().cuda()else:criterion = losses.__dict__[config['loss']]().cuda()# 启用cudnn加速cudnn.benchmark = True# 创建模型print("=> creating model %s" % config['arch'])model = archs.__dict__[config['arch']](config['num_classes'],config['input_channels'],config['deep_supervision'])model = model.cuda()# 定义优化器params = filter(lambda p: p.requires_grad, model.parameters())if config['optimizer'] == 'Adam':optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'])elif config['optimizer'] == 'SGD':optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],nesterov=config['nesterov'], weight_decay=config['weight_decay'])# 定义学习率调度器if config['scheduler'] == 'CosineAnnealingLR':scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])elif config['scheduler'] == 'ReduceLROnPlateau':scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],verbose=1, min_lr=config['min_lr'])# ... 其他调度器# 数据加载img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)# 数据增强train_transform = Compose([albu.RandomRotate90(),albu.Flip(),OneOf([transforms.HueSaturationValue(),transforms.RandomBrightness(),transforms.RandomContrast(),], p=1),albu.Resize(config['input_h'], config['input_w']),transforms.Normalize(),])val_transform = Compose([albu.Resize(config['input_h'], config['input_w']),transforms.Normalize(),])# 创建数据加载器train_dataset = Dataset(...)val_dataset = Dataset(...)train_loader = torch.utils.data.DataLoader(...)val_loader = torch.utils.data.DataLoader(...)# 训练循环log = {'epoch': [], 'lr': [], 'loss': [], 'iou': [], 'val_loss': [], 'val_iou': []}best_iou = 0trigger = 0for epoch in range(config['epochs']):print('Epoch [%d/%d]' % (epoch, config['epochs']))# 训练一个epochtrain_log = train(config, train_loader, model, criterion, optimizer)# 验证val_log = validate(config, val_loader, model, criterion)# 更新学习率if config['scheduler'] == 'CosineAnnealingLR':scheduler.step()elif config['scheduler'] == 'ReduceLROnPlateau':scheduler.step(val_log['loss'])# 打印日志print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'% (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))# 保存日志log['epoch'].append(epoch)log['lr'].append(config['lr'])log['loss'].append(train_log['loss'])log['iou'].append(train_log['iou'])log['val_loss'].append(val_log['loss'])log['val_iou'].append(val_log['iou'])pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'], index=False)# 保存最佳模型if val_log['iou'] > best_iou:torch.save(model.state_dict(), 'models/%s/model.pth' % config['name'])best_iou = val_log['iou']print("=> saved best model")trigger = 0# 早停机制if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:print("=> early stopping")breaktorch.cuda.empty_cache()
验证与可视化
val.py用于加载训练好的模型进行验证,并可视化分割结果:
python
运行
def main():args = parse_args()# 加载配置with open('models/%s/config.yml' % args.name, 'r') as f:config = yaml.load(f, Loader=yaml.FullLoader)# 创建模型model = archs.__dict__[config['arch']](config['num_classes'],config['input_channels'],config['deep_supervision'])model = model.cuda()# 加载模型权重model.load_state_dict(torch.load('models/%s/model.pth' % config['name']))model.eval()# 准备数据img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]_, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)# 加载验证集val_transform = Compose([albu.Resize(config['input_h'], config['input_w']),transforms.Normalize(),])val_dataset = Dataset(...)val_loader = torch.utils.data.DataLoader(...)# 评估并保存结果avg_meter = AverageMeter()for c in range(config['num_classes']):os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)with torch.no_grad():for input, target, meta in tqdm(val_loader, total=len(val_loader)):input = input.cuda()target = target.cuda()# 模型预测if config['deep_supervision']:output = model(input)[-1]else:output = model(input)# 计算IoUiou = iou_score(output, target)avg_meter.update(iou, input.size(0))# 保存输出结果output = torch.sigmoid(output).cpu().numpy()for i in range(len(output)):for c in range(config['num_classes']):cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'),(output[i, c] * 255).astype('uint8'))print('IoU: %.4f' % avg_meter.avg)# 可视化结果plot_examples(input, target, model, num_examples=3)
可视化函数:
python
运行
def plot_examples(datax, datay, model, num_examples=6):fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))m = datax.shape[0]for row_num in range(num_examples):image_indx = np.random.randint(m)# 获取模型预测image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()# 绘制原图ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])ax[row_num][0].set_title("Original Image")# 绘制分割结果ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))ax[row_num][1].set_title("Segmented Image")# 绘制目标掩码ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])ax[row_num][2].set_title("Target Mask")plt.show()
训练与使用指南
数据准备:
bash
python preprocess_dsb2018.py模型训练:
bash
python train.py --dataset dsb2018_96 --arch NestedUNet --epochs 100 --batch_size 8模型验证:
bash
python val.py --name dsb2018_96_NestedUNet_woDS
总结
本文详细介绍了基于 PyTorch 的 UNet 和 NestedUNet 图像分割模型的实现。通过这个项目,我们可以学习到:
- 如何构建经典的 UNet 模型及其改进版本 NestedUNet
 - 如何设计适用于图像分割的损失函数(如 BCEDiceLoss)
 - 如何实现完整的训练流程,包括数据加载、数据增强、模型训练和验证
 - 如何评估分割模型的性能(使用 IoU 等指标)
 
该项目可以作为图像分割任务的基础框架,通过修改数据集加载部分和调整模型参数,可应用于不同的分割任务中。NestedUNet 通过增加嵌套连接和深度监督,通常能比传统 UNet 获得更好的分割性能,但计算成本也更高,实际应用中可根据需求选择合适的模型。
