当前位置: 首页 > news >正文

UNet++

这是一个基于 PyTorch 的 医学图像分割项目,核心目标是实现二分类图像分割(如病灶区域检测、细胞分割等),采用 UNet 系列架构(默认 NestedUNet),支持完整的训练流程和性能评估。​

项目核心功能​

  1. 网络架构:支持多种 UNet 衍生架构(通过 archs 模块扩展),默认使用 NestedUNet,可选深度监督模式提升分割精度。​
  1. 数据处理:基于 Albumentations 实现高效数据增强(旋转、翻转、亮度 / 对比度调整等),自动划分训练 / 验证集(8:2),适配自定义数据集格式。​
  1. 训练配置:支持灵活配置批次大小、学习率、优化器(Adam/SGD)、学习率调度器(余弦退火 / ReduceLROnPlateau 等)、损失函数(自定义 BCEDiceLoss 或内置 BCEWithLogitsLoss)。​
  1. 性能评估:以 IOU(交并比)为核心评估指标,实时监控训练 / 验证损失和 IOU,自动保存最优模型,支持早停防止过拟合。​
  1. 兼容性:自动适配 CPU/GPU 训练,Windows/Linux 系统兼容(数据加载线程数可配置)。​

项目结构依赖​

  • 核心依赖:PyTorch、Albumentations、Pandas、Scikit-learn、TQDM、PyYAML。​
  • 自定义模块:archs(网络架构定义)、losses(自定义损失函数)、dataset(数据集加载)、metrics(IOU 计算)、utils(工具函数)。​
  • 数据集格式:需在 inputs/[数据集名] 目录下存放 images(输入图像)和 masks(分割掩码),图像与掩码文件名一一对应。​

核心训练流程​

  1. 解析命令行参数,生成训练配置(含模型、数据、优化器等参数)。​
  1. 初始化网络模型、损失函数、优化器和学习率调度器,适配训练设备(CPU/GPU)。​
  1. 加载数据集并应用数据增强,生成训练 / 验证数据加载器。​
  1. 迭代训练:单轮训练计算损失并反向传播更新参数,验证轮评估性能并记录指标。​
  1. 动态保存最优模型(基于验证集 IOU),支持早停机制,训练日志和配置文件持久化存储。​

项目优势​

  • 高灵活性:支持多架构、多损失函数、多调度器切换,适配不同分割场景。​
  • 工程化规范:配置文件保存、日志记录、进度条可视化,便于实验复现和结果对比。​
  • 性能优化:支持 GPU 加速、深度监督、数据增强等,提升分割精度和泛化能力。                                                                                                                                                                           
  • 核心代码      train.py

  • import albumentations as albuimport argparse
    import os
    from collections import OrderedDict
    from glob import globimport pandas as pd
    import torch
    import torch.backends.cudnn as cudnn
    import torch.nn as nn
    import torch.optim as optim
    import yaml
    from albumentations.core.composition import Compose, OneOf
    from sklearn.model_selection import train_test_split
    from torch.optim import lr_scheduler
    from tqdm import tqdmimport archs
    import losses
    from dataset import Dataset
    from metrics import iou_score
    from utils import AverageMeter, str2boolARCH_NAMES = archs.__all__
    LOSS_NAMES = losses.__all__
    LOSS_NAMES.append('BCEWithLogitsLoss')"""
    指定参数:
    --dataset dsb2018_96 
    --arch NestedUNet
    """def parse_args():parser = argparse.ArgumentParser()parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS",help='model name: (default: arch+timestamp)')parser.add_argument('--epochs', default=100, type=int, metavar='N',help='number of total epochs to run')parser.add_argument('-b', '--batch_size', default=8, type=int,metavar='N', help='mini-batch size (default: 16)')# modelparser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',choices=ARCH_NAMES,help='model architecture: ' +' | '.join(ARCH_NAMES) +' (default: NestedUNet)')parser.add_argument('--deep_supervision', default=False, type=str2bool)parser.add_argument('--input_channels', default=3, type=int,help='input channels')parser.add_argument('--num_classes', default=1, type=int,help='number of classes')parser.add_argument('--input_w', default=96, type=int,help='image width')parser.add_argument('--input_h', default=96, type=int,help='image height')# lossparser.add_argument('--loss', default='BCEDiceLoss',choices=LOSS_NAMES,help='loss: ' +' | '.join(LOSS_NAMES) +' (default: BCEDiceLoss)')# datasetparser.add_argument('--dataset', default='dsb2018_96',help='dataset name')parser.add_argument('--img_ext', default='.png',help='image file extension')parser.add_argument('--mask_ext', default='.png',help='mask file extension')# optimizerparser.add_argument('--optimizer', default='SGD',choices=['Adam', 'SGD'],help='loss: ' +' | '.join(['Adam', 'SGD']) +' (default: Adam)')parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,metavar='LR', help='initial learning rate')parser.add_argument('--momentum', default=0.9, type=float,help='momentum')parser.add_argument('--weight_decay', default=1e-4, type=float,help='weight decay')parser.add_argument('--nesterov', default=False, type=str2bool,help='nesterov')# schedulerparser.add_argument('--scheduler', default='CosineAnnealingLR',choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])parser.add_argument('--min_lr', default=1e-5, type=float,help='minimum learning rate')parser.add_argument('--factor', default=0.1, type=float)parser.add_argument('--patience', default=2, type=int)parser.add_argument('--milestones', default='1,2', type=str)parser.add_argument('--gamma', default=2/3, type=float)parser.add_argument('--early_stopping', default=-1, type=int,metavar='N', help='early stopping (default: -1)')parser.add_argument('--num_workers', default=0, type=int)config = parser.parse_args()return configdef train(config, train_loader, model, criterion, optimizer, device):avg_meters = {'loss': AverageMeter(),'iou': AverageMeter()}model.train()pbar = tqdm(total=len(train_loader))for input, target, _ in train_loader:input = input.to(device)target = target.to(device)# compute outputif config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)# compute gradient and do optimizing stepoptimizer.zero_grad()loss.backward()optimizer.step()avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))postfix = OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg),])pbar.set_postfix(postfix)pbar.update(1)pbar.close()return OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg)])def validate(config, val_loader, model, criterion, device):avg_meters = {'loss': AverageMeter(),'iou': AverageMeter()}# switch to evaluate modemodel.eval()with torch.no_grad():pbar = tqdm(total=len(val_loader))for input, target, _ in val_loader:input = input.to(device)target = target.to(device)# compute outputif config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))postfix = OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg),])pbar.set_postfix(postfix)pbar.update(1)pbar.close()return OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg)])def main():config = vars(parse_args())# 自动选择设备(有GPU用GPU,无则用CPU)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"使用设备: {device}")if config['name'] is None:if config['deep_supervision']:config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])else:config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])os.makedirs('models/%s' % config['name'], exist_ok=True)print('-' * 20)for key in config:print('%s: %s' % (key, config[key]))print('-' * 20)with open('models/%s/config.yml' % config['name'], 'w') as f:yaml.dump(config, f)# define loss function (criterion)if config['loss'] == 'BCEWithLogitsLoss':criterion = nn.BCEWithLogitsLoss().to(device)else:criterion = losses.__dict__[config['loss']]().to(device)cudnn.benchmark = True if device.type == 'cuda' else False# create modelprint("=> creating model %s" % config['arch'])model = archs.__dict__[config['arch']](config['num_classes'],config['input_channels'],config['deep_supervision'])model = model.to(device)params = filter(lambda p: p.requires_grad, model.parameters())if config['optimizer'] == 'Adam':optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'])elif config['optimizer'] == 'SGD':optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],nesterov=config['nesterov'], weight_decay=config['weight_decay'])else:raise NotImplementedErrorif config['scheduler'] == 'CosineAnnealingLR':scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])elif config['scheduler'] == 'ReduceLROnPlateau':scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],verbose=1, min_lr=config['min_lr'])elif config['scheduler'] == 'MultiStepLR':scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])elif config['scheduler'] == 'ConstantLR':scheduler = Noneelse:raise NotImplementedError# Data loading codeimg_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)# 数据增强:替换RandomBrightness和RandomContrast为RandomBrightnessContrasttrain_transform = Compose([albu.RandomRotate90(),albu.HorizontalFlip(),OneOf([albu.HueSaturationValue(),  # 调整色调、饱和度、明度albu.RandomBrightnessContrast(),  # 同时调整亮度和对比度(替代原RandomBrightness和RandomContrast)], p=1),albu.Resize(config['input_h'], config['input_w']),albu.Normalize(),])val_transform = Compose([albu.Resize(config['input_h'], config['input_w']),albu.Normalize(),])train_dataset = Dataset(img_ids=train_img_ids,img_dir=os.path.join('inputs', config['dataset'], 'images'),mask_dir=os.path.join('inputs', config['dataset'], 'masks'),img_ext=config['img_ext'],mask_ext=config['mask_ext'],num_classes=config['num_classes'],transform=train_transform)val_dataset = Dataset(img_ids=val_img_ids,img_dir=os.path.join('inputs', config['dataset'], 'images'),mask_dir=os.path.join('inputs', config['dataset'], 'masks'),img_ext=config['img_ext'],mask_ext=config['mask_ext'],num_classes=config['num_classes'],transform=val_transform)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=config['batch_size'],shuffle=True,num_workers=config['num_workers'],drop_last=True)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=config['batch_size'],shuffle=False,num_workers=config['num_workers'],drop_last=False)log = OrderedDict([('epoch', []),('lr', []),('loss', []),('iou', []),('val_loss', []),('val_iou', []),])best_iou = 0trigger = 0for epoch in range(config['epochs']):print('Epoch [%d/%d]' % (epoch, config['epochs']))# train for one epochtrain_log = train(config, train_loader, model, criterion, optimizer, device)# evaluate on validation setval_log = validate(config, val_loader, model, criterion, device)if config['scheduler'] == 'CosineAnnealingLR':scheduler.step()elif config['scheduler'] == 'ReduceLROnPlateau':scheduler.step(val_log['loss'])print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'% (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))log['epoch'].append(epoch)log['lr'].append(config['lr'])log['loss'].append(train_log['loss'])log['iou'].append(train_log['iou'])log['val_loss'].append(val_log['loss'])log['val_iou'].append(val_log['iou'])pd.DataFrame(log).to_csv('models/%s/log.csv' %config['name'], index=False)trigger += 1if val_log['iou'] > best_iou:torch.save(model.state_dict(), 'models/%s/model.pth' %config['name'])best_iou = val_log['iou']print("=> saved best model")trigger = 0# early stoppingif config['early_stopping'] >= 0 and trigger >= config['early_stopping']:print("=> early stopping")breakif device.type == 'cuda':torch.cuda.empty_cache()if __name__ == '__main__':main()

                配置文件

  • arch: NestedUNet
    batch_size: 8
    dataset: dsb2018_96
    deep_supervision: false
    early_stopping: -1
    epochs: 100
    factor: 0.1
    gamma: 0.6666666666666666
    img_ext: .png
    input_channels: 3
    input_h: 96
    input_w: 96
    loss: BCEDiceLoss
    lr: 0.001
    mask_ext: .png
    milestones: 1,2
    min_lr: 1.0e-05
    momentum: 0.9
    name: dsb2018_96_NestedUNet_woDS
    nesterov: false
    num_classes: 1
    num_workers: 0
    optimizer: SGD
    patience: 2
    scheduler: CosineAnnealingLR
    weight_decay: 0.0001
    

                                                                                   

http://www.dtcms.com/a/566147.html

相关文章:

  • git多个账号管理
  • 网站后台怎么打开北京网站优化wyhseo
  • 永州市住房和城乡建设局网站下载小程序
  • OSI网络模型(通信方向)
  • SiC MOSFET米勒平台/米勒效应详解
  • halcon分类器使用标准流程
  • 哈尔滨建设银行网站常州建站程序
  • 网站建设用源码建设报名系统
  • 大模型-vllm云端部署模型快速上手体验-5
  • 20.旋转图像(原地矩阵)
  • 网站建设与管理试卷Aphp网站开发视频网站
  • 中间件的前世今生:起源与发展历程
  • InfluxDB 应用场景与使用指南
  • Linux execve系统调用深度解析:从用户空间到进程替换的完整旅程
  • 蓝牙钥匙 第37次 企业车队管理场景下的智能化解决方案:从权限管理到访问控制
  • 福州做企业网站中山住房和建设局网站
  • 做网站活动利于优化的网站要备案吗
  • 南京网站关键词优化丫丫影院
  • auto-tracking自动埋点插件
  • 什么叫网站维护建购物网站难吗
  • 公司做网页要多少钱佛山seo
  • 美术馆网站建设概述网站如何收录快
  • 避免出现重复的属性方法:Python高级编程技巧详解
  • 营销型网站建设的五力原则包括深圳在线官网
  • 德州口碑好的网站制作公司爱站网关键词挖掘工具熊猫
  • 响应式外贸网站价格著名的wordpress网站
  • 【每日一面】实现一个深拷贝函数
  • 图标网站导航制作怎么做网站后台管理系统设计
  • 产品月报|睿本云10月产品功能迭代
  • 国外物流公司网站模板长沙专业网站制作