用 PyTorch 训练 NestedUNet 分割细胞核
一、项目背景:为什么要做细胞核分割?
细胞核分割是医学影像分析的基础任务之一,在病理诊断、细胞计数、疾病研究中都有重要应用。比如:
- 病理医生通过分析细胞核的形态(大小、形状、分布)判断细胞是否癌变;
- 细胞实验中,需要精确分割单个细胞核以统计数量或观察分裂状态。
传统方法依赖人工标注或阈值分割,效率低且精度差。而深度学习模型(如 U-net 系列)能自动学习细胞核的特征,实现高精度分割,大幅降低人工成本。
我们今天的目标是:用 NestedUNet 模型(U-net++ 的改进版)实现细胞核自动分割,最终在验证集上达到较高的 IoU(交并比,分割任务的核心指标)。
二、环境准备:一行代码搞定依赖
首先确保你的环境安装了以下库,推荐用 Anaconda 创建虚拟环境(避免版本冲突):
bash
# 创建虚拟环境(可选但推荐)
conda create -n seg_env python=3.8
conda activate seg_env# 安装核心依赖
pip install torch torchvision torchaudio # PyTorch框架(根据CUDA版本选择,详见官网)
pip install albumentations # 数据增强库(比torchvision更强大)
pip install numpy pandas matplotlib # 数据处理与可视化
pip install scikit-image tqdm # 图像处理与进度条
验证环境:运行python -c "import torch; print(torch.cuda.is_available())",输出True说明 GPU 可用(训练会快 10 倍以上),False则用 CPU 训练(适合入门调试)。
三、数据集解析:2018 Data Science Bowl 细胞核数据
我们使用的数据集是dsb2018_96,源自 2018 年 Data Science Bowl 比赛,已预处理为 96×96 的小尺寸图像,非常适合新手练手。
1. 数据集结构
数据集按 “图像 - 掩码” 对应存储,目录结构如下:
plaintext
inputs/
└── dsb2018_96/ # 数据集名称├── images/ # 输入图像(细胞核原始图)│ ├── 0.png│ ├── 1.png│ ...└── masks/ # 掩码(标注的细胞核区域)├── 0.png├── 1.png...
- 图像(images):96×96 像素的灰度图(单通道),显示细胞核的显微镜图像;
- 掩码(masks):与图像同名的二值图,白色区域(像素值 1)表示细胞核,黑色区域(像素值 0)表示背景。
2. 数据特点
- 任务类型:二分类语义分割(仅区分 “细胞核” 和 “背景”);
- 难点:细胞核大小不一、形状不规则,且存在重叠(比如两个细胞核粘在一起),对模型的细节捕捉能力要求高;
- 数据量:约 600 张训练图 + 150 张验证图(按 8:2 划分),数量适中,适合中等规模模型训练。
3. 数据获取
如果你没有数据集,可以按以下方式生成类似结构:
- 从Kaggle 官网下载原始 DSB2018 数据;
- 用
scikit-image将图像 Resize 到 96×96:python
运行
from skimage import io, transform img = io.imread("original_image.png") img_resized = transform.resize(img, (96, 96), anti_aliasing=True) io.imsave("inputs/dsb2018_96/images/0.png", img_resized)
四、代码实战:从 0 到 1 训练 NestedUNet
我们的代码分为 5 个核心模块:参数配置、数据加载、模型定义、训练 / 验证循环、主流程。每个模块都有详细注释,确保你能看懂每一行的作用。
1. 完整代码结构
先看整体框架,后面会逐部分解析:
python
运行
import os
import argparse
import yaml
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import matplotlib.pyplot as plt# 自定义模块(后面会实现)
from dataset import SegDataset # 数据集类
from archs import NestedUNet # 模型类
from loss import BCEDiceLoss # 损失函数
from utils import calculate_iou # 评估指标计算# 参数解析
def parse_args():# 省略,后面详细讲pass# 数据加载
def get_loaders(args):# 省略,后面详细讲pass# 训练函数
def train_fn(train_loader, model, criterion, optimizer, device):# 省略,后面详细讲pass# 验证函数
def validate_fn(valid_loader, model, criterion, device):# 省略,后面详细讲pass# 主函数
def main():# 省略,后面详细讲passif __name__ == "__main__":main()
2. 参数配置(parse_args 函数)
通过命令行参数灵活配置训练细节,核心参数如下(可根据需求调整):
python
运行
def parse_args():parser = argparse.ArgumentParser()# 模型参数parser.add_argument("--arch", default="NestedUNet", help="模型架构(NestedUNet/Unet等)")parser.add_argument("--deep_supervision", action="store_true", help="是否使用深度监督")parser.add_argument("--input_channels", default=1, type=int, help="输入通道数(灰度图为1,RGB为3)")parser.add_argument("--num_classes", default=1, type=int, help="输出类别数(二分类为1)")# 训练参数parser.add_argument("--epochs", default=50, type=int, help="训练轮数")parser.add_argument("--batch_size", default=16, type=int, help="批次大小")parser.add_argument("--lr", default=1e-4, type=float, help="初始学习率")parser.add_argument("--loss", default="bce_dice", help="损失函数(bce/bce_dice)")parser.add_argument("--optimizer", default="adam", help="优化器(adam/sgd)")parser.add_argument("--scheduler", default="cosine", help="学习率调度器")# 数据参数parser.add_argument("--dataset", default="dsb2018_96", help="数据集名称")parser.add_argument("--img_ext", default=".png", help="图像文件扩展名")parser.add_argument("--mask_ext", default=".png", help="掩码文件扩展名")parser.add_argument("--input_w", default=96, type=int, help="图像宽度")parser.add_argument("--input_h", default=96, type=int, help="图像高度")# 其他参数parser.add_argument("--name", default="nested_unet_dsb2018", help="实验名称(用于保存模型)")parser.add_argument("--early_stopping", default=10, type=int, help="早停轮数(防止过拟合)")return parser.parse_args()
关键参数说明:
--deep_supervision:NestedUNet 的核心特性,开启后模型会在多个解码阶段输出结果,损失函数对多输出加权,提升小目标分割精度;--loss:推荐用bce_dice(BCE 损失 + Dice 损失),BCE 擅长平衡类别,Dice 擅长处理样本不平衡(细胞核像素少);--early_stopping:若连续 10 轮验证集 IoU 不提升,则停止训练,避免过拟合。
3. 数据加载与增强(get_loaders 函数)
数据增强是提升分割精度的关键,尤其是医学数据量少时,通过增强可以 “伪造” 更多样本,提升模型泛化能力。
(1)自定义数据集类(dataset.py)
python
运行
import os
import numpy as np
from skimage import io
import torch
from torch.utils.data import Datasetclass SegDataset(Dataset):def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, transform=None):self.img_ids = img_ids # 图像文件名列表(不含扩展名)self.img_dir = img_dir # 图像目录self.mask_dir = mask_dir # 掩码目录self.img_ext = img_ext # 图像扩展名self.mask_ext = mask_ext # 掩码扩展名self.transform = transform # 数据增强器def __len__(self):return len(self.img_ids)def __getitem__(self, idx):img_id = self.img_ids[idx]# 读取图像和掩码(转为float32,便于PyTorch处理)img = io.imread(os.path.join(self.img_dir, img_id + self.img_ext)).astype(np.float32)mask = io.imread(os.path.join(self.mask_dir, img_id + self.mask_ext)).astype(np.float32)# 若图像是单通道(灰度图),添加通道维度([H,W]→[H,W,1])if len(img.shape) == 2:img = img[..., np.newaxis]mask = mask[..., np.newaxis]# 应用数据增强if self.transform is not None:augmented = self.transform(image=img, mask=mask)img = augmented["image"]mask = augmented["mask"]# 掩码二值化(确保只有0和1)mask = (mask > 0.5).astype(np.float32)return img, mask, img_id
(2)数据增强与加载器
python
运行
def get_loaders(args):# 数据路径img_dir = os.path.join("inputs", args.dataset, "images")mask_dir = os.path.join("inputs", args.dataset, "masks")img_ids = [os.path.splitext(f)[0] for f in os.listdir(img_dir) if f.endswith(args.img_ext)]# 划分训练集和验证集(8:2,随机种子41确保可复现)train_img_ids, valid_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)# 训练集增强:随机旋转、翻转、色彩抖动(提升模型鲁棒性)train_transform = A.Compose([A.RandomRotate90(), # 随机旋转90度A.Flip(), # 随机水平/垂直翻转A.OneOf([ # 随机选一种色彩增强A.RandomBrightnessContrast(),A.RandomGamma(),], p=0.5),A.Resize(args.input_h, args.input_w), # 调整尺寸A.Normalize(mean=[0.485], std=[0.229]), # 归一化(单通道用一个均值和标准差)ToTensorV2(), # 转为PyTorch张量([H,W,C]→[C,H,W])])# 验证集增强:仅调整尺寸和归一化(不添加噪声,保证评估准确)valid_transform = A.Compose([A.Resize(args.input_h, args.input_w),A.Normalize(mean=[0.485], std=[0.229]),ToTensorV2(),])# 创建数据集和加载器train_dataset = SegDataset(train_img_ids, img_dir, mask_dir, args.img_ext, args.mask_ext, train_transform)valid_dataset = SegDataset(valid_img_ids, img_dir, mask_dir, args.img_ext, args.mask_ext, valid_transform)train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)return train_loader, valid_loader, train_img_ids, valid_img_ids
增强技巧:
- 训练集用
OneOf随机选一种增强,避免过度增强导致特征失真; - 验证集不做随机变换,确保评估结果稳定;
- 单通道图像的归一化均值 / 标准差可根据数据集统计(这里用 ImageNet 的近似值)。
4. 模型定义:NestedUNet(U-net++)
NestedUNet 是 U-net 的升级版,通过密集特征融合和深度监督解决 U-net 的 “语义鸿沟” 问题,特别适合分割小目标(如细胞核)。
核心结构(简化版,完整代码见archs.py):
python
运行
import torch
import torch.nn as nn
import torch.nn.functional as Fclass NestedUNet(nn.Module):def __init__(self, input_channels=1, num_classes=1, deep_supervision=True):super().__init__()self.deep_supervision = deep_supervision# 编码端(下采样):提取语义特征self.down1 = self._down_block(input_channels, 64) # 输出64通道self.down2 = self._down_block(64, 128) # 输出128通道self.down3 = self._down_block(128, 256) # 输出256通道self.down4 = self._down_block(256, 512) # 输出512通道# 瓶颈层(最深层)self.center = self._conv_block(512, 1024) # 输出1024通道# 解码端(上采样):密集特征融合self.up4 = self._up_block(1024, 512)self.up3 = self._up_block(512, 256)self.up2 = self._up_block(256, 128)self.up1 = self._up_block(128, 64)# 输出层(深度监督:多个输出分支)self.out1 = nn.Conv2d(64, num_classes, kernel_size=1)self.out2 = nn.Conv2d(128, num_classes, kernel_size=1)self.out3 = nn.Conv2d(256, num_classes, kernel_size=1)self.out4 = nn.Conv2d(512, num_classes, kernel_size=1)# 卷积块(2次卷积+ReLU)def _conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True))# 下采样块(卷积块+最大池化)def _down_block(self, in_channels, out_channels):return nn.Sequential(self._conv_block(in_channels, out_channels),nn.MaxPool2d(kernel_size=2, stride=2))# 上采样块(上采样+特征拼接+卷积块)def _up_block(self, in_channels, out_channels):return nn.Sequential(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),nn.Conv2d(in_channels, out_channels, kernel_size=1), # 降维self._conv_block(out_channels * 2, out_channels) # 拼接编码端特征(×2是因为拼接))def forward(self, x):# 编码端输出x1 = self.down1(x)x2 = self.down2(x1)x3 = self.down3(x2)x4 = self.down4(x3)# 瓶颈层center = self.center(x4)# 解码端输出(密集融合)up4 = self.up4(center)up3 = self.up3(up4)up2 = self.up2(up3)up1 = self.up1(up2)# 深度监督:输出多个分支out1 = self.out1(up1)if self.deep_supervision:out2 = self.out2(up2)out3 = self.out3(up3)out4 = self.out4(up4)return [out1, out2, out3, out4] # 多输出用于深度监督else:return out1
NestedUNet 核心优势:
- 解码端每个阶段都融合编码端多个层次的特征,解决 “语义鸿沟”;
- 深度监督(多输出)让模型同时学习粗粒度和细粒度特征,小目标分割更准。
5. 损失函数:BCEDiceLoss(平衡类别 + 样本)
细胞核分割中,背景像素远多于细胞核(样本不平衡),且边界难区分,因此需要定制损失函数:
python
运行
import torch
import torch.nn as nn
import torch.nn.functional as Fclass BCEDiceLoss(nn.Module):def __init__(self, weight=None, size_average=True):super().__init__()def forward(self, inputs, targets, smooth=1):# Sigmoid激活(将输出转为0-1概率)inputs = torch.sigmoid(inputs) # 展平张量(计算全局损失)inputs = inputs.view(-1)targets = targets.view(-1)# BCE损失(处理类别不平衡)bce_loss = F.binary_cross_entropy(inputs, targets, reduction='mean')# Dice损失(衡量重叠度,对边界敏感)intersection = (inputs * targets).sum() dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) # 总损失:BCE + Dice(权重可调整)return bce_loss + dice_loss
为什么用组合损失:
- BCE 损失:通过交叉熵惩罚错分样本,适合平衡正负类别;
- Dice 损失:直接衡量预测与真实掩码的重叠度,对边界误差更敏感,适合分割任务。
6. 训练与验证循环
训练循环的核心是 “正向传播算损失→反向传播更新参数→验证集评估泛化能力”:
(1)训练函数
python
运行
def train_fn(train_loader, model, criterion, optimizer, device):model.train() # 训练模式(启用Dropout、BN等)total_loss = 0.0total_iou = 0.0# 进度条显示训练过程loop = tqdm(train_loader, total=len(train_loader))for imgs, masks, _ in loop:# 数据移到GPU/CPUimgs = imgs.to(device)masks = masks.to(device)# 梯度清零optimizer.zero_grad()# 正向传播outputs = model(imgs)# 计算损失(深度监督时,对多个输出加权)if isinstance(outputs, list):loss = 0.0for out in outputs:loss += criterion(out, masks)loss /= len(outputs) # 平均多输出损失else:loss = criterion(outputs, masks)# 反向传播+参数更新loss.backward()optimizer.step()# 计算IoU(评估指标)with torch.no_grad(): # 不计算梯度,节省内存if isinstance(outputs, list):pred = torch.sigmoid(outputs[0]) # 用第一个输出(最精细)计算IoUelse:pred = torch.sigmoid(outputs)pred = (pred > 0.5).float() # 二值化(0.5为阈值)iou = calculate_iou(pred, masks)# 累计损失和IoUtotal_loss += loss.item()total_iou += iou.item()# 更新进度条loop.set_postfix(loss=loss.item(), iou=iou.item())# 计算平均损失和IoUavg_loss = total_loss / len(train_loader)avg_iou = total_iou / len(train_loader)return avg_loss, avg_iou
(2)验证函数
python
运行
def validate_fn(valid_loader, model, criterion, device):model.eval() # 评估模式(冻结BN、Dropout)total_loss = 0.0total_iou = 0.0with torch.no_grad(): # 验证时不计算梯度loop = tqdm(valid_loader, total=len(valid_loader))for imgs, masks, _ in loop:imgs = imgs.to(device)masks = masks.to(device)outputs = model(imgs)# 计算损失(同训练函数)if isinstance(outputs, list):loss = 0.0for out in outputs:loss += criterion(out, masks)loss /= len(outputs)else:loss = criterion(outputs, masks)# 计算IoUif isinstance(outputs, list):pred = torch.sigmoid(outputs[0])else:pred = torch.sigmoid(outputs)pred = (pred > 0.5).float()iou = calculate_iou(pred, masks)total_loss += loss.item()total_iou += iou.item()loop.set_postfix(loss=loss.item(), iou=iou.item())avg_loss = total_loss / len(valid_loader)avg_iou = total_iou / len(valid_loader)return avg_loss, avg_iou
(3)IoU 计算函数(utils.py)
python
运行
import torchdef calculate_iou(pred, target, smooth=1e-6):# pred和target都是二值张量(0或1)intersection = (pred & target).sum()union = (pred | target).sum()iou = (intersection + smooth) / (union + smooth)return iou
7. 主流程:整合所有模块
python
运行
def main():args = parse_args()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"使用设备:{device}")# 创建模型保存目录os.makedirs(f"models/{args.name}", exist_ok=True)# 保存配置参数(方便复现)with open(f"models/{args.name}/config.yml", "w") as f:yaml.dump(vars(args), f)# 加载数据train_loader, valid_loader, train_ids, valid_ids = get_loaders(args)print(f"训练集样本数:{len(train_ids)},验证集样本数:{len(valid_ids)}")# 初始化模型、损失函数、优化器model = NestedUNet(input_channels=args.input_channels,num_classes=args.num_classes,deep_supervision=args.deep_supervision).to(device)if args.loss == "bce_dice":criterion = BCEDiceLoss()else:criterion = nn.BCEWithLogitsLoss() # 自带Sigmoidif args.optimizer == "adam":optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)else:optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)# 学习率调度器(cosine退火,自动调整学习率)if args.scheduler == "cosine":scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)else:scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=5)# 记录训练日志log = {"train_loss": [], "train_iou": [],"valid_loss": [], "valid_iou": []}# 早停相关变量best_iou = 0.0early_stopping_counter = 0# 训练循环for epoch in range(1, args.epochs + 1):print(f"\n===== Epoch {epoch}/{args.epochs} =====")# 训练train_loss, train_iou = train_fn(train_loader, model, criterion, optimizer, device)# 验证valid_loss, valid_iou = validate_fn(valid_loader, model, criterion, device)# 更新日志log["train_loss"].append(train_loss)log["train_iou"].append(train_iou)log["valid_loss"].append(valid_loss)log["valid_iou"].append(valid_iou)print(f"训练集:损失={train_loss:.4f},IoU={train_iou:.4f}")print(f"验证集:损失={valid_loss:.4f},IoU={valid_iou:.4f}")# 调整学习率if args.scheduler == "cosine":scheduler.step()else:scheduler.step(valid_iou) # 基于验证集IoU调整# 保存最佳模型(验证集IoU最高)if valid_iou > best_iou:best_iou = valid_ioutorch.save(model.state_dict(), f"models/{args.name}/best_model.pth")print(f"保存最佳模型(IoU={best_iou:.4f})")early_stopping_counter = 0 # 重置早停计数器else:early_stopping_counter += 1print(f"早停计数器:{early_stopping_counter}/{args.early_stopping}")if early_stopping_counter >= args.early_stopping:print("早停触发,停止训练")break# 绘制训练曲线plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(log["train_loss"], label="Train Loss")plt.plot(log["valid_loss"], label="Valid Loss")plt.title("Loss Curve")plt.legend()plt.subplot(1, 2, 2)plt.plot(log["train_iou"], label="Train IoU")plt.plot(log["valid_iou"], label="Valid IoU")plt.title("IoU Curve")plt.legend()plt.savefig(f"models/{args.name}/curves.png")print(f"训练曲线已保存到 models/{args.name}/curves.png")if __name__ == "__main__":main()
五、训练结果与分析
1. 预期效果
在 GTX 1080Ti 上训练 50 轮(约 1 小时),验证集 IoU 可达 0.85 以上(越高越好,1.0 为完美分割)。训练曲线应呈现:
- 损失曲线:训练集和验证集损失均逐渐下降,且差距不大(无过拟合);
- IoU 曲线:训练集和验证集 IoU 均逐渐上升,最终稳定在 0.85 左右。
2. 分割结果可视化
随机选择验证集图像,对比 “原始图像→真实掩码→模型预测”:
python
运行
import matplotlib.pyplot as plt
from skimage import io# 加载模型(略)
model.load_state_dict(torch.load("models/nested_unet_dsb2018/best_model.pth"))
model.eval()# 取一张验证集图像
img, mask, img_id = valid_dataset[0]
with torch.no_grad():pred = model(img.unsqueeze(0).to(device)) # 加batch维度pred = torch.sigmoid(pred[0])[0].cpu().numpy() # 转为numpypred = (pred > 0.5).astype(np.float32) # 二值化# 可视化
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.imshow(img[0], cmap="gray") # 原始图像(单通道)
plt.title("Original Image")
plt.subplot(132)
plt.imshow(mask[0], cmap="gray") # 真实掩码
plt.title("True Mask")
plt.subplot(133)
plt.imshow(pred[0], cmap="gray") # 预测掩码
plt.title("Predicted Mask")
plt.show()
理想结果:预测掩码与真实掩码高度重合,尤其是细胞核的边缘和重叠区域能被准确分割。
3. 常见问题与调优
-
过拟合:训练集 IoU 高(>0.9),验证集 IoU 低(<0.7)。解决:增加数据增强强度(如添加高斯噪声)、减小模型深度、使用早停。
-
分割边界模糊:预测掩码边缘不清晰。解决:增加 Dice 损失权重(让模型更关注边界)、使用更大的输入尺寸(如 128×128)。
-
小细胞核漏检:小目标未被分割。解决:开启深度监督(
--deep_supervision)、减小批次大小(让模型更关注小样本)。
六、总结与拓展
通过这个项目,你已经掌握了图像分割的核心流程:
- 数据预处理与增强(提升模型鲁棒性的关键);
- NestedUNet 模型的原理与实现(密集融合 + 深度监督);
- 损失函数与评估指标(BCE+Dice 损失、IoU 计算);
- 训练循环与调优技巧(早停、学习率调度)。
拓展方向
- 尝试更先进的模型:如 U-net+++、SegFormer(结合 Transformer);
- 多模态数据:融合细胞核的染色图像和荧光图像,提升分割精度;
- 后处理优化:用形态学操作(如腐蚀、膨胀)去除预测掩码中的噪声。
希望这篇教程能帮你快速入门图像分割,如果你在实战中遇到问题,欢迎在评论区交流~ 代码已整理到 GitHub,关注我获取完整项目链接!
