基于 U-Net 的医学图像分割
项目概述
该项目实现了一个端到端的医学图像分割流程,包括:
数据预处理与增强
U-Net 模型构建与训练
模型验证与可视化
结果保存与分析
数据预处理
项目使用 DSB2018 数据集,通过 preprocess_dsb2018.py 进行数据预处理:
def main():img_size = 96paths = glob('inputs/stage1_train/*')# 创建输出目录os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)for i in tqdm(range(len(paths))):path = paths[i]img = cv2.imread(os.path.join(path, 'images', os.path.basename(path) + '.png'))mask = np.zeros((img.shape[0], img.shape[1]))# 合并多个掩码for mask_path in glob(os.path.join(path, 'masks', '*')):mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127mask[mask_] = 1# 调整图像尺寸img = cv2.resize(img, (img_size, img_size))mask = cv2.resize(mask, (img_size, img_size))# 保存处理后的图像和掩码cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size, os.path.basename(path) + '.png'), img)cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size, os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))预处理步骤包括:
统一图像尺寸为 96×96 像素
合并多个掩码文件为单个二值掩码
标准化图像格式和通道
数据集类设计
dataset.py 中实现了自定义数据集类,支持数据增强:
class Dataset(torch.utils.data.Dataset):def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):self.img_ids = img_idsself.img_dir = img_dirself.mask_dir = mask_dirself.img_ext = img_extself.mask_ext = mask_extself.num_classes = num_classesself.transform = transformdef __getitem__(self, idx):img_id = self.img_ids[idx]# 读取图像和掩码img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))mask = []for i in range(self.num_classes):mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])mask = np.dstack(mask)# 数据增强if self.transform is not None:augmented = self.transform(image=img, mask=mask)img = augmented['image']mask = augmented['mask']# 标准化和维度调整img = img.astype('float32') / 255img = img.transpose(2, 0, 1)mask = mask.astype('float32') / 255mask = mask.transpose(2, 0, 1)return img, mask, {'img_id': img_id}数据增强策略
项目使用 Albumentations 库进行数据增强:
训练集增强:
train_transform = Compose([albu.RandomRotate90(),albu.HorizontalFlip(),albu.OneOf([albu.HueSaturationValue(),albu.RandomBrightnessContrast(),], p=1),albu.Resize(config['input_h'], config['input_w']),albu.Normalize(), ])
验证集增强:
val_transform = Compose([albu.Resize(config['input_h'], config['input_w']),albu.Normalize(), ])
模型训练
train.py 实现了完整的训练流程:
参数配置
def parse_args():parser = argparse.ArgumentParser()parser.add_argument('--name', default="dsb2018_96_NestedUNet_woDS")parser.add_argument('--epochs', default=100, type=int)parser.add_argument('--batch_size', default=8, type=int)parser.add_argument('--arch', default='NestedUNet')parser.add_argument('--deep_supervision', default=False, type=str2bool)parser.add_argument('--loss', default='BCEDiceLoss')parser.add_argument('--optimizer', default='SGD')parser.add_argument('--lr', default=1e-3, type=float)# ... 更多参数训练循环
def train(config, train_loader, model, criterion, optimizer):model.train()for input, target, _ in train_loader:# 前向传播if config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()
模型验证与可视化
val.py 提供了模型验证和结果可视化功能:
def plot_examples(datax, datay, model, num_examples=6):fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))m = datax.shape[0]for row_num in range(num_examples):image_indx = np.random.randint(m)image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])ax[row_num][0].set_title("Orignal Image")ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))ax[row_num][1].set_title("Segmented Image localization")ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])ax[row_num][2].set_title("Target image")plt.show()关键特性
1. 深度监督
支持深度监督训练,通过多个输出层提供中间监督信号。
2. 灵活的损失函数
支持多种损失函数,包括 BCEWithLogitsLoss 和自定义的 BCEDiceLoss。
3. 学习率调度
提供多种学习率调度策略:
CosineAnnealingLR
ReduceLROnPlateau
MultiStepLR
ConstantLR
4. 早停机制
通过监控验证集性能实现早停,防止过拟合。
使用方式
训练模型
python train.py --dataset dsb2018_96 --arch NestedUNet
验证模型
python val.py --name dsb2018_96_NestedUNet_woDS
总结
该项目提供了一个完整的医学图像分割解决方案,具有以下优点:
模块化设计:各个组件独立,便于修改和扩展
丰富的数据增强:提高模型泛化能力
灵活的配置:通过配置文件管理所有超参数
完整的训练监控:记录训练过程中的各项指标
结果可视化:直观展示分割效果
这个项目不仅适用于细胞核分割,通过调整配置也可以应用于其他医学图像分割任务,为医学图像分析研究提供了有力的工具。
