当前位置: 首页 > news >正文

PyTorch深度学习框架60天进阶学习计划 - 第36天:医疗影像诊断(一)

PyTorch深度学习框架60天进阶学习计划 - 第36天:医疗影像诊断(一)

朋友们!真没想到能写到第36天!今天我们要踏入一个既充满挑战又极具意义的领域——医疗影像诊断。我们将学习如何利用3D ResNet对肺部CT进行分析,探索适合医学图像的数据增强技术,并解决医疗数据中常见的类别不平衡问题。

医疗AI有一句玩笑:“普通的AI模型出错了,可能只是把猫识别成狗;医疗AI出错了,可能就把健康人送进了ICU。” 所以,让我们带着敬畏之心,开始今天的学习吧!

一、医疗影像诊断概述

医疗影像诊断是AI在医疗领域最有前景的应用之一。与普通图像不同,医疗影像通常具有以下特点:

  1. 维度多样:CT和MRI等医疗影像是3D数据,而不是简单的2D图像
  2. 数据稀缺:标注的医疗数据远少于普通图像数据集
  3. 类别不平衡:疾病样本通常远少于健康样本
  4. 高精度要求:医疗诊断对准确性要求极高,容错率低

今天我们将聚焦于肺部CT的分析,这在肺癌、肺炎和COVID-19等疾病诊断中有重要应用。

二、3D ResNet结构设计

2.1 为什么选择ResNet?

在医疗影像中,我们通常需要提取复杂的特征。ResNet的残差连接可以有效解决深层网络的梯度消失问题,使我们能够构建更深的网络。同时,医学特征往往需要从微小的变化中捕捉,ResNet良好的特征提取能力使其成为理想选择。

2.2 从2D到3D的转换

将2D ResNet转换为3D版本主要涉及以下变化:

2D组件3D对应组件变化说明
Conv2dConv3d卷积核从(k×k)变为(k×k×k)
MaxPool2dMaxPool3d池化窗口从(k×k)变为(k×k×k)
BatchNorm2dBatchNorm3d归一化维度增加
Adaptive AvgPool2dAdaptive AvgPool3d自适应池化维度增加

2.3 3D ResNet基本结构

我们的3D ResNet主要由以下部分组成:

  1. 初始卷积层:捕捉基本特征
  2. 残差块:提取复杂特征并解决梯度消失问题
  3. 全局池化层:降维并保留重要特征
  4. 全连接层:进行最终分类
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock3D(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, 
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck3D(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(Bottleneck3D, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
                              padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=2, zero_init_residual=False):
        super(ResNet3D, self).__init__()
        self.in_planes = 64
        
        # 初始卷积层
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        
        # 残差层
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        # 分类头
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # 权重初始化
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # 残差块特殊初始化
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck3D):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock3D):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_planes, planes * block.expansion, kernel_size=1, 
                         stride=stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.in_planes, planes, stride, downsample))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 输入处理和初始特征提取
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # 特征提取网络
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # 分类头
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

def resnet18_3d(num_classes=2, **kwargs):
    """18层3D ResNet"""
    return ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

def resnet34_3d(num_classes=2, **kwargs):
    """34层3D ResNet"""
    return ResNet3D(BasicBlock3D, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet50_3d(num_classes=2, **kwargs):
    """50层3D ResNet"""
    return ResNet3D(Bottleneck3D, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

三、医学影像数据处理与增强

3.1 医学影像数据集

医学影像数据的组织通常比普通图像复杂。让我们首先了解常见的CT数据格式:

  • DICOM格式:医学影像的标准格式,包含图像和患者信息
  • NIfTI格式:神经影像常用格式,多用于研究
  • NRRD格式:适合存储多维医学数据

对于肺部CT,我们需要处理的是一系列横截面图像,每个患者可能有几十到几百张切片。

3.2 数据预处理

医学影像预处理通常包括以下步骤:

  1. 数据读取:解析DICOM或其他医学影像格式
  2. 窗口化:调整CT值范围以突出感兴趣的组织(肺窗通常为-1000到400HU)
  3. 重采样:将不同分辨率的CT统一到相同的体素大小
  4. 切割:去除无关区域,只保留肺部
  5. 标准化:将像素值归一化到适合神经网络的范围
import numpy as np
import pydicom
import glob
import os
import SimpleITK as sitk
from skimage import measure
from scipy import ndimage

def load_dicom_series(directory):
    """
    加载DICOM系列文件并转换为3D体积
    
    参数:
        directory: 包含DICOM文件的目录
    
    返回:
        3D numpy数组,形状为 [深度, 高度, 宽度]
    """
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(directory)
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    
    # 转换为numpy数组
    array = sitk.GetArrayFromImage(image)
    return array, image

def apply_lung_window(ct_scan, min_bound=-1000, max_bound=400):
    """
    应用肺窗口值
    
    参数:
        ct_scan: CT扫描的3D numpy数组
        min_bound: HU值下限
        max_bound: HU值上限
    
    返回:
        窗口化和归一化后的CT扫描
    """
    # 截断HU值
    ct_scan = np.clip(ct_scan, min_bound, max_bound)
    
    # 归一化到[0,1]
    ct_scan = (ct_scan - min_bound) / (max_bound - min_bound)
    
    return ct_scan

def resample_volume(img, spacing, new_spacing=[1.0, 1.0, 1.0]):
    """
    重采样CT体积到指定的体素间距
    
    参数:
        img: SimpleITK图像对象
        spacing: 原始体素间距
        new_spacing: 目标体素间距
        
    返回:
        重采样后的SimpleITK图像对象
    """
    # 计算新的尺寸
    spacing = np.array(spacing)
    new_spacing = np.array(new_spacing)
    orig_size = np.array(img.GetSize())
    resize_factor = spacing / new_spacing
    new_size = orig_size * resize_factor
    new_size = np.round(new_size).astype(int)
    
    # 执行重采样
    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(new_spacing)
    resample.SetSize(new_size.tolist())
    resample.SetOutputDirection(img.GetDirection())
    resample.SetOutputOrigin(img.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(img.GetPixelIDValue())
    resample.SetInterpolator(sitk.sitkLinear)
    
    return resample.Execute(img)

def segment_lungs(ct_scan, fill_lung_structures=True):
    """
    分割肺部区域
    
    参数:
        ct_scan: CT扫描的3D numpy数组
        fill_lung_structures: 是否填充肺内结构
        
    返回:
        肺部掩码和应用掩码后的CT扫描
    """
    # 阈值化得到二值图像
    binary_image = np.array(ct_scan < -320, dtype=np.int8)
    
    # 标记所有区域
    labels = measure.label(binary_image)
    
    # 假设肺区域不是最大的连通区域
    # 背景通常是最大的连通区域
    background_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
    binary_image[labels == background_label] = 0
    
    # 用形态学闭操作填充肺内结构
    if fill_lung_structures:
        for i in range(ct_scan.shape[0]):
            slice = binary_image[i]
            binary_image[i] = ndimage.binary_fill_holes(slice)
    
    # 创建肺部掩码
    lung_mask = binary_image
    
    # 应用掩码到原始CT扫描
    masked_ct = ct_scan * lung_mask
    
    return lung_mask, masked_ct

def normalize_scan(ct_scan):
    """
    标准化CT扫描值到0-1范围
    
    参数:
        ct_scan: CT扫描的3D numpy数组
        
    返回:
        标准化后的CT扫描
    """
    ct_scan = ct_scan.astype(np.float32)
    ct_scan = (ct_scan - np.min(ct_scan)) / (np.max(ct_scan) - np.min(ct_scan))
    return ct_scan

def preprocess_ct_scan(dicom_dir, output_size=(128, 128, 128)):
    """
    完整的CT扫描预处理流程
    
    参数:
        dicom_dir: DICOM文件目录
        output_size: 输出体积的尺寸
        
    返回:
        预处理后的CT扫描,准备用于深度学习模型
    """
    # 加载DICOM文件
    ct_array, ct_image = load_dicom_series(dicom_dir)
    
    # 应用肺窗口值
    windowed_ct = apply_lung_window(ct_array)
    
    # 重采样到统一分辨率
    spacing = ct_image.GetSpacing()
    resampled_ct_image = resample_volume(ct_image, spacing)
    resampled_ct = sitk.GetArrayFromImage(resampled_ct_image)
    
    # 肺部分割
    lung_mask, masked_ct = segment_lungs(resampled_ct)
    
    # 标准化
    normalized_ct = normalize_scan(masked_ct)
    
    # 调整到目标大小
    # 这里使用简单的缩放,实际应用中可能需要更复杂的方法
    from scipy.ndimage import zoom
    resize_factor = np.array(output_size) / np.array(normalized_ct.shape)
    final_ct = zoom(normalized_ct, resize_factor, order=1)
    
    return final_ct

3.3 医学影像数据增强

医学影像的数据增强需要特别谨慎,不能引入不真实的变化。以下是适合肺部CT的数据增强策略:

增强方法描述适用性
旋转绕各轴小角度旋转高,肺部诊断通常不依赖方向
缩放轻微的体积缩放中,需保持合理的解剖结构比例
亮度/对比度调整轻微调整CT值窗口高,模拟不同的CT扫描仪参数
噪声添加添加高斯噪声中,应保持主要特征清晰
弹性变形局部非刚性变形低,可能引入不真实的病变形态
随机裁剪从原始体积中裁剪子块高,适合大体积数据
import torch
import numpy as np
from scipy.ndimage import rotate, zoom, shift
import elasticdeform
from torchvision import transforms

class CTAugmentation3D:
    """
    用于3D医学影像(特别是CT)的数据增强类
    """
    def __init__(self, 
                 rotation_range=(-10, 10),
                 scale_range=(0.9, 1.1),
                 shift_range=(-5, 5),
                 noise_factor=0.05,
                 brightness_range=(0.9, 1.1),
                 contrast_range=(0.9, 1.1),
                 p_rotation=0.5,
                 p_scale=0.5,
                 p_shift=0.5, 
                 p_noise=0.3,
                 p_brightness=0.3,
                 p_contrast=0.3,
                 p_elastic=0.2):
        """
        初始化3D增强器
        
        参数:
            rotation_range: 旋转角度范围(度)
            scale_range: 缩放因子范围
            shift_range: 平移像素范围
            noise_factor: 噪声强度系数
            brightness_range: 亮度调整范围
            contrast_range: 对比度调整范围
            p_*: 各增强方法的应用概率
        """
        self.rotation_range = rotation_range
        self.scale_range = scale_range
        self.shift_range = shift_range
        self.noise_factor = noise_factor
        self.brightness_range = brightness_range
        self.contrast_range = contrast_range
        
        self.p_rotation = p_rotation
        self.p_scale = p_scale
        self.p_shift = p_shift
        self.p_noise = p_noise
        self.p_brightness = p_brightness
        self.p_contrast = p_contrast
        self.p_elastic = p_elastic
    
    def apply_rotation(self, volume):
        """应用随机旋转"""
        # 为每个轴随机生成旋转角度
        angles = [np.random.uniform(self.rotation_range[0], 
                                   self.rotation_range[1]) for _ in range(3)]
        
        # 沿着每个轴旋转
        for i, angle in enumerate(angles):
            axes = tuple([j for j in range(3) if j != i])
            volume = rotate(volume, angle, axes=axes, reshape=False, order=1, mode='nearest')
        
        return volume
    
    def apply_scaling(self, volume):
        """应用随机缩放"""
        # 为每个维度随机生成缩放因子
        scale_factor = np.random.uniform(self.scale_range[0], self.scale_range[1])
        
        # 应用缩放
        volume = zoom(volume, scale_factor, order=1)
        
        # 确保大小一致(如果缩放后尺寸变化)
        if volume.shape != self.original_shape:
            # 计算需要裁剪或padding的量
            diffs = np.array(volume.shape) - np.array(self.original_shape)
            
            # 裁剪或padding
            result = np.zeros(self.original_shape)
            
            # 为每个维度确定切片范围
            slices_src = []
            slices_dst = []
            
            for i in range(3):
                if diffs[i] > 0:  # 需要裁剪
                    # 从中心裁剪
                    start_src = diffs[i] // 2
                    end_src = start_src + self.original_shape[i]
                    start_dst = 0
                    end_dst = self.original_shape[i]
                else:  # 需要padding
                    # 中心padding
                    start_src = 0
                    end_src = volume.shape[i]
                    start_dst = -diffs[i] // 2
                    end_dst = start_dst + volume.shape[i]
                
                slices_src.append(slice(start_src, end_src))
                slices_dst.append(slice(start_dst, end_dst))
            
            # 将缩放后的体积复制到结果中
            result[tuple(slices_dst)] = volume[tuple(slices_src)]
            volume = result
        
        return volume
    
    def apply_shift(self, volume):
        """应用随机平移"""
        shifts = [np.random.uniform(self.shift_range[0], 
                                   self.shift_range[1]) for _ in range(3)]
        return shift(volume, shifts, order=1, mode='nearest')
    
    def apply_noise(self, volume):
        """添加高斯噪声"""
        noise = np.random.normal(0, self.noise_factor, volume.shape)
        volume = volume + noise
        volume = np.clip(volume, 0, 1)  # 确保值在有效范围内
        return volume
    
    def apply_brightness(self, volume):
        """调整亮度"""
        factor = np.random.uniform(self.brightness_range[0], self.brightness_range[1])
        volume = volume * factor
        volume = np.clip(volume, 0, 1)
        return volume
    
    def apply_contrast(self, volume):
        """调整对比度"""
        factor = np.random.uniform(self.contrast_range[0], self.contrast_range[1])
        mean = np.mean(volume)
        volume = (volume - mean) * factor + mean
        volume = np.clip(volume, 0, 1)
        return volume
    
    def apply_elastic_deformation(self, volume):
        """应用弹性变形"""
        # 为3D体积生成变形场
        # sigma控制变形的平滑度,较大的值产生更平滑的变形
        # points控制变形网格的粗细,较大的值产生更精细的变形
        deformed_volume = elasticdeform.deform_random_grid(
            volume, 
            sigma=3, 
            points=3,
            order=1,
            mode='nearest'
        )
        return deformed_volume
    
    def __call__(self, volume):
        """
        对输入的3D体积应用增强
        
        参数:
            volume: numpy数组,形状为[D, H, W]
            
        返回:
            增强后的体积
        """
        # 保存原始形状以便于缩放后的形状修正
        self.original_shape = volume.shape
        
        # 应用各种增强,每种增强都有一定概率应用
        if np.random.random() < self.p_rotation:
            volume = self.apply_rotation(volume)
            
        if np.random.random() < self.p_scale:
            volume = self.apply_scaling(volume)
            
        if np.random.random() < self.p_shift:
            volume = self.apply_shift(volume)
            
        if np.random.random() < self.p_noise:
            volume = self.apply_noise(volume)
            
        if np.random.random() < self.p_brightness:
            volume = self.apply_brightness(volume)
            
        if np.random.random() < self.p_contrast:
            volume = self.apply_contrast(volume)
            
        if np.random.random() < self.p_elastic:
            volume = self.apply_elastic_deformation(volume)
        
        return volume

# PyTorch的3D CT数据集类
class LungCTDataset(torch.utils.data.Dataset):
    def __init__(self, ct_paths, labels=None, transform=None, phase='train'):
        """
        肺部CT数据集
        
        参数:
            ct_paths: CT数据路径列表
            labels: 对应的标签列表
            transform: 数据增强和转换
            phase: 'train', 'val' 或 'test'
        """
        self.ct_paths = ct_paths
        self.labels = labels
        self.transform = transform
        self.phase = phase
        
    def __len__(self):
        return len(self.ct_paths)
    
    def __getitem__(self, idx):
        # 加载预处理好的CT体积
        # 假设每个路径是一个.npy文件,包含预处理好的CT体积
        ct_volume = np.load(self.ct_paths[idx])
        
        # 应用数据增强
        if self.transform and self.phase == 'train':
            ct_volume = self.transform(ct_volume)
        
        # 确保数据是浮点数并且形状正确([C, D, H, W])
        ct_volume = ct_volume.astype(np.float32)
        ct_volume = np.expand_dims(ct_volume, axis=0)  # 添加通道维度
        
        # 转换为PyTorch张量
        ct_tensor = torch.from_numpy(ct_volume)
        
        # 返回数据和标签(如果有)
        if self.labels is not None:
            label = self.labels[idx]
            return ct_tensor, label
        else:
            return ct_tensor

四、处理类别不平衡的损失函数设计

4.1 常见的类别不平衡问题解决方案

方法描述优点缺点
欠采样减少多数类样本减少训练时间丢失信息,模型可能欠拟合
过采样增加少数类样本保留所有数据可能过拟合少数类
合成样本生成如SMOTE算法生成少数类样本平衡数据集不丢失信息生成样本可能不真实
类别权重在损失函数中给少数类更高权重简单有效,保留所有数据需要调整权重参数
焦点损失 (Focal Loss)关注难分类样本自动调整不同样本的权重需要调整超参数
组合采样结合欠采样和过采样平衡各方法的优缺点实现较复杂

在医学影像中,由于数据珍贵且获取成本高,我们通常不会采用单纯的欠采样。而是倾向于损失函数调整和过采样的组合方法。

4.2 特定的损失函数设计

对于肺部CT分析,我们将设计几种适合类别不平衡的损失函数:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class WeightedCrossEntropyLoss(nn.Module):
    """
    带类别权重的交叉熵损失
    适合处理类别不平衡问题
    """
    def __init__(self, weight=None, reduction='mean'):
        """
        参数:
            weight: 各类别的权重,通常少数类权重更高
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(WeightedCrossEntropyLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
    
    def forward(self, input, target):
        return F.cross_entropy(
            input, target, 
            weight=self.weight, 
            reduction=self.reduction
        )

class FocalLoss(nn.Module):
    """
    Focal Loss(聚焦损失)
    通过降低易分类样本的权重,关注难以分类的样本
    """
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        """
        参数:
            alpha: 各类别的权重
            gamma: 聚焦参数,越大对易分类样本的惩罚越大
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class DiceLoss(nn.Module):
    """
    Dice Loss
    常用于医学图像分割,也适用于分类问题
    """
    def __init__(self, smooth=1.0, reduction='mean'):
        """
        参数:
            smooth: 平滑系数,防止分母为0
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.reduction = reduction
        
    def forward(self, input, target):
        # 将预测值转换为概率
        prob = F.softmax(input, dim=1)
        
        # 将目标转换为one-hot编码
        target = F.one_hot(target, num_classes=input.size(1)).float()
        target = target.permute(0, 3, 1, 2)
        
        # 计算Dice系数
        numerator = 2 * torch.sum(prob * target, dim=(2, 3))
        denominator = torch.sum(prob + target, dim=(2, 3)) + self.smooth
        dice_coeff = numerator / denominator
        dice_loss = 1 - dice_coeff
        
        if self.reduction == 'mean':
            return dice_loss.mean()
        elif self.reduction == 'sum':
            return dice_loss.sum()
        else:
            return dice_loss

class CombinedLoss(nn.Module):
    """
    结合Focal Loss和Dice Loss的复合损失
    综合利用两种损失函数的优点
    """
    def __init__(self, alpha=None, gamma=2.0, weight_focal=0.5, weight_dice=0.5, smooth=1.0):
        """
        参数:
            alpha: Focal Loss的类别权重
            gamma: Focal Loss的聚焦参数
            weight_focal: Focal Loss的权重
            weight_dice: Dice Loss的权重
            smooth: Dice Loss的平滑系数
        """
        super(CombinedLoss, self).__init__()
        self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
        self.dice_loss = DiceLoss(smooth=smooth)
        self.weight_focal = weight_focal
        self.weight_dice = weight_dice
        
    def forward(self, input, target):
        return (
            self.weight_focal * self.focal_loss(input, target) + 
            self.weight_dice * self.dice_loss(input, target)
        )

class AsymmetricLoss(nn.Module):
    """
    非对称损失
    对不同类别使用不同的gamma值,更加灵活地处理类别不平衡
    """
    def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, reduction='mean'):
        """
        参数:
            gamma_pos: 正类的gamma值
            gamma_neg: 负类的gamma值
            clip: 截断阈值
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(AsymmetricLoss, self).__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.clip = clip
        self.reduction = reduction
    
    def forward(self, input, target):
        # 将目标转换为one-hot编码
        target = F.one_hot(target, num_classes=input.size(1)).float()
        
        # Sigmoid输出
        prob = torch.sigmoid(input)
        
        # 裁剪概率,增加数值稳定性
        prob = torch.clamp(prob, self.clip, 1.0 - self.clip)
        
        # 计算正样本和负样本的聚焦因子
        pos_loss = target * torch.log(prob) * (1 - prob) ** self.gamma_pos
        neg_loss = (1 - target) * torch.log(1 - prob) * prob ** self.gamma_neg
        
        loss = -(pos_loss + neg_loss)
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

def calculate_class_weights(labels, method='inverse', beta=0.999):
    """
    计算类别权重
    
    参数:
        labels: 训练集标签列表
        method: 计算方法,'inverse'(反比例)或'effective'(有效样本数)
        beta: 有效样本数方法的平衡因子
        
    返回:
        各类别的权重
    """
    # 计算每个类别的样本数
    class_counts = np.bincount(labels)
    n_classes = len(class_counts)
    
    if method == 'inverse':
        # 权重与类别频率成反比
        weights = 1.0 / np.array(class_counts)
        # 归一化权重
        weights = weights / np.sum(weights) * n_classes
    
    elif method == 'effective':
        # 使用有效样本数计算权重
        effective_num = 1.0 - np.power(beta, class_counts)
        weights = (1.0 - beta) / effective_num
        # 归一化权重
        weights = weights / np.sum(weights) * n_classes
    
    return torch.FloatTensor(weights)

4.3 损失函数的选择策略

在肺部CT诊断中,不同损失函数的适用场景:

损失函数适用场景优势
加权交叉熵中度不平衡简单有效,易于理解和调整
Focal Loss高度不平衡自适应关注难例,减少易分样本影响
Dice Loss二分类问题不受类别比例影响,适合评估重叠度
组合损失复杂不平衡综合多种损失函数优点
非对称损失极度不平衡对正负类分别调整焦点参数

一个经验法则是:当阳性样本比例<10%时,考虑使用Focal Loss或组合损失;当比例在10%-30%之间时,加权交叉熵通常足够;如果更关注召回率,Dice Loss是个不错的选择。

五、完整训练流程

现在,让我们将前面的组件整合起来,构建一个完整的肺部CT分析训练流程:

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import time
import random
from tensorboardX import SummaryWriter

# 设置随机种子,确保可重复性
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class LungCTTrainer:
    def __init__(self, model, train_dataset, val_dataset, test_dataset=None, 
                 batch_size=8, lr=0.001, loss_fn=None, device=None, 
                 class_weights=None, experiment_name="lung_ct_analysis"):
        """
        肺部CT分析训练器
        
        参数:
            model: 3D ResNet模型
            train_dataset: 训练数据集
            val_dataset: 验证数据集
            test_dataset: 测试数据集
            batch_size: 批处理大小
            lr: 学习率
            loss_fn: 损失函数
            device: 训练设备
            class_weights: 类别权重
            experiment_name: 实验名称
        """
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size
        self.lr = lr
        
        # 设置设备
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # 将模型移动到设备上
        self.model = self.model.to(self.device)
        
        # 设置损失函数
        if loss_fn is None:
            if class_weights is not None:
                self.loss_fn = WeightedCrossEntropyLoss(weight=class_weights.to(self.device))
            else:
                self.loss_fn = nn.CrossEntropyLoss()
        else:
            self.loss_fn = loss_fn
        
        # 设置优化器
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        # 学习率调度器
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )
        
        # 创建数据加载器
        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, 
            num_workers=4, pin_memory=True
        )
        self.val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False, 
            num_workers=4, pin_memory=True
        )
        if test_dataset:
            self.test_loader = DataLoader(
                test_dataset, batch_size=batch_size, shuffle=False, 
                num_workers=4, pin_memory=True
            )
        else:
            self.test_loader = None
        
        # 设置TensorBoard
        self.writer = SummaryWriter(f"runs/{experiment_name}_{time.strftime('%Y%m%d_%H%M%S')}")
        
        # 训练状态跟踪
        self.best_val_loss = float('inf')
        self.best_model_path = f"models/{experiment_name}_best_model.pth"
        self.early_stop_patience = 15
        self.early_stop_counter = 0
        
        # 创建保存模型的目录
        os.makedirs("models", exist_ok=True)
    
    def train_epoch(self, epoch):
        """训练一个epoch"""
        self.model.train()
        running_loss = 0.0
        all_preds = []
        all_targets = []
        
        # 使用tqdm创建进度条
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1} [Train]")
        
        for inputs, targets in pbar:
            # 将数据移到设备上
            inputs = inputs.to(self.device, non_blocking=True)
            targets = targets.to(self.device, non_blocking=True)
            
            # 清零梯度
            self.optimizer.zero_grad()
            
            # 前向传播
            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, targets)
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪,防止梯度爆炸
            nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # 更新参数
            self.optimizer.step()
            
            # 统计
            running_loss += loss.item() * inputs.size(0)
            
            # 收集预测和目标
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            
            # 更新进度条
            pbar.set_postfix({"loss": loss.item()})
        
        # 计算平均损失和评估指标
        epoch_loss = running_loss / len(self.train_dataset)
        epoch_acc = accuracy_score(all_targets, all_preds)
        epoch_prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        
        # 记录到TensorBoard
        self.writer.add_scalar('Loss/train', epoch_loss, epoch)
        self.writer.add_scalar('Accuracy/train', epoch_acc, epoch)
        self.writer.add_scalar('Precision/train', epoch_prec, epoch)
        self.writer.add_scalar('Recall/train', epoch_recall, epoch)
        self.writer.add_scalar('F1/train', epoch_f1, epoch)
        
        return epoch_loss, epoch_acc, epoch_prec, epoch_recall, epoch_f1
    
    def validate_epoch(self, epoch):
        """验证一个epoch"""
        self.model.eval()
        running_loss = 0.0
        all_preds = []
        all_targets = []
        all_probs = []
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc=f"Epoch {epoch+1} [Val]")
            for inputs, targets in pbar:
                # 将数据移到设备上
                inputs = inputs.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)
                
                # 前向传播
                outputs = self.model(inputs)
                loss = self.loss_fn(outputs, targets)
                
                # 统计
                running_loss += loss.item() * inputs.size(0)
                
                # 收集预测、目标和概率
                probs = torch.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
                
                # 更新进度条
                pbar.set_postfix({"loss": loss.item()})
        
        # 计算平均损失和评估指标
        epoch_loss = running_loss / len(self.val_dataset)
        epoch_acc = accuracy_score(all_targets, all_preds)
        epoch_prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        
        # 如果是二分类问题,计算AUC
        if len(np.unique(all_targets)) == 2:
            epoch_auc = roc_auc_score(all_targets, np.array(all_probs)[:, 1])
            self.writer.add_scalar('AUC/val', epoch_auc, epoch)
        else:
            epoch_auc = None
        
        # 记录到TensorBoard
        self.writer.add_scalar('Loss/val', epoch_loss, epoch)
        self.writer.add_scalar('Accuracy/val', epoch_acc, epoch)
        self.writer.add_scalar('Precision/val', epoch_prec, epoch)
        self.writer.add_scalar('Recall/val', epoch_recall, epoch)
        self.writer.add_scalar('F1/val', epoch_f1, epoch)
        
        # 更新学习率
        self.scheduler.step(epoch_loss)
        
        # 保存最佳模型
        if epoch_loss < self.best_val_loss:
            self.best_val_loss = epoch_loss
            torch.save(self.model.state_dict(), self.best_model_path)
            print(f"Best model saved with val loss: {epoch_loss:.4f}")
            self.early_stop_counter = 0
        else:
            self.early_stop_counter += 1
        
        return epoch_loss, epoch_acc, epoch_prec, epoch_recall, epoch_f1, epoch_auc
    
    def train(self, epochs=100):
        """训练模型"""
        print(f"Starting training for {epochs} epochs...")
        
        # 训练历史记录
        history = {
            'train_loss': [], 'train_acc': [], 'train_prec': [],
            'train_recall': [], 'train_f1': [],
            'val_loss': [], 'val_acc': [], 'val_prec': [],
            'val_recall': [], 'val_f1': [], 'val_auc': []
        }
        
        # 训练循环
        for epoch in range(epochs):
            # 训练阶段
            train_loss, train_acc, train_prec, train_recall, train_f1 = self.train_epoch(epoch)
            
            # 验证阶段
            val_loss, val_acc, val_prec, val_recall, val_f1, val_auc = self.validate_epoch(epoch)
            
            # 记录历史
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['train_prec'].append(train_prec)
            history['train_recall'].append(train_recall)
            history['train_f1'].append(train_f1)
            
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['val_prec'].append(val_prec)
            history['val_recall'].append(val_recall)
            history['val_f1'].append(val_f1)
            history['val_auc'].append(val_auc)
            
            # 打印当前结果
            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
            if val_auc:
                print(f"Val AUC: {val_auc:.4f}")
            print("-" * 50)
            
            # 早停检查
            if self.early_stop_counter >= self.early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        # 训练完成,关闭TensorBoard writer
        self.writer.close()
        
        # 绘制训练历史
        self.plot_training_history(history)
        
        return history
    
    def plot_training_history(self, history):
        """绘制训练历史"""
        # 创建一个2x2的子图布局
        fig, axes = plt.subplots(2, 2, figsize=(18, 12))
        
        # 损失图
        axes[0, 0].plot(history['train_loss'], label='Train Loss')
        axes[0, 0].plot(history['val_loss'], label='Val Loss')
        axes[0, 0].set_title('Loss')
        axes[0, 0].set_xlabel('Epochs')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # 准确率图
        axes[0, 1].plot(history['train_acc'], label='Train Accuracy')
        axes[0, 1].plot(history['val_acc'], label='Val Accuracy')
        axes[0, 1].set_title('Accuracy')
        axes[0, 1].set_xlabel('Epochs')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # F1分数图
        axes[1, 0].plot(history['train_f1'], label='Train F1')
        axes[1, 0].plot(history['val_f1'], label='Val F1')
        axes[1, 0].set_title('F1 Score')
        axes[1, 0].set_xlabel('Epochs')
        axes[1, 0].set_ylabel('F1 Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # AUC图(如果有)
        if None not in history['val_auc']:
            axes[1, 1].plot(history['val_auc'], label='Val AUC')
            axes[1, 1].set_title('AUC')
            axes[1, 1].set_xlabel('Epochs')
            axes[1, 1].set_ylabel('AUC')
            axes[1, 1].legend()
            axes[1, 1].grid(True)
        else:
            # 如果没有AUC,可以绘制其他指标
            axes[1, 1].plot(history['train_prec'], label='Train Precision')
            axes[1, 1].plot(history['val_prec'], label='Val Precision')
            axes[1, 1].set_title('Precision')
            axes[1, 1].set_xlabel('Epochs')
            axes[1, 1].set_ylabel('Precision')
            axes[1, 1].legend()
            axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig('training_history.png')
        plt.show()
    
    def test(self, load_best_model=True):
        """测试模型"""
        if self.test_loader is None:
            print("No test dataset provided.")
            return None
        
        # 加载最佳模型
        if load_best_model:
            self.model.load_state_dict(torch.load(self.best_model_path))
            print(f"Loaded best model from {self.best_model_path}")
        
        self.model.eval()
        all_preds = []
        all_targets = []
        all_probs = []
        
        with torch.no_grad():
            for inputs, targets in tqdm(self.test_loader, desc="Testing"):
                # 将数据移到设备上
                inputs = inputs.to(self.device, non_blocking=True)
                
                # 前向传播
                outputs = self.model(inputs)
                probs = torch.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.numpy())
                all_probs.extend(probs.cpu().numpy())
        
        # 计算评估指标
        acc = accuracy_score(all_targets, all_preds)
        prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
        f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        
        # 如果是二分类问题,计算AUC
        if len(np.unique(all_targets)) == 2:
            auc = roc_auc_score(all_targets, np.array(all_probs)[:, 1])
        else:
            auc = None
        
        # 打印结果
        print("\nTest Results:")
        print(f"Accuracy: {acc:.4f}")
        print(f"Precision: {prec:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1 Score: {f1:.4f}")
        if auc:
            print(f"AUC: {auc:.4f}")
        
        return {
            'accuracy': acc,
            'precision': prec,
            'recall': recall,
            'f1': f1,
            'auc': auc,
            'predictions': all_preds,
            'targets': all_targets,
            'probabilities': all_probs
        }

# 使用示例
def main():
    # 设置随机种子
    set_seed(42)
    
    # 假设我们已经有预处理好的数据
    # 这里仅作示例,实际使用需替换为真实数据路径
    ct_paths = ["path/to/ct1.npy", "path/to/ct2.npy", "..."]
    labels = [0, 1, "..."]  # 0代表正常,1代表疾病
    
    # 计算类别权重
    class_weights = calculate_class_weights(labels, method='effective')
    
    # 创建数据增强器
    augmentation = CTAugmentation3D(
        rotation_range=(-10, 10),
        scale_range=(0.9, 1.1),
        shift_range=(-5, 5),
        noise_factor=0.03,
        brightness_range=(0.9, 1.1),
        contrast_range=(0.9, 1.1)
    )
    
    # 创建数据集
    full_dataset = LungCTDataset(ct_paths, labels, transform=augmentation)
    
    # 划分数据集
    train_size = int(0.7 * len(full_dataset))
    val_size = int(0.15 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size]
    )
    
    # 创建模型
    model = resnet18_3d(num_classes=2)
    
    # 创建损失函数
    # 对于严重类别不平衡,可以使用Focal Loss
    loss_fn = FocalLoss(alpha=class_weights, gamma=2.0)
    
    # 创建训练器
    trainer = LungCTTrainer(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        batch_size=8,
        lr=0.001,
        loss_fn=loss_fn,
        class_weights=class_weights,
        experiment_name="lung_ct_3d_resnet"
    )
    
    # 训练模型
    history = trainer.train(epochs=50)
    
    # 测试模型
    test_results = trainer.test()
    
    # 打印测试结果汇总
    print("\nTest Results Summary:")
    print(f"Accuracy: {test_results['accuracy']:.4f}")
    print(f"F1 Score: {test_results['f1']:.4f}")
    if test_results['auc']:
        print(f"AUC: {test_results['auc']:.4f}")

if __name__ == "__main__":
    main()

清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。

怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!

相关文章:

  • 衡阳企业网站建设价格投放广告
  • 怎样在门户网站做 推广国际实时新闻
  • WordPress多用户建站网站排名掉了怎么恢复
  • 网站如何做seo的他达拉非片正确服用方法
  • 一手楼房可以做哪个网站专业排名优化工具
  • 系统开发板上海何鹏seo
  • Unhandled exception: org.apache.poi.openxml4j.exceptions.InvalidFormatException
  • 混合精度策略在PBiCGStab算法中的应用
  • 数据蒸馏:Dataset Distillation by Matching Training Trajectories 论文翻译和理解
  • Redis数据结构之String
  • 蓝桥杯:对字符串处理常用知识笔记
  • 如何在Ubuntu上安装Dify
  • 控件主题效果添加程序设计
  • 【速写】SFT案例实操(以Qwen2.5-instruct-0.5B)
  • 24统计建模国奖论文写作框架(机器学习+图像识别类)
  • 搭建redis主从同步实现读写分离(原理剖析)
  • Day1:前端项目uni-app壁纸实战
  • Python-函数参数
  • (四)数据检索与增强生成——让对话系统更智能、更高效
  • 微软的 Copilot 现在可以浏览网页并为您执行操作
  • Qt中左侧项目菜单中构建设置功能中的构建步骤是怎么回事
  • 数字内容个性化推荐引擎构建
  • 计算机网络实验(包括实验指导书)
  • 可视化工具
  • STM32 × CLion 新建项目
  • 人工智能(AI)入门篇:什么是人工智能?什么是生成式人工智能?