当前位置: 首页 > news >正文

源码解析(二):nnUNet

原文

系统框架

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

  1. 实验规划:分析数据集属性并生成管道配置
  2. 预处理:准备训练数据集(规范化、重采样等)
  3. 模型训练:使用配置的设置训练模型
  4. 模型评估:计算训练模型的性能指标
  5. 最佳配置选择:确定最佳模型或集成
  6. 推理:将选定的模型应用于新数据

安装

  1. 硬件要求

    • CPU
      • 现代多核处理器
    • 内存
      • 最低:16GB
      • 建议:32GB 或更多(特别是对于包含大图像的 3D 数据集)
    • 图形处理器
      • 用于训练:NVIDIA GPU,至少具有 11GB VRAM(RTX 2080 Ti、3090、4090、A5000 或更高版本)
      • 仅用于推理:配备 8GB+ VRAM 的 NVIDIA GPU,或仅使用 CPU(速度明显较慢)
    • 存储空间
      • 至少 100GB 可用空间(根据数据集大小而变化)
  2. 软件要求

    • Python
      • 3.10 或更高版本
    • 操作系统
      • Linux(推荐,尤其是 Ubuntu)
      • Windows 10/11
      • macOS(通过 MPS 或仅 CPU 提供有限的 GPU 支持)
    • CUDA 和 cuDNN
      • GPU 加速所需(兼容 PyTorch 2.1.2+)
  3. 安装方法

    通过pip安装

    安装 nnU-Net v2 最简单的方法是使用 pip:

    pip install nnunetv2
    

    这将自动安装 nnU-Net v2 及其所有依赖项,如pyproject.toml32-55文件。

    从 GitHub 仓库安装

    对于最新的开发版本或者如果您想为代码库做出贡献:

    git clone https://github.com/MIC-DKFZ/nnUNet
    cd nnUNet
    pip install -e .
    

核心模块代码解读

1.训练数据预处理

nnU-Net v2 中的预处理系统将原始医学影像数据转换为适用于神经网络训练和推理的标准化输入。

预处理系统的核心是**DefaultPreprocessor**类,它协调所有预处理操作。

代码文件为:preprocessing/preprocessors/default_preprocessor.py

a.图像加载与转置
第一步使用计划中指定的读取器加载图像并应用轴转置以确保方向一致:

data, data_properties = rw.read_images(image_files)
data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])

b.图像裁减

裁剪通过删除没有相关信息的背景区域来减少内存需求,nnU-Net 记录用于裁剪的边界框,以便在推理过程中进行逆转:

shape_before_cropping = data.shape[1:]
properties['shape_before_cropping'] = shape_before_cropping
data, seg, bbox = crop_to_nonzero(data, seg)
properties['bbox_used_for_cropping'] = bbox

c.正则化

归一化使图像间的强度值标准化,使网络训练更加稳定。在重采样之前应用归一化,以确保插值的准确性:

data = self._normalize(data, seg, configuration_manager,plans_manager.foreground_intensity_properties_per_channel)

d.重采样(各向异性处理)

重采样将体素之间的间距调整为配置中指定的目标间距:

重采样过程包括:

  1. 根据原始和目标间距计算新形状
  2. 应用适当的插值(图像和分割不同)
# /preprocessing/resampling/default_resampling.pynew_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)
data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing)
seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing)

e. 前景采样以提高训练效率

对于分割任务,nnU-Net 对前景位置进行采样,以通过平衡的块采样实现高效的训练:

properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this, verbose=self.verbose)

2.推理预处理

在推理过程中,预处理是即时执行的,流程如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3.用户单个case的图像预处理

preprocessor = DefaultPreprocessor()
data, seg, properties = preprocessor.run_case(input_images, seg_file, plans_manager, configuration_manager, dataset_json
)

实验自动化设置

设置的理念

nnUNet “无配置”理念指根据数据集属性自动设计和配置网络架构,平衡性能和硬件限制。该部分负责根据数据集特征自动配置神经网络架构和训练参数。实验规划系统会分析数据集属性和硬件约束,从而生成预处理、训练和推理的最佳设置。

实验自动化分析数据集特征以确定:

  • 重采样的目标间距
  • 网络架构和拓扑(池化操作、内核大小)
  • 内存高效的patch和batch size的大小
  • 适当的数据增强和预处理策略

代码文件位置:

nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py

**文件中ExperimentPlanner**几个主要职责:

  1. 读取数据集属性——分析数据集指纹以了解图像特征,如形状、间距和强度分布。
  2. 网络架构配置——根据数据集属性选择适当的网络深度、内核大小和特征图。
  3. 硬件感知优化——它估计 GPU 内存需求并调整补丁和批次大小以适应可用资源。
  4. 多配置规划它为 2D、3D 全分辨率和 3D 低分辨率方法创建配置计划。

设置的类型

配置维度解决用例
2d2D全分辨率训练速度快,适合高度各向异性的数据
3d_全分辨率3D全分辨率适合中等大小的 3D 体积
3d_lowres3D分辨率降低对于非常大的 3D 体积
3d_cascade_fullres3D全分辨率3d_lowres 之后的第二阶段,用于大型数据集

系统还会根据GPU的现存自动化设计ResEncUnet的参数大小

系统规划目标 GPU 内存用例
nnUNetPlannerResEncM8 GBRTX 2080Ti、1080Ti等
nnUNetPlannerResEncL24 GBRTX 3090、RTX 4090、A5000
nnUNetPlannerResEncXL40 GBA100 40GB、A6000等

模型的训练

训练框架

训练模块是 nnU-Net 的核心组件,负责模型训练,协调从数据加载到模型验证的整个过程

代码文件位置:

nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

训练系统的核心是**nnUNetTrainer类**,它协调整个训练过程。它提供了一个灵活的框架,可以扩展以适应不同的训练策略。在

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

主要职责:

  • 初始化和管理网络架构
  • 配置优化器和损失函数
  • 管理数据加载器和增强管道
  • 执行训练和验证循环
  • 实施检查点和日志记录

和模型训练相关的文件都在nnUNetTrainer文件夹下面,比如优化器,网络结构定义,训练流程等。

模型结构

网络架构根据计划和配置动态构建。默认架构类似 U-Net,但可以自定义,该方法**build_network_architecture**负责按照计划构建适当的网络架构。

# Implementation in nnUNetTrainer class
def build_network_architecture(architecture_class_name: str,arch_init_kwargs: dict,arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],num_input_channels: int,num_output_channels: int,enable_deep_supervision: bool = True) -> nn.Module:

数据增强

nnU-Net 采用大量数据增强来提高模型泛化能力:

增强类型示例
空间旋转、缩放、弹性变形、镜像
强度亮度、对比度、伽马校正
噪音高斯噪声、高斯模糊
其他模拟低分辨率、随机二进制算子

增强管道是根据数据集属性配置的:

def get_training_transforms(patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, use_mask_for_norm, is_cascaded, foreground_labels, regions, ignore_label):# Configures the training augmentation pipeline

损失函数

nnU-Net 结合使用了 Dice 损失和交叉熵(或基于区域的分割的 BCE),损失函数是根据任务类型(基于标签或基于区域的分割)以及是否启用深度监督来构建的

def _build_loss(self):if self.label_manager.has_regions:loss = DC_and_BCE_loss({},{'batch_dice': self.configuration_manager.batch_dice,'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},use_ignore_label=self.label_manager.ignore_label is not None,dice_class=MemoryEfficientSoftDiceLoss)else:loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)if self._do_i_compile():loss.dc = torch.compile(loss.dc)# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases# this gives higher resolution outputs more weight in the lossif self.enable_deep_supervision:deep_supervision_scales = self._get_deep_supervision_scales()weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])if self.is_ddp and not self._do_i_compile():# very strange and stupid interaction. DDP crashes and complains about unused parameters due to# weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff.# Anywho, the simple fix is to set a very low weight to this.weights[-1] = 1e-6else:weights[-1] = 0# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1weights = weights / weights.sum()# now wrap the lossloss = DeepSupervisionWrapper(loss, weights)return loss

超参设置

默认情况下,nnU-Net 使用带有 Nesterov 动量的 SGD 进行优化和多项式学习率衰减:

def configure_optimizers(self):optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,momentum=0.99, nesterov=True)lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)return optimizer, lr_scheduler

可通过训练器变体获得替代优化器:

  • nnUNetTrainerAdam:使用 Adam 或 AdamW 优化器
  • nnUNetTrainerAdan:使用Adan优化器(需要安装adan-pytorch)

多卡训练

nnU-Net 支持使用 PyTorch 的 DistributedDataParallel (DDP) 进行分布式训练,当使用多个 GPU 时,批量大小分布在各个工作器上,并且系统处理跨设备的梯度同步。

多种训练方法

nnU-Net 提供了基础训练器的几种变体,以支持不同的用例:

训练变体目的
nnUNetTrainer无深度监督缺乏深度监督的训练
nnUNetTrainerAdan使用 Adan 优化器
nnUNetTrainerAdam使用 Adam 优化器
nnUNetTrainer_Xepochs训练指定数量的 epoch
nnUNetTrainerBenchmark_5epochs用于基准性能

变体系统允许轻松定制,而无需修改核心训练器,比如下面的变体,直接用True或者False设置即可:

class nnUNetTrainerNoDeepSupervision(nnUNetTrainer):def __init__(self, plans, configuration, fold, dataset_json, device):super().__init__(plans, configuration, fold, dataset_json, device)self.enable_deep_supervision = False

开始训练

训练过程通常通过命令行启动:

nnUNetv2_train DATASET_NAME_OR_ID CONFIGURATION FOLD [-tr TRAINER] [-p PLANS][-pretrained_weights PATH] [-num_gpus NUM] [--npz] [--c] [--val][--val_best] [--disable_checkpointing] [-device DEVICE]

关键参数:

  • DATASET_NAME_OR_ID:用于训练的数据集
  • CONFIGURATION:要使用的配置(例如,“2d”、“3d_fullres”)
  • FOLD:交叉验证倍数(0-4 或“全部”)
  • tr:自定义训练器类(默认值:‘nnUNetTrainer’)
  • p:计划标识符(默认值:‘nnUNetPlans’)
  • num_gpus:用于训练的 GPU 数量

模型的推理

nnU-Net 推理系统负责应用已训练的模型对新的医学图像进行预测,将原始输入数据转换为精确的分割图。

推理系统的核心类是**nnUNetPredictor**,它协调整个预测过程。它管理:

  1. 模型初始化——加载网络架构和权重
  2. 预处理协调——确保正确的图像准备
  3. 预测执行——运行滑动窗口算法
  4. 结果处理——将逻辑转换为最终分割

初始化配置

使用前必须**nnUNetPredictor**进行初始化,一般使用以下参数:

范围默认描述
tile_step_size0.5移动滑动窗口的量(0.5 = 50%重叠)
use_gaussianTrue是否应用高斯加权进行窗口混合
use_mirroringTrue是否通过镜像使用测试时间增强
perform_everything_on_deviceTrue处理期间是否将数据保留在 GPU 上
deviceCUDA计算设备(推荐使用 CUDA)

初始化预测器后,必须使用经过训练的模型对其进行配置:

predictor.initialize_from_trained_model_folder(model_folder,  # Path to trained model folderuse_folds=(0,),  # Which folds to use (can combine multiple)checkpoint_name='checkpoint_final.pth'  # Which checkpoint to use
)

推理方法

代码文件位置:nnunetv2/inference/predict_from_raw_data.py

推理系统提供了多种预测方法,每种方法适用于不同的用例:

方法用例优势缺点
predict_from_files()基于多个文件的图像的批量预测最佳内存效率,并行处理需要磁盘上的文件
predict_from_list_of_npy_arrays()多张图片已作为数组加载无需文件 I/O更高的内存使用率
predict_single_npy_array()单幅图像预测最简单的 API最慢,无并行化
predict_from_data_iterator()自定义数据加载方案最大的灵活性更复杂的实现

推理系统的一个关键组件是滑动窗口预测机制,它可以处理可能无法一次性放入 GPU 内存的大型医学图像。

# nnunetv2/inference/sliding_window_prediction.py#L10-L54
def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8,value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \-> torch.Tensor:tmp = np.zeros(tile_size)center_coords = [i // 2 for i in tile_size]sigmas = [i * sigma_scale for i in tile_size]tmp[tuple(center_coords)] = 1gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)gaussian_importance_map = torch.from_numpy(gaussian_importance_map)gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor)gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype)# gaussian_importance_map cannot be 0, otherwise we may end up with nans!mask = gaussian_importance_map == 0gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask])return gaussian_importance_mapdef compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \List[List[int]]:assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size"assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'# our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of# 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size]num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)]steps = []for dim in range(len(tile_size)):# the highest step value for this dimension ismax_step_value = image_size[dim] - tile_size[dim]if num_steps[dim] > 1:actual_step_size = max_step_value / (num_steps[dim] - 1)else:actual_step_size = 99999999999  # does not matter because there is only one step at 0steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]steps.append(steps_here)return steps

多折交叉预测

推理系统通过集成平均支持多个网络(通常来自不同的交叉验证折叠)的预测:

  1. 使用多重折叠进行初始化:use_folds=(0, 1, 2, 3, 4)
  2. 对于每个输入:
    • 循环遍历所有模型权重
    • 从每个模型生成预测
    • 平均预测

这通常比使用单折叠产生更稳健的结果,但代价是预测时间更长。集成平均发生在logits级别(softmax/argmax之前),从数学上讲,这比平均分割更合理。

模型评估与最佳选择

nnU-Net v2 中的评估和模型选择系统提供了一个强大的框架,用于评估已训练分割模型的性能、选择最佳配置,并通过后处理和集成来改进结果。

评估指标

nnU-Net 的评估系统主要使用 Dice 系数作为主要性能指标,但也会计算:

  • Dice系数:测量预测和地面实况之间的空间重叠
  • 交并比(IoU):替代重叠度量
  • 真正例(TP)假正例(FP)假负例(FN)真负例(TN)

模型最佳选择

nnU-Net 为每个数据集训练多个模型配置,并自动选择性能最佳的模型。选择过程包括评估单个模型和模型集成,以确定哪个模型能获得最高的 Dice 分数。

默认情况下,nnU-Net 会考虑以下配置进行评估:

配置描述
2d二维U-Net
3d_全分辨率全分辨率 3D U-Net
3d_lowres低分辨率的 3D U-Net
3d_cascade_fullres3D U-Net 级联(低分辨率 → 全分辨率)

交叉验证结果收集

在模型选择之前,收集并合并所有交叉验证的结果:

  1. 将各个折叠的所有验证预测复制到统一文件夹中
  2. 根据事实评估收集到的预测
  3. 生成交叉验证性能的摘要
# Example of accumulating cross-validation results
merged_output_folder = join(output_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
accumulate_cv_results(output_folder, merged_output_folder, folds, num_processes, overwrite)

寻找最佳配置

代码文件位置:nnunetv2/evaluation/find_best_configuration.py

该**find_best_configuration**函数比较所有模型变体:

  1. 评估单个模型配置(2D、3D、级联等)
  2. 创建和评估模型集成(如果允许)
  3. 选择 Dice 分数最高的配置
  4. 确定最佳配置的最佳后处理
  5. 生成并保存推理指令

相关文章:

  • 03.MySQL表的操作详解
  • K3s简介、实战、问题记录
  • Java高效处理大文件:避免OOM的深度实践
  • 【STM32F1标准库】理论——外部中断
  • 用提示词写程序(3),VSCODE+Claude3.5+deepseek开发edge扩展插件V2
  • 纯汇编自制操作系统(四、应用程序等的实现)
  • vue3(入门,setup,ref,计算属性,watch)
  • 财管5-投资项目的评价指标现金流量构成
  • C# 类和继承(构造函数的执行)
  • Spring Ai 从Demo到搭建套壳项目(一)初识与实现与deepseek对话模式
  • YOLOv5-入门篇笔记
  • 鸿蒙OSUniApp声纹识别与语音验证:打造安全可靠的跨平台语音应用#三方框架 #Uniapp
  • Java并发编程实战 Day 3:volatile关键字与内存可见性
  • 3D Gaussian splatting 05: 代码阅读-训练整体流程
  • CSS篇-5
  • 箱式不确定集
  • 广东WordPress开发公司及服务
  • 搭建基于VsCode的ESP32的开发环境教程
  • Spring Boot DevTools 热部署
  • MATLAB实战:传染病模型仿真实现
  • 做外贸如何浏览国外网站/优化师助理
  • 做的最少的网站/整站seo定制
  • 临安区建设局网站/链爱交易平台
  • 中科汇联网站建设手册/厦门网站关键词推广
  • 企业 宣传 还要网站吗/成人大专
  • 加强公司网站建设/百家号权重查询