UNet++
这是一个基于 PyTorch 的 医学图像分割项目,核心目标是实现二分类图像分割(如病灶区域检测、细胞分割等),采用 UNet 系列架构(默认 NestedUNet),支持完整的训练流程和性能评估。
项目核心功能
- 网络架构:支持多种 UNet 衍生架构(通过 archs 模块扩展),默认使用 NestedUNet,可选深度监督模式提升分割精度。
- 数据处理:基于 Albumentations 实现高效数据增强(旋转、翻转、亮度 / 对比度调整等),自动划分训练 / 验证集(8:2),适配自定义数据集格式。
- 训练配置:支持灵活配置批次大小、学习率、优化器(Adam/SGD)、学习率调度器(余弦退火 / ReduceLROnPlateau 等)、损失函数(自定义 BCEDiceLoss 或内置 BCEWithLogitsLoss)。
- 性能评估:以 IOU(交并比)为核心评估指标,实时监控训练 / 验证损失和 IOU,自动保存最优模型,支持早停防止过拟合。
- 兼容性:自动适配 CPU/GPU 训练,Windows/Linux 系统兼容(数据加载线程数可配置)。
项目结构依赖
- 核心依赖:PyTorch、Albumentations、Pandas、Scikit-learn、TQDM、PyYAML。
- 自定义模块:archs(网络架构定义)、losses(自定义损失函数)、dataset(数据集加载)、metrics(IOU 计算)、utils(工具函数)。
- 数据集格式:需在 inputs/[数据集名] 目录下存放 images(输入图像)和 masks(分割掩码),图像与掩码文件名一一对应。
核心训练流程
- 解析命令行参数,生成训练配置(含模型、数据、优化器等参数)。
- 初始化网络模型、损失函数、优化器和学习率调度器,适配训练设备(CPU/GPU)。
- 加载数据集并应用数据增强,生成训练 / 验证数据加载器。
- 迭代训练:单轮训练计算损失并反向传播更新参数,验证轮评估性能并记录指标。
- 动态保存最优模型(基于验证集 IOU),支持早停机制,训练日志和配置文件持久化存储。
项目优势
- 高灵活性:支持多架构、多损失函数、多调度器切换,适配不同分割场景。
- 工程化规范:配置文件保存、日志记录、进度条可视化,便于实验复现和结果对比。
- 性能优化:支持 GPU 加速、深度监督、数据增强等,提升分割精度和泛化能力。
核心代码 train.py
import albumentations as albuimport argparse import os from collections import OrderedDict from glob import globimport pandas as pd import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.optim as optim import yaml from albumentations.core.composition import Compose, OneOf from sklearn.model_selection import train_test_split from torch.optim import lr_scheduler from tqdm import tqdmimport archs import losses from dataset import Dataset from metrics import iou_score from utils import AverageMeter, str2boolARCH_NAMES = archs.__all__ LOSS_NAMES = losses.__all__ LOSS_NAMES.append('BCEWithLogitsLoss')""" 指定参数: --dataset dsb2018_96 --arch NestedUNet """def parse_args():parser = argparse.ArgumentParser()parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS",help='model name: (default: arch+timestamp)')parser.add_argument('--epochs', default=100, type=int, metavar='N',help='number of total epochs to run')parser.add_argument('-b', '--batch_size', default=8, type=int,metavar='N', help='mini-batch size (default: 16)')# modelparser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',choices=ARCH_NAMES,help='model architecture: ' +' | '.join(ARCH_NAMES) +' (default: NestedUNet)')parser.add_argument('--deep_supervision', default=False, type=str2bool)parser.add_argument('--input_channels', default=3, type=int,help='input channels')parser.add_argument('--num_classes', default=1, type=int,help='number of classes')parser.add_argument('--input_w', default=96, type=int,help='image width')parser.add_argument('--input_h', default=96, type=int,help='image height')# lossparser.add_argument('--loss', default='BCEDiceLoss',choices=LOSS_NAMES,help='loss: ' +' | '.join(LOSS_NAMES) +' (default: BCEDiceLoss)')# datasetparser.add_argument('--dataset', default='dsb2018_96',help='dataset name')parser.add_argument('--img_ext', default='.png',help='image file extension')parser.add_argument('--mask_ext', default='.png',help='mask file extension')# optimizerparser.add_argument('--optimizer', default='SGD',choices=['Adam', 'SGD'],help='loss: ' +' | '.join(['Adam', 'SGD']) +' (default: Adam)')parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,metavar='LR', help='initial learning rate')parser.add_argument('--momentum', default=0.9, type=float,help='momentum')parser.add_argument('--weight_decay', default=1e-4, type=float,help='weight decay')parser.add_argument('--nesterov', default=False, type=str2bool,help='nesterov')# schedulerparser.add_argument('--scheduler', default='CosineAnnealingLR',choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])parser.add_argument('--min_lr', default=1e-5, type=float,help='minimum learning rate')parser.add_argument('--factor', default=0.1, type=float)parser.add_argument('--patience', default=2, type=int)parser.add_argument('--milestones', default='1,2', type=str)parser.add_argument('--gamma', default=2/3, type=float)parser.add_argument('--early_stopping', default=-1, type=int,metavar='N', help='early stopping (default: -1)')parser.add_argument('--num_workers', default=0, type=int)config = parser.parse_args()return configdef train(config, train_loader, model, criterion, optimizer, device):avg_meters = {'loss': AverageMeter(),'iou': AverageMeter()}model.train()pbar = tqdm(total=len(train_loader))for input, target, _ in train_loader:input = input.to(device)target = target.to(device)# compute outputif config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)# compute gradient and do optimizing stepoptimizer.zero_grad()loss.backward()optimizer.step()avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))postfix = OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg),])pbar.set_postfix(postfix)pbar.update(1)pbar.close()return OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg)])def validate(config, val_loader, model, criterion, device):avg_meters = {'loss': AverageMeter(),'iou': AverageMeter()}# switch to evaluate modemodel.eval()with torch.no_grad():pbar = tqdm(total=len(val_loader))for input, target, _ in val_loader:input = input.to(device)target = target.to(device)# compute outputif config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))postfix = OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg),])pbar.set_postfix(postfix)pbar.update(1)pbar.close()return OrderedDict([('loss', avg_meters['loss'].avg),('iou', avg_meters['iou'].avg)])def main():config = vars(parse_args())# 自动选择设备(有GPU用GPU,无则用CPU)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"使用设备: {device}")if config['name'] is None:if config['deep_supervision']:config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])else:config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])os.makedirs('models/%s' % config['name'], exist_ok=True)print('-' * 20)for key in config:print('%s: %s' % (key, config[key]))print('-' * 20)with open('models/%s/config.yml' % config['name'], 'w') as f:yaml.dump(config, f)# define loss function (criterion)if config['loss'] == 'BCEWithLogitsLoss':criterion = nn.BCEWithLogitsLoss().to(device)else:criterion = losses.__dict__[config['loss']]().to(device)cudnn.benchmark = True if device.type == 'cuda' else False# create modelprint("=> creating model %s" % config['arch'])model = archs.__dict__[config['arch']](config['num_classes'],config['input_channels'],config['deep_supervision'])model = model.to(device)params = filter(lambda p: p.requires_grad, model.parameters())if config['optimizer'] == 'Adam':optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'])elif config['optimizer'] == 'SGD':optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],nesterov=config['nesterov'], weight_decay=config['weight_decay'])else:raise NotImplementedErrorif config['scheduler'] == 'CosineAnnealingLR':scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])elif config['scheduler'] == 'ReduceLROnPlateau':scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],verbose=1, min_lr=config['min_lr'])elif config['scheduler'] == 'MultiStepLR':scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])elif config['scheduler'] == 'ConstantLR':scheduler = Noneelse:raise NotImplementedError# Data loading codeimg_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)# 数据增强:替换RandomBrightness和RandomContrast为RandomBrightnessContrasttrain_transform = Compose([albu.RandomRotate90(),albu.HorizontalFlip(),OneOf([albu.HueSaturationValue(), # 调整色调、饱和度、明度albu.RandomBrightnessContrast(), # 同时调整亮度和对比度(替代原RandomBrightness和RandomContrast)], p=1),albu.Resize(config['input_h'], config['input_w']),albu.Normalize(),])val_transform = Compose([albu.Resize(config['input_h'], config['input_w']),albu.Normalize(),])train_dataset = Dataset(img_ids=train_img_ids,img_dir=os.path.join('inputs', config['dataset'], 'images'),mask_dir=os.path.join('inputs', config['dataset'], 'masks'),img_ext=config['img_ext'],mask_ext=config['mask_ext'],num_classes=config['num_classes'],transform=train_transform)val_dataset = Dataset(img_ids=val_img_ids,img_dir=os.path.join('inputs', config['dataset'], 'images'),mask_dir=os.path.join('inputs', config['dataset'], 'masks'),img_ext=config['img_ext'],mask_ext=config['mask_ext'],num_classes=config['num_classes'],transform=val_transform)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=config['batch_size'],shuffle=True,num_workers=config['num_workers'],drop_last=True)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=config['batch_size'],shuffle=False,num_workers=config['num_workers'],drop_last=False)log = OrderedDict([('epoch', []),('lr', []),('loss', []),('iou', []),('val_loss', []),('val_iou', []),])best_iou = 0trigger = 0for epoch in range(config['epochs']):print('Epoch [%d/%d]' % (epoch, config['epochs']))# train for one epochtrain_log = train(config, train_loader, model, criterion, optimizer, device)# evaluate on validation setval_log = validate(config, val_loader, model, criterion, device)if config['scheduler'] == 'CosineAnnealingLR':scheduler.step()elif config['scheduler'] == 'ReduceLROnPlateau':scheduler.step(val_log['loss'])print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'% (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))log['epoch'].append(epoch)log['lr'].append(config['lr'])log['loss'].append(train_log['loss'])log['iou'].append(train_log['iou'])log['val_loss'].append(val_log['loss'])log['val_iou'].append(val_log['iou'])pd.DataFrame(log).to_csv('models/%s/log.csv' %config['name'], index=False)trigger += 1if val_log['iou'] > best_iou:torch.save(model.state_dict(), 'models/%s/model.pth' %config['name'])best_iou = val_log['iou']print("=> saved best model")trigger = 0# early stoppingif config['early_stopping'] >= 0 and trigger >= config['early_stopping']:print("=> early stopping")breakif device.type == 'cuda':torch.cuda.empty_cache()if __name__ == '__main__':main()配置文件
arch: NestedUNet batch_size: 8 dataset: dsb2018_96 deep_supervision: false early_stopping: -1 epochs: 100 factor: 0.1 gamma: 0.6666666666666666 img_ext: .png input_channels: 3 input_h: 96 input_w: 96 loss: BCEDiceLoss lr: 0.001 mask_ext: .png milestones: 1,2 min_lr: 1.0e-05 momentum: 0.9 name: dsb2018_96_NestedUNet_woDS nesterov: false num_classes: 1 num_workers: 0 optimizer: SGD patience: 2 scheduler: CosineAnnealingLR weight_decay: 0.0001
