PyTorch深度学习框架60天进阶学习计划 - 第36天:医疗影像诊断(一)
PyTorch深度学习框架60天进阶学习计划 - 第36天:医疗影像诊断(一)
朋友们!真没想到能写到第36天!今天我们要踏入一个既充满挑战又极具意义的领域——医疗影像诊断。我们将学习如何利用3D ResNet对肺部CT进行分析,探索适合医学图像的数据增强技术,并解决医疗数据中常见的类别不平衡问题。
医疗AI有一句玩笑:“普通的AI模型出错了,可能只是把猫识别成狗;医疗AI出错了,可能就把健康人送进了ICU。” 所以,让我们带着敬畏之心,开始今天的学习吧!
一、医疗影像诊断概述
医疗影像诊断是AI在医疗领域最有前景的应用之一。与普通图像不同,医疗影像通常具有以下特点:
- 维度多样:CT和MRI等医疗影像是3D数据,而不是简单的2D图像
- 数据稀缺:标注的医疗数据远少于普通图像数据集
- 类别不平衡:疾病样本通常远少于健康样本
- 高精度要求:医疗诊断对准确性要求极高,容错率低
今天我们将聚焦于肺部CT的分析,这在肺癌、肺炎和COVID-19等疾病诊断中有重要应用。
二、3D ResNet结构设计
2.1 为什么选择ResNet?
在医疗影像中,我们通常需要提取复杂的特征。ResNet的残差连接可以有效解决深层网络的梯度消失问题,使我们能够构建更深的网络。同时,医学特征往往需要从微小的变化中捕捉,ResNet良好的特征提取能力使其成为理想选择。
2.2 从2D到3D的转换
将2D ResNet转换为3D版本主要涉及以下变化:
2D组件 | 3D对应组件 | 变化说明 |
---|---|---|
Conv2d | Conv3d | 卷积核从(k×k)变为(k×k×k) |
MaxPool2d | MaxPool3d | 池化窗口从(k×k)变为(k×k×k) |
BatchNorm2d | BatchNorm3d | 归一化维度增加 |
Adaptive AvgPool2d | Adaptive AvgPool3d | 自适应池化维度增加 |
2.3 3D ResNet基本结构
我们的3D ResNet主要由以下部分组成:
- 初始卷积层:捕捉基本特征
- 残差块:提取复杂特征并解决梯度消失问题
- 全局池化层:降维并保留重要特征
- 全连接层:进行最终分类
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock3D(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(BasicBlock3D, self).__init__()
self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck3D(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(Bottleneck3D, self).__init__()
self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet3D(nn.Module):
def __init__(self, block, layers, num_classes=2, zero_init_residual=False):
super(ResNet3D, self).__init__()
self.in_planes = 64
# 初始卷积层
self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
# 残差层
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# 分类头
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# 权重初始化
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# 残差块特殊初始化
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck3D):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock3D):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.in_planes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.in_planes, planes * block.expansion, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm3d(planes * block.expansion),
)
layers = []
layers.append(block(self.in_planes, planes, stride, downsample))
self.in_planes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_planes, planes))
return nn.Sequential(*layers)
def forward(self, x):
# 输入处理和初始特征提取
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
# 特征提取网络
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 分类头
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet18_3d(num_classes=2, **kwargs):
"""18层3D ResNet"""
return ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=num_classes, **kwargs)
def resnet34_3d(num_classes=2, **kwargs):
"""34层3D ResNet"""
return ResNet3D(BasicBlock3D, [3, 4, 6, 3], num_classes=num_classes, **kwargs)
def resnet50_3d(num_classes=2, **kwargs):
"""50层3D ResNet"""
return ResNet3D(Bottleneck3D, [3, 4, 6, 3], num_classes=num_classes, **kwargs)
三、医学影像数据处理与增强
3.1 医学影像数据集
医学影像数据的组织通常比普通图像复杂。让我们首先了解常见的CT数据格式:
- DICOM格式:医学影像的标准格式,包含图像和患者信息
- NIfTI格式:神经影像常用格式,多用于研究
- NRRD格式:适合存储多维医学数据
对于肺部CT,我们需要处理的是一系列横截面图像,每个患者可能有几十到几百张切片。
3.2 数据预处理
医学影像预处理通常包括以下步骤:
- 数据读取:解析DICOM或其他医学影像格式
- 窗口化:调整CT值范围以突出感兴趣的组织(肺窗通常为-1000到400HU)
- 重采样:将不同分辨率的CT统一到相同的体素大小
- 切割:去除无关区域,只保留肺部
- 标准化:将像素值归一化到适合神经网络的范围
import numpy as np
import pydicom
import glob
import os
import SimpleITK as sitk
from skimage import measure
from scipy import ndimage
def load_dicom_series(directory):
"""
加载DICOM系列文件并转换为3D体积
参数:
directory: 包含DICOM文件的目录
返回:
3D numpy数组,形状为 [深度, 高度, 宽度]
"""
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(directory)
reader.SetFileNames(dicom_names)
image = reader.Execute()
# 转换为numpy数组
array = sitk.GetArrayFromImage(image)
return array, image
def apply_lung_window(ct_scan, min_bound=-1000, max_bound=400):
"""
应用肺窗口值
参数:
ct_scan: CT扫描的3D numpy数组
min_bound: HU值下限
max_bound: HU值上限
返回:
窗口化和归一化后的CT扫描
"""
# 截断HU值
ct_scan = np.clip(ct_scan, min_bound, max_bound)
# 归一化到[0,1]
ct_scan = (ct_scan - min_bound) / (max_bound - min_bound)
return ct_scan
def resample_volume(img, spacing, new_spacing=[1.0, 1.0, 1.0]):
"""
重采样CT体积到指定的体素间距
参数:
img: SimpleITK图像对象
spacing: 原始体素间距
new_spacing: 目标体素间距
返回:
重采样后的SimpleITK图像对象
"""
# 计算新的尺寸
spacing = np.array(spacing)
new_spacing = np.array(new_spacing)
orig_size = np.array(img.GetSize())
resize_factor = spacing / new_spacing
new_size = orig_size * resize_factor
new_size = np.round(new_size).astype(int)
# 执行重采样
resample = sitk.ResampleImageFilter()
resample.SetOutputSpacing(new_spacing)
resample.SetSize(new_size.tolist())
resample.SetOutputDirection(img.GetDirection())
resample.SetOutputOrigin(img.GetOrigin())
resample.SetTransform(sitk.Transform())
resample.SetDefaultPixelValue(img.GetPixelIDValue())
resample.SetInterpolator(sitk.sitkLinear)
return resample.Execute(img)
def segment_lungs(ct_scan, fill_lung_structures=True):
"""
分割肺部区域
参数:
ct_scan: CT扫描的3D numpy数组
fill_lung_structures: 是否填充肺内结构
返回:
肺部掩码和应用掩码后的CT扫描
"""
# 阈值化得到二值图像
binary_image = np.array(ct_scan < -320, dtype=np.int8)
# 标记所有区域
labels = measure.label(binary_image)
# 假设肺区域不是最大的连通区域
# 背景通常是最大的连通区域
background_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
binary_image[labels == background_label] = 0
# 用形态学闭操作填充肺内结构
if fill_lung_structures:
for i in range(ct_scan.shape[0]):
slice = binary_image[i]
binary_image[i] = ndimage.binary_fill_holes(slice)
# 创建肺部掩码
lung_mask = binary_image
# 应用掩码到原始CT扫描
masked_ct = ct_scan * lung_mask
return lung_mask, masked_ct
def normalize_scan(ct_scan):
"""
标准化CT扫描值到0-1范围
参数:
ct_scan: CT扫描的3D numpy数组
返回:
标准化后的CT扫描
"""
ct_scan = ct_scan.astype(np.float32)
ct_scan = (ct_scan - np.min(ct_scan)) / (np.max(ct_scan) - np.min(ct_scan))
return ct_scan
def preprocess_ct_scan(dicom_dir, output_size=(128, 128, 128)):
"""
完整的CT扫描预处理流程
参数:
dicom_dir: DICOM文件目录
output_size: 输出体积的尺寸
返回:
预处理后的CT扫描,准备用于深度学习模型
"""
# 加载DICOM文件
ct_array, ct_image = load_dicom_series(dicom_dir)
# 应用肺窗口值
windowed_ct = apply_lung_window(ct_array)
# 重采样到统一分辨率
spacing = ct_image.GetSpacing()
resampled_ct_image = resample_volume(ct_image, spacing)
resampled_ct = sitk.GetArrayFromImage(resampled_ct_image)
# 肺部分割
lung_mask, masked_ct = segment_lungs(resampled_ct)
# 标准化
normalized_ct = normalize_scan(masked_ct)
# 调整到目标大小
# 这里使用简单的缩放,实际应用中可能需要更复杂的方法
from scipy.ndimage import zoom
resize_factor = np.array(output_size) / np.array(normalized_ct.shape)
final_ct = zoom(normalized_ct, resize_factor, order=1)
return final_ct
3.3 医学影像数据增强
医学影像的数据增强需要特别谨慎,不能引入不真实的变化。以下是适合肺部CT的数据增强策略:
增强方法 | 描述 | 适用性 |
---|---|---|
旋转 | 绕各轴小角度旋转 | 高,肺部诊断通常不依赖方向 |
缩放 | 轻微的体积缩放 | 中,需保持合理的解剖结构比例 |
亮度/对比度调整 | 轻微调整CT值窗口 | 高,模拟不同的CT扫描仪参数 |
噪声添加 | 添加高斯噪声 | 中,应保持主要特征清晰 |
弹性变形 | 局部非刚性变形 | 低,可能引入不真实的病变形态 |
随机裁剪 | 从原始体积中裁剪子块 | 高,适合大体积数据 |
import torch
import numpy as np
from scipy.ndimage import rotate, zoom, shift
import elasticdeform
from torchvision import transforms
class CTAugmentation3D:
"""
用于3D医学影像(特别是CT)的数据增强类
"""
def __init__(self,
rotation_range=(-10, 10),
scale_range=(0.9, 1.1),
shift_range=(-5, 5),
noise_factor=0.05,
brightness_range=(0.9, 1.1),
contrast_range=(0.9, 1.1),
p_rotation=0.5,
p_scale=0.5,
p_shift=0.5,
p_noise=0.3,
p_brightness=0.3,
p_contrast=0.3,
p_elastic=0.2):
"""
初始化3D增强器
参数:
rotation_range: 旋转角度范围(度)
scale_range: 缩放因子范围
shift_range: 平移像素范围
noise_factor: 噪声强度系数
brightness_range: 亮度调整范围
contrast_range: 对比度调整范围
p_*: 各增强方法的应用概率
"""
self.rotation_range = rotation_range
self.scale_range = scale_range
self.shift_range = shift_range
self.noise_factor = noise_factor
self.brightness_range = brightness_range
self.contrast_range = contrast_range
self.p_rotation = p_rotation
self.p_scale = p_scale
self.p_shift = p_shift
self.p_noise = p_noise
self.p_brightness = p_brightness
self.p_contrast = p_contrast
self.p_elastic = p_elastic
def apply_rotation(self, volume):
"""应用随机旋转"""
# 为每个轴随机生成旋转角度
angles = [np.random.uniform(self.rotation_range[0],
self.rotation_range[1]) for _ in range(3)]
# 沿着每个轴旋转
for i, angle in enumerate(angles):
axes = tuple([j for j in range(3) if j != i])
volume = rotate(volume, angle, axes=axes, reshape=False, order=1, mode='nearest')
return volume
def apply_scaling(self, volume):
"""应用随机缩放"""
# 为每个维度随机生成缩放因子
scale_factor = np.random.uniform(self.scale_range[0], self.scale_range[1])
# 应用缩放
volume = zoom(volume, scale_factor, order=1)
# 确保大小一致(如果缩放后尺寸变化)
if volume.shape != self.original_shape:
# 计算需要裁剪或padding的量
diffs = np.array(volume.shape) - np.array(self.original_shape)
# 裁剪或padding
result = np.zeros(self.original_shape)
# 为每个维度确定切片范围
slices_src = []
slices_dst = []
for i in range(3):
if diffs[i] > 0: # 需要裁剪
# 从中心裁剪
start_src = diffs[i] // 2
end_src = start_src + self.original_shape[i]
start_dst = 0
end_dst = self.original_shape[i]
else: # 需要padding
# 中心padding
start_src = 0
end_src = volume.shape[i]
start_dst = -diffs[i] // 2
end_dst = start_dst + volume.shape[i]
slices_src.append(slice(start_src, end_src))
slices_dst.append(slice(start_dst, end_dst))
# 将缩放后的体积复制到结果中
result[tuple(slices_dst)] = volume[tuple(slices_src)]
volume = result
return volume
def apply_shift(self, volume):
"""应用随机平移"""
shifts = [np.random.uniform(self.shift_range[0],
self.shift_range[1]) for _ in range(3)]
return shift(volume, shifts, order=1, mode='nearest')
def apply_noise(self, volume):
"""添加高斯噪声"""
noise = np.random.normal(0, self.noise_factor, volume.shape)
volume = volume + noise
volume = np.clip(volume, 0, 1) # 确保值在有效范围内
return volume
def apply_brightness(self, volume):
"""调整亮度"""
factor = np.random.uniform(self.brightness_range[0], self.brightness_range[1])
volume = volume * factor
volume = np.clip(volume, 0, 1)
return volume
def apply_contrast(self, volume):
"""调整对比度"""
factor = np.random.uniform(self.contrast_range[0], self.contrast_range[1])
mean = np.mean(volume)
volume = (volume - mean) * factor + mean
volume = np.clip(volume, 0, 1)
return volume
def apply_elastic_deformation(self, volume):
"""应用弹性变形"""
# 为3D体积生成变形场
# sigma控制变形的平滑度,较大的值产生更平滑的变形
# points控制变形网格的粗细,较大的值产生更精细的变形
deformed_volume = elasticdeform.deform_random_grid(
volume,
sigma=3,
points=3,
order=1,
mode='nearest'
)
return deformed_volume
def __call__(self, volume):
"""
对输入的3D体积应用增强
参数:
volume: numpy数组,形状为[D, H, W]
返回:
增强后的体积
"""
# 保存原始形状以便于缩放后的形状修正
self.original_shape = volume.shape
# 应用各种增强,每种增强都有一定概率应用
if np.random.random() < self.p_rotation:
volume = self.apply_rotation(volume)
if np.random.random() < self.p_scale:
volume = self.apply_scaling(volume)
if np.random.random() < self.p_shift:
volume = self.apply_shift(volume)
if np.random.random() < self.p_noise:
volume = self.apply_noise(volume)
if np.random.random() < self.p_brightness:
volume = self.apply_brightness(volume)
if np.random.random() < self.p_contrast:
volume = self.apply_contrast(volume)
if np.random.random() < self.p_elastic:
volume = self.apply_elastic_deformation(volume)
return volume
# PyTorch的3D CT数据集类
class LungCTDataset(torch.utils.data.Dataset):
def __init__(self, ct_paths, labels=None, transform=None, phase='train'):
"""
肺部CT数据集
参数:
ct_paths: CT数据路径列表
labels: 对应的标签列表
transform: 数据增强和转换
phase: 'train', 'val' 或 'test'
"""
self.ct_paths = ct_paths
self.labels = labels
self.transform = transform
self.phase = phase
def __len__(self):
return len(self.ct_paths)
def __getitem__(self, idx):
# 加载预处理好的CT体积
# 假设每个路径是一个.npy文件,包含预处理好的CT体积
ct_volume = np.load(self.ct_paths[idx])
# 应用数据增强
if self.transform and self.phase == 'train':
ct_volume = self.transform(ct_volume)
# 确保数据是浮点数并且形状正确([C, D, H, W])
ct_volume = ct_volume.astype(np.float32)
ct_volume = np.expand_dims(ct_volume, axis=0) # 添加通道维度
# 转换为PyTorch张量
ct_tensor = torch.from_numpy(ct_volume)
# 返回数据和标签(如果有)
if self.labels is not None:
label = self.labels[idx]
return ct_tensor, label
else:
return ct_tensor
四、处理类别不平衡的损失函数设计
4.1 常见的类别不平衡问题解决方案
方法 | 描述 | 优点 | 缺点 |
---|---|---|---|
欠采样 | 减少多数类样本 | 减少训练时间 | 丢失信息,模型可能欠拟合 |
过采样 | 增加少数类样本 | 保留所有数据 | 可能过拟合少数类 |
合成样本生成 | 如SMOTE算法生成少数类样本 | 平衡数据集不丢失信息 | 生成样本可能不真实 |
类别权重 | 在损失函数中给少数类更高权重 | 简单有效,保留所有数据 | 需要调整权重参数 |
焦点损失 (Focal Loss) | 关注难分类样本 | 自动调整不同样本的权重 | 需要调整超参数 |
组合采样 | 结合欠采样和过采样 | 平衡各方法的优缺点 | 实现较复杂 |
在医学影像中,由于数据珍贵且获取成本高,我们通常不会采用单纯的欠采样。而是倾向于损失函数调整和过采样的组合方法。
4.2 特定的损失函数设计
对于肺部CT分析,我们将设计几种适合类别不平衡的损失函数:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class WeightedCrossEntropyLoss(nn.Module):
"""
带类别权重的交叉熵损失
适合处理类别不平衡问题
"""
def __init__(self, weight=None, reduction='mean'):
"""
参数:
weight: 各类别的权重,通常少数类权重更高
reduction: 'none', 'mean', 'sum'中的一个
"""
super(WeightedCrossEntropyLoss, self).__init__()
self.weight = weight
self.reduction = reduction
def forward(self, input, target):
return F.cross_entropy(
input, target,
weight=self.weight,
reduction=self.reduction
)
class FocalLoss(nn.Module):
"""
Focal Loss(聚焦损失)
通过降低易分类样本的权重,关注难以分类的样本
"""
def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
"""
参数:
alpha: 各类别的权重
gamma: 聚焦参数,越大对易分类样本的惩罚越大
reduction: 'none', 'mean', 'sum'中的一个
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input, target):
ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.alpha)
pt = torch.exp(-ce_loss)
focal_loss = (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class DiceLoss(nn.Module):
"""
Dice Loss
常用于医学图像分割,也适用于分类问题
"""
def __init__(self, smooth=1.0, reduction='mean'):
"""
参数:
smooth: 平滑系数,防止分母为0
reduction: 'none', 'mean', 'sum'中的一个
"""
super(DiceLoss, self).__init__()
self.smooth = smooth
self.reduction = reduction
def forward(self, input, target):
# 将预测值转换为概率
prob = F.softmax(input, dim=1)
# 将目标转换为one-hot编码
target = F.one_hot(target, num_classes=input.size(1)).float()
target = target.permute(0, 3, 1, 2)
# 计算Dice系数
numerator = 2 * torch.sum(prob * target, dim=(2, 3))
denominator = torch.sum(prob + target, dim=(2, 3)) + self.smooth
dice_coeff = numerator / denominator
dice_loss = 1 - dice_coeff
if self.reduction == 'mean':
return dice_loss.mean()
elif self.reduction == 'sum':
return dice_loss.sum()
else:
return dice_loss
class CombinedLoss(nn.Module):
"""
结合Focal Loss和Dice Loss的复合损失
综合利用两种损失函数的优点
"""
def __init__(self, alpha=None, gamma=2.0, weight_focal=0.5, weight_dice=0.5, smooth=1.0):
"""
参数:
alpha: Focal Loss的类别权重
gamma: Focal Loss的聚焦参数
weight_focal: Focal Loss的权重
weight_dice: Dice Loss的权重
smooth: Dice Loss的平滑系数
"""
super(CombinedLoss, self).__init__()
self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
self.dice_loss = DiceLoss(smooth=smooth)
self.weight_focal = weight_focal
self.weight_dice = weight_dice
def forward(self, input, target):
return (
self.weight_focal * self.focal_loss(input, target) +
self.weight_dice * self.dice_loss(input, target)
)
class AsymmetricLoss(nn.Module):
"""
非对称损失
对不同类别使用不同的gamma值,更加灵活地处理类别不平衡
"""
def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, reduction='mean'):
"""
参数:
gamma_pos: 正类的gamma值
gamma_neg: 负类的gamma值
clip: 截断阈值
reduction: 'none', 'mean', 'sum'中的一个
"""
super(AsymmetricLoss, self).__init__()
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.clip = clip
self.reduction = reduction
def forward(self, input, target):
# 将目标转换为one-hot编码
target = F.one_hot(target, num_classes=input.size(1)).float()
# Sigmoid输出
prob = torch.sigmoid(input)
# 裁剪概率,增加数值稳定性
prob = torch.clamp(prob, self.clip, 1.0 - self.clip)
# 计算正样本和负样本的聚焦因子
pos_loss = target * torch.log(prob) * (1 - prob) ** self.gamma_pos
neg_loss = (1 - target) * torch.log(1 - prob) * prob ** self.gamma_neg
loss = -(pos_loss + neg_loss)
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
def calculate_class_weights(labels, method='inverse', beta=0.999):
"""
计算类别权重
参数:
labels: 训练集标签列表
method: 计算方法,'inverse'(反比例)或'effective'(有效样本数)
beta: 有效样本数方法的平衡因子
返回:
各类别的权重
"""
# 计算每个类别的样本数
class_counts = np.bincount(labels)
n_classes = len(class_counts)
if method == 'inverse':
# 权重与类别频率成反比
weights = 1.0 / np.array(class_counts)
# 归一化权重
weights = weights / np.sum(weights) * n_classes
elif method == 'effective':
# 使用有效样本数计算权重
effective_num = 1.0 - np.power(beta, class_counts)
weights = (1.0 - beta) / effective_num
# 归一化权重
weights = weights / np.sum(weights) * n_classes
return torch.FloatTensor(weights)
4.3 损失函数的选择策略
在肺部CT诊断中,不同损失函数的适用场景:
损失函数 | 适用场景 | 优势 |
---|---|---|
加权交叉熵 | 中度不平衡 | 简单有效,易于理解和调整 |
Focal Loss | 高度不平衡 | 自适应关注难例,减少易分样本影响 |
Dice Loss | 二分类问题 | 不受类别比例影响,适合评估重叠度 |
组合损失 | 复杂不平衡 | 综合多种损失函数优点 |
非对称损失 | 极度不平衡 | 对正负类分别调整焦点参数 |
一个经验法则是:当阳性样本比例<10%时,考虑使用Focal Loss或组合损失;当比例在10%-30%之间时,加权交叉熵通常足够;如果更关注召回率,Dice Loss是个不错的选择。
五、完整训练流程
现在,让我们将前面的组件整合起来,构建一个完整的肺部CT分析训练流程:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import time
import random
from tensorboardX import SummaryWriter
# 设置随机种子,确保可重复性
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class LungCTTrainer:
def __init__(self, model, train_dataset, val_dataset, test_dataset=None,
batch_size=8, lr=0.001, loss_fn=None, device=None,
class_weights=None, experiment_name="lung_ct_analysis"):
"""
肺部CT分析训练器
参数:
model: 3D ResNet模型
train_dataset: 训练数据集
val_dataset: 验证数据集
test_dataset: 测试数据集
batch_size: 批处理大小
lr: 学习率
loss_fn: 损失函数
device: 训练设备
class_weights: 类别权重
experiment_name: 实验名称
"""
self.model = model
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.batch_size = batch_size
self.lr = lr
# 设置设备
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# 将模型移动到设备上
self.model = self.model.to(self.device)
# 设置损失函数
if loss_fn is None:
if class_weights is not None:
self.loss_fn = WeightedCrossEntropyLoss(weight=class_weights.to(self.device))
else:
self.loss_fn = nn.CrossEntropyLoss()
else:
self.loss_fn = loss_fn
# 设置优化器
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
# 学习率调度器
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
# 创建数据加载器
self.train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True
)
self.val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True
)
if test_dataset:
self.test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True
)
else:
self.test_loader = None
# 设置TensorBoard
self.writer = SummaryWriter(f"runs/{experiment_name}_{time.strftime('%Y%m%d_%H%M%S')}")
# 训练状态跟踪
self.best_val_loss = float('inf')
self.best_model_path = f"models/{experiment_name}_best_model.pth"
self.early_stop_patience = 15
self.early_stop_counter = 0
# 创建保存模型的目录
os.makedirs("models", exist_ok=True)
def train_epoch(self, epoch):
"""训练一个epoch"""
self.model.train()
running_loss = 0.0
all_preds = []
all_targets = []
# 使用tqdm创建进度条
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1} [Train]")
for inputs, targets in pbar:
# 将数据移到设备上
inputs = inputs.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)
# 清零梯度
self.optimizer.zero_grad()
# 前向传播
outputs = self.model(inputs)
loss = self.loss_fn(outputs, targets)
# 反向传播
loss.backward()
# 梯度裁剪,防止梯度爆炸
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.optimizer.step()
# 统计
running_loss += loss.item() * inputs.size(0)
# 收集预测和目标
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
# 更新进度条
pbar.set_postfix({"loss": loss.item()})
# 计算平均损失和评估指标
epoch_loss = running_loss / len(self.train_dataset)
epoch_acc = accuracy_score(all_targets, all_preds)
epoch_prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
epoch_f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
# 记录到TensorBoard
self.writer.add_scalar('Loss/train', epoch_loss, epoch)
self.writer.add_scalar('Accuracy/train', epoch_acc, epoch)
self.writer.add_scalar('Precision/train', epoch_prec, epoch)
self.writer.add_scalar('Recall/train', epoch_recall, epoch)
self.writer.add_scalar('F1/train', epoch_f1, epoch)
return epoch_loss, epoch_acc, epoch_prec, epoch_recall, epoch_f1
def validate_epoch(self, epoch):
"""验证一个epoch"""
self.model.eval()
running_loss = 0.0
all_preds = []
all_targets = []
all_probs = []
with torch.no_grad():
pbar = tqdm(self.val_loader, desc=f"Epoch {epoch+1} [Val]")
for inputs, targets in pbar:
# 将数据移到设备上
inputs = inputs.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)
# 前向传播
outputs = self.model(inputs)
loss = self.loss_fn(outputs, targets)
# 统计
running_loss += loss.item() * inputs.size(0)
# 收集预测、目标和概率
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# 更新进度条
pbar.set_postfix({"loss": loss.item()})
# 计算平均损失和评估指标
epoch_loss = running_loss / len(self.val_dataset)
epoch_acc = accuracy_score(all_targets, all_preds)
epoch_prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
epoch_f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
# 如果是二分类问题,计算AUC
if len(np.unique(all_targets)) == 2:
epoch_auc = roc_auc_score(all_targets, np.array(all_probs)[:, 1])
self.writer.add_scalar('AUC/val', epoch_auc, epoch)
else:
epoch_auc = None
# 记录到TensorBoard
self.writer.add_scalar('Loss/val', epoch_loss, epoch)
self.writer.add_scalar('Accuracy/val', epoch_acc, epoch)
self.writer.add_scalar('Precision/val', epoch_prec, epoch)
self.writer.add_scalar('Recall/val', epoch_recall, epoch)
self.writer.add_scalar('F1/val', epoch_f1, epoch)
# 更新学习率
self.scheduler.step(epoch_loss)
# 保存最佳模型
if epoch_loss < self.best_val_loss:
self.best_val_loss = epoch_loss
torch.save(self.model.state_dict(), self.best_model_path)
print(f"Best model saved with val loss: {epoch_loss:.4f}")
self.early_stop_counter = 0
else:
self.early_stop_counter += 1
return epoch_loss, epoch_acc, epoch_prec, epoch_recall, epoch_f1, epoch_auc
def train(self, epochs=100):
"""训练模型"""
print(f"Starting training for {epochs} epochs...")
# 训练历史记录
history = {
'train_loss': [], 'train_acc': [], 'train_prec': [],
'train_recall': [], 'train_f1': [],
'val_loss': [], 'val_acc': [], 'val_prec': [],
'val_recall': [], 'val_f1': [], 'val_auc': []
}
# 训练循环
for epoch in range(epochs):
# 训练阶段
train_loss, train_acc, train_prec, train_recall, train_f1 = self.train_epoch(epoch)
# 验证阶段
val_loss, val_acc, val_prec, val_recall, val_f1, val_auc = self.validate_epoch(epoch)
# 记录历史
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['train_prec'].append(train_prec)
history['train_recall'].append(train_recall)
history['train_f1'].append(train_f1)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
history['val_prec'].append(val_prec)
history['val_recall'].append(val_recall)
history['val_f1'].append(val_f1)
history['val_auc'].append(val_auc)
# 打印当前结果
print(f"Epoch {epoch+1}/{epochs}")
print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
if val_auc:
print(f"Val AUC: {val_auc:.4f}")
print("-" * 50)
# 早停检查
if self.early_stop_counter >= self.early_stop_patience:
print(f"Early stopping at epoch {epoch+1}")
break
# 训练完成,关闭TensorBoard writer
self.writer.close()
# 绘制训练历史
self.plot_training_history(history)
return history
def plot_training_history(self, history):
"""绘制训练历史"""
# 创建一个2x2的子图布局
fig, axes = plt.subplots(2, 2, figsize=(18, 12))
# 损失图
axes[0, 0].plot(history['train_loss'], label='Train Loss')
axes[0, 0].plot(history['val_loss'], label='Val Loss')
axes[0, 0].set_title('Loss')
axes[0, 0].set_xlabel('Epochs')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)
# 准确率图
axes[0, 1].plot(history['train_acc'], label='Train Accuracy')
axes[0, 1].plot(history['val_acc'], label='Val Accuracy')
axes[0, 1].set_title('Accuracy')
axes[0, 1].set_xlabel('Epochs')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)
# F1分数图
axes[1, 0].plot(history['train_f1'], label='Train F1')
axes[1, 0].plot(history['val_f1'], label='Val F1')
axes[1, 0].set_title('F1 Score')
axes[1, 0].set_xlabel('Epochs')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].legend()
axes[1, 0].grid(True)
# AUC图(如果有)
if None not in history['val_auc']:
axes[1, 1].plot(history['val_auc'], label='Val AUC')
axes[1, 1].set_title('AUC')
axes[1, 1].set_xlabel('Epochs')
axes[1, 1].set_ylabel('AUC')
axes[1, 1].legend()
axes[1, 1].grid(True)
else:
# 如果没有AUC,可以绘制其他指标
axes[1, 1].plot(history['train_prec'], label='Train Precision')
axes[1, 1].plot(history['val_prec'], label='Val Precision')
axes[1, 1].set_title('Precision')
axes[1, 1].set_xlabel('Epochs')
axes[1, 1].set_ylabel('Precision')
axes[1, 1].legend()
axes[1, 1].grid(True)
plt.tight_layout()
plt.savefig('training_history.png')
plt.show()
def test(self, load_best_model=True):
"""测试模型"""
if self.test_loader is None:
print("No test dataset provided.")
return None
# 加载最佳模型
if load_best_model:
self.model.load_state_dict(torch.load(self.best_model_path))
print(f"Loaded best model from {self.best_model_path}")
self.model.eval()
all_preds = []
all_targets = []
all_probs = []
with torch.no_grad():
for inputs, targets in tqdm(self.test_loader, desc="Testing"):
# 将数据移到设备上
inputs = inputs.to(self.device, non_blocking=True)
# 前向传播
outputs = self.model(inputs)
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_targets.extend(targets.numpy())
all_probs.extend(probs.cpu().numpy())
# 计算评估指标
acc = accuracy_score(all_targets, all_preds)
prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
# 如果是二分类问题,计算AUC
if len(np.unique(all_targets)) == 2:
auc = roc_auc_score(all_targets, np.array(all_probs)[:, 1])
else:
auc = None
# 打印结果
print("\nTest Results:")
print(f"Accuracy: {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
if auc:
print(f"AUC: {auc:.4f}")
return {
'accuracy': acc,
'precision': prec,
'recall': recall,
'f1': f1,
'auc': auc,
'predictions': all_preds,
'targets': all_targets,
'probabilities': all_probs
}
# 使用示例
def main():
# 设置随机种子
set_seed(42)
# 假设我们已经有预处理好的数据
# 这里仅作示例,实际使用需替换为真实数据路径
ct_paths = ["path/to/ct1.npy", "path/to/ct2.npy", "..."]
labels = [0, 1, "..."] # 0代表正常,1代表疾病
# 计算类别权重
class_weights = calculate_class_weights(labels, method='effective')
# 创建数据增强器
augmentation = CTAugmentation3D(
rotation_range=(-10, 10),
scale_range=(0.9, 1.1),
shift_range=(-5, 5),
noise_factor=0.03,
brightness_range=(0.9, 1.1),
contrast_range=(0.9, 1.1)
)
# 创建数据集
full_dataset = LungCTDataset(ct_paths, labels, transform=augmentation)
# 划分数据集
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
full_dataset, [train_size, val_size, test_size]
)
# 创建模型
model = resnet18_3d(num_classes=2)
# 创建损失函数
# 对于严重类别不平衡,可以使用Focal Loss
loss_fn = FocalLoss(alpha=class_weights, gamma=2.0)
# 创建训练器
trainer = LungCTTrainer(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
batch_size=8,
lr=0.001,
loss_fn=loss_fn,
class_weights=class_weights,
experiment_name="lung_ct_3d_resnet"
)
# 训练模型
history = trainer.train(epochs=50)
# 测试模型
test_results = trainer.test()
# 打印测试结果汇总
print("\nTest Results Summary:")
print(f"Accuracy: {test_results['accuracy']:.4f}")
print(f"F1 Score: {test_results['f1']:.4f}")
if test_results['auc']:
print(f"AUC: {test_results['auc']:.4f}")
if __name__ == "__main__":
main()
清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。
怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!