CV 医学影像分类、分割、目标检测,之【3D肝脏分割】项目拆解
CV 医学影像分类、分割、目标检测,之【3D肝脏分割】项目拆解
- 第1行:`from posixpath import join`
- 第2行:`from torch.utils.data import DataLoader`
- 第3行:`import os`
- 第4行:`import sys`
- 第5行:`import random`
- 第6行:`from torchvision.transforms import RandomCrop`
- 第7行:`import numpy as np`
- 第8行:`import SimpleITK as sitk`
- 第11行:`import torchvision`
- 第14行:`import glob`
- 第15行:`import pandas as pd`
- 第16行:`import matplotlib.pyplot as plt`
- 第17行:`from PIL import Image`
- 第18-23行:重复导入
- 第24行:`import PIL`
- 第26行:`images=os.listdir('D:/LiverDataset/image')`
- 第27行:`labels=os.listdir('D:/LiverDataset/label')`
- 第29-30行:`image_list=[]` `label_list=[]`
- 第32-35行:第一个for循环
- 第37-40行:第二个for循环
- 第42-46行:`getarrayFromslice`函数
- 第48-49行:调用函数
- 第52-99行:D3UnetData类
- 第52-53行:类定义
- 第54-56行:初始化
- 第58-59行:`__getitem__`方法
- 第62-63行:读取图像
- 第65-66行:切片选择
- 第68行:二值化
- 第70行:类型转换
- 第72-73行:转为张量
- 第75-76行:应用变换
- 第78行:返回
- 第80-81行:`__len__`方法
- 第83-99行:D3UnetData_test类
- 第103-105行:定义变换
- 第107-109行:创建数据集
- 第112-113行:创建DataLoader
- 第117-118行:测试加载
- 第121-122行:可视化
- 第125-126行:显示标签
- 第129-145行:DoubleConv类
- 第148-151行:测试DoubleConv
- 第154-161行:Down类(下采样)
- 第164-179行:Up类(上采样)
- 第182-187行:Out类
- 第190-216行:UNet3d类
- 第219-223行:创建模型
- 第228-229行:定义损失和优化器
- 第232-289行:train函数
- 第292-301行:训练循环
- 第305-318行:预测和可视化
第1行:from posixpath import join
问1:为什么要导入posixpath?
答1:处理文件路径的拼接。
问2:posixpath和os.path有什么区别?
答2:posixpath专门处理UNIX风格路径(用/),os.path会根据操作系统自动选择。
问3:这里为什么不用os.path.join?
答3:实际上这是个错误!后面代码没用到posixpath.join,应该删除或改成os.path。
问4:什么是POSIX?
答4:Portable Operating System Interface,可移植操作系统接口标准。
第2行:from torch.utils.data import DataLoader
问5:DataLoader的作用是什么?
答5:批量加载数据,打乱顺序,多进程读取。
问6:为什么需要批量加载?
答6:GPU并行计算需要批量数据,单个样本浪费算力。
问7:utils是什么意思?
答7:utilities的缩写,工具集。
第3行:import os
问8:os模块提供什么功能?
答8:操作系统接口,文件操作、路径处理、环境变量。
问9:为什么第1行导入了posixpath.join,这里又导入os?
答9:os提供更多功能如listdir,posixpath只处理路径。
第4行:import sys
问10:sys模块是做什么的?
答10:Python解释器相关的变量和函数。
问11:这段代码为什么导入sys但没用?
答11:可能是遗留代码,或准备用于sys.path添加路径。
第5行:import random
问12:random在深度学习中用来做什么?
答12:数据增强、打乱顺序、随机初始化。
问13:这里导入了但没用到,可能用途是什么?
答13:可能原计划做随机裁剪或随机选择训练样本。
第6行:from torchvision.transforms import RandomCrop
问14:RandomCrop是什么?
答14:随机裁剪图像的数据增强方法。
问15:为什么要随机裁剪?
答15:增加数据多样性,防止过拟合。
问16:这里导入了但没用,为什么?
答16:可能原计划用但后来改用Resize了。
第7行:import numpy as np
问17:为什么起别名np?
答17:约定俗成,简化书写。
问18:numpy和torch的关系是什么?
答18:numpy是CPU数组运算,torch支持GPU,可相互转换。
第8行:import SimpleITK as sitk
问19:ITK是什么的缩写?
答19:Insight Segmentation and Registration Toolkit。
问20:为什么叫Simple?
答20:简化版的ITK,C++库的Python封装。
问21:sitk相比其他图像库的优势?
答21:保留医学图像元数据:间距、方向、原点。
第11行:import torchvision
问23:torchvision提供什么?
答23:计算机视觉的数据集、模型、变换工具。
问24:vision是指什么?
答24:计算机视觉,让计算机"看懂"图像。
第14行:import glob
问27:glob是做什么的?
答27:文件路径模式匹配,如*.jpg找所有jpg文件。
问28:glob这个名字的由来?
答28:global的缩写,全局通配符扩展。
问29:这里导入glob但没用到,可能的原因?
答29:可能原来用glob查找文件,后来改用os.listdir。
第15行:import pandas as pd
问30:pandas在这里可能的用途?
答30:读取CSV格式的标注信息或记录训练日志。
问31:为什么没用到?
答31:可能改用直接读取文件夹的方式。
第16行:import matplotlib.pyplot as plt
问32:pyplot是什么?
答32:matplotlib的面向对象绘图接口。
问33:为什么叫pyplot?
答33:模仿MATLAB的plot函数设计的Python版本。
第17行:from PIL import Image
问34:PIL是什么的缩写?
答34:Python Imaging Library。
问35:为什么已有PIL还要导入SimpleITK?
答35:PIL处理普通图像,SimpleITK处理医学图像的3D体积和元数据。
第18-23行:重复导入
问36:这么多重复导入说明什么?
答36:代码没有经过清理,可能是多次实验的累积。
第24行:import PIL
问37:已经from PIL import Image,为什么还import PIL?
答37:可能想用PIL的其他功能,但实际没用到。
第26行:images=os.listdir('D:/LiverDataset/image')
问38:listdir返回什么?
答38:文件夹内所有文件和子文件夹的名称列表。
问39:为什么用D盘?
答39:Windows系统,D盘通常是数据盘,C盘是系统盘。
问40:LiverDataset说明什么?
答40:Liver是肝脏,这是肝脏分割数据集。
第27行:labels=os.listdir('D:/LiverDataset/label')
问41:label和image的对应关系是什么?
答41:同名文件,一个是原图,一个是标注。
第29-30行:image_list=[]
label_list=[]
问42:为什么要创建空列表?
答42:准备存储完整的文件路径。
问43:列表和数组的区别?
答43:列表可变长度、存储任意类型;数组固定类型、支持向量运算。
第32-35行:第一个for循环
for i in images:file_path='D:/LiverDataset/image/{}/{}'.format(i,i)print(file_path)image_list.append(file_path)
问44:format是什么?
答44:字符串格式化方法,{}是占位符。
问45:为什么路径是image/{}/{}
两个i?
答45:数据组织是:image文件夹/病人ID文件夹/病人ID文件。
问46:append是什么操作?
答46:在列表末尾添加元素。
问47:print的作用?
答47:调试用,确认路径正确。
第37-40行:第二个for循环
for i in labels:file_path='D:/LiverDataset/label/{}/{}'.format(i,i)print(file_path)label_list.append(file_path)
问48:这段代码有什么问题?
答48:没有检查image和label是否一一对应。
第42-46行:getarrayFromslice
函数
def getarrayFromslice(file_path):image=sitk.ReadImage(file_path)img_array=sitk.GetArrayFromImage(image)shape=img_array.shapeprint(shape)
问49:为什么函数名有From和slice?
答49:从切片文件获取数组,但命名不准确,应该是from volume。
问50:ReadImage读取的是什么格式?
答50:SimpleITK的Image对象,包含像素数据和元数据。
问51:GetArrayFromImage做了什么转换?
答51:从SimpleITK Image转为numpy array,丢弃元数据。
问52:shape会是什么样的?
答52:(切片数, 高度, 宽度),如(512, 512, 512)。
第48-49行:调用函数
for i in image_list:getarrayFromslice(i)
问53:这个循环的目的是什么?
答53:检查所有图像的尺寸,确保数据一致性。
问54:为什么只打印不保存?
答54:这是数据探索阶段,了解数据结构。
第52-99行:D3UnetData类
第52-53行:类定义
class D3UnetData(Dataset):def __init__(self,image_list,label_list,transformer):
问55:D3是什么意思?
答55:3D的意思,表示三维U-Net。
问56:为什么继承Dataset?
答56:PyTorch要求,实现__getitem__
和__len__
接口。
问57:transformer是什么?
答57:图像变换操作,如缩放、裁剪。
第54-56行:初始化
self.image_list=image_list
self.label_list=label_list
self.transformer=transformer
问58:self是什么?
答58:实例自身的引用,Python的面向对象机制。
第58-59行:__getitem__
方法
def __getitem__(self,index):image=self.image_list[index]label=self.label_list[index]
问59:index从哪里来?
答59:DataLoader自动生成,从0到len-1。
问60:为什么叫__getitem__
不叫get_item?
答60:Python魔术方法,支持[]索引操作。
第62-63行:读取图像
image_ct=sitk.ReadImage(image,sitk.sitkInt16)
label_ct=sitk.ReadImage(label,sitk.sitkInt8)
问61:sitkInt16是什么?
答61:16位有符号整数,范围-32768到32767。
问62:为什么image用Int16,label用Int8?
答62:CT值范围大需要16位,标签只有0/1用8位省内存。
问63:CT值的范围是多少?
答63:-1000(空气)到+3000(骨骼),单位是Hounsfield。
第65-66行:切片选择
ct_array=sitk.GetArrayFromImage(image_ct)[250:300]
label_array=sitk.GetArrayFromImage(label_ct)[250:300]
问64:为什么选250:300?
答64:肝脏在腹部中间位置,这50层大概覆盖肝脏区域。
问65:如果总共只有200层怎么办?
答65:会报错,代码没有边界检查。
第68行:二值化
label_array[label_array>0]=1
问66:这是什么操作?
答66:布尔索引,将所有大于0的值设为1。
问67:为什么要这样做?
答67:确保标签只有0和1两类,可能原始数据有多个器官标注。
第70行:类型转换
ct_array=ct_array.astype(np.float32)
问68:为什么转float32?
答68:神经网络计算需要浮点数,32位在精度和内存间平衡。
问69:float32和float64的区别?
答69:32位精度7位小数,64位精度15位,深度学习32位够用。
第72-73行:转为张量
ct_array=torch.FloatTensor(ct_array).unsqueeze(0)
label_array=torch.LongTensor(label_array)
问70:FloatTensor和LongTensor的区别?
答70:FloatTensor存浮点数,LongTensor存整数索引。
问71:unsqueeze(0)在哪个维度添加?
答71:第0维,[50,512,512]变[1,50,512,512]。
问72:为什么label不需要unsqueeze?
答72:CrossEntropyLoss期望标签是类别索引,不需要通道维。
第75-76行:应用变换
ct_array=self.transformer(ct_array)
label_array=self.transformer(label_array)
问73:transformer会做什么?
答73:这里是Resize到96×96,减少内存使用。
第78行:返回
return ct_array,label_array
问74:返回的是元组吗?
答74:是的,Python自动打包成元组。
第80-81行:__len__
方法
def __len__(self):return len(self.image_list)
问75:为什么需要__len__
?
答75:DataLoader需要知道数据集大小来创建批次。
我继续分析剩余的每一行代码…
第83-99行:D3UnetData_test类
问76:为什么要单独的test类?
答76:测试集可能需要不同的预处理或采样策略。
ct_array=sitk.GetArrayFromImage(image_ct)[200:250]
问77:为什么test用200:250,train用250:300?
答77:避免数据泄露,训练和测试用不同的切片。
问78:这样分割合理吗?
答78:不合理!应该按病人分,不是按切片分。
第103-105行:定义变换
transformer=transforms.Compose([transforms.Resize((96,96)),
])
问79:Compose是什么?
答79:组合多个变换,按顺序执行。
问80:为什么缩放到96×96?
答80:原始512×512太大,3D卷积内存消耗巨大。
问81:96这个数字有什么特殊?
答81:是32的倍数,适合多次下采样(96→48→24→12→6)。
第107-109行:创建数据集
train_ds=D3UnetData(image_list,label_list,transformer)
test_ds=D3UnetData_test(image_list,label_list,transformer)
len(train_ds)
问82:ds是什么缩写?
答82:dataset的缩写。
问83:为什么train和test用同样的image_list?
答83:这是错误!训练集和测试集应该是不同的病人。
第112-113行:创建DataLoader
train_dl=DataLoader(train_ds,batch_size=2,shuffle=True)
test_dl=DataLoader(test_ds,batch_size=2,shuffle=True)
问84:dl是什么缩写?
答84:dataloader的缩写。
问85:batch_size=2为什么这么小?
答85:3D医学图像占用内存大,[2,1,50,96,96]已经很大了。
问86:测试集为什么要shuffle?
答86:不应该!测试集应该shuffle=False保持顺序。
第117-118行:测试加载
img,label=next(iter(train_dl))
print(img.shape,label.shape)
问87:iter是什么?
答87:创建迭代器对象,可以用next获取下一个。
问88:next做什么?
答88:从迭代器获取一个批次。
第121-122行:可视化
img_show=img[0,0,25,:,:].numpy()
plt.imshow(img_show,cmap='gray')
问89:[0,0,25,:,:]各维度是什么?
答89:[批次0,通道0,第25层,所有高,所有宽]。
问90:为什么选第25层?
答90:50层的中间,最可能看到肝脏。
问91:cmap='gray’是什么?
答91:colormap灰度颜色映射,CT图像是灰度的。
第125-126行:显示标签
label_show=label[0,25,:,:].numpy()
plt.imshow(label_show,cmap='gray')
问92:标签为什么没有通道维度?
答92:标签是类别索引[batch,depth,height,width]。
第129-145行:DoubleConv类
class DoubleConv(nn.Module):def __init__(self,in_channels,out_channels,num_groups=8):
问93:nn.Module是什么?
答93:PyTorch所有神经网络层的基类。
问94:in_channels和out_channels是什么?
答94:输入和输出的特征图数量(通道数)。
self.double_conv=nn.Sequential(nn.Conv3d(in_channels,out_channels,kernel_size=3,stride=1,padding=1),
问95:Sequential是什么?
答95:顺序容器,依次执行内部的层。
问96:Conv3d的kernel_size=3是什么意思?
答96:3×3×3的立方体卷积核。
问97:stride=1表示什么?
答97:卷积核每次移动1个像素。
问98:padding=1的作用?
答98:边缘填充1圈0,保持尺寸不变。
nn.GroupNorm(num_groups=num_groups,num_channels=out_channels),
问99:为什么不用BatchNorm3d?
答99:批次太小(2),统计不稳定,GroupNorm不依赖批次大小。
问100:归一化的目的是什么?
答100:稳定训练,防止梯度消失或爆炸。
nn.ReLU(inplace=True),
问101:ReLU是什么?
答101:Rectified Linear Unit,max(0,x)激活函数。
问102:inplace=True是什么意思?
答102:原地操作,覆盖输入省内存。
第148-151行:测试DoubleConv
img.shape
net=DoubleConv(1,64,num_groups=8)
out=net(img)
print(out.shape)
问103:1→64通道变化意味着什么?
答103:从单通道CT图提取64种不同的特征。
第154-161行:Down类(下采样)
class Down(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()
问104:super().init()做什么?
答104:调用父类Module的初始化,注册参数。
self.encoder = nn.Sequential(nn.MaxPool3d(2, 2),DoubleConv(in_channels, out_channels)
)
问105:MaxPool3d(2,2)是什么操作?
答105:2×2×2窗口取最大值,尺寸减半。
问106:为什么先池化再卷积?
答106:U-Net的设计,减少分辨率同时增加通道数。
第164-179行:Up类(上采样)
class Up(nn.Module):def __init__(self, in_channels, out_channels, trilinear=True):
问107:trilinear是什么?
答107:三线性插值,3D图像的平滑上采样方法。
if trilinear:self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
else:self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
问108:Upsample和ConvTranspose3d的区别?
答108:Upsample插值无参数,ConvTranspose3d有可学习参数。
问109://是什么运算?
答109:整数除法,向下取整。
问110:align_corners=True是什么?
答110:对齐角点像素,影响插值计算方式。
def forward(self, x1, x2):x1 = self.up(x1)
问111:x1和x2分别是什么?
答111:x1是深层特征要上采样,x2是浅层特征要拼接。
diffZ = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffX = x2.size()[4] - x1.size()[4]
问112:为什么要计算差值?
答112:上采样可能有尺寸误差,需要padding对齐。
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
问113:这个padding顺序是什么?
答113:[左右,上下,前后]每个维度的padding量。
问114:为什么用diffX - diffX // 2
?
答114:处理奇数差值,如diff=3时,一边pad 1,另一边pad 2。
x = torch.cat([x2, x1], dim=1)
问115:dim=1是哪个维度?
答115:通道维度,拼接特征图。
第182-187行:Out类
class Out(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)
问116:kernel_size=1的作用?
答116:1×1×1卷积,只改变通道数不改变空间尺寸。
问117:这里out_channels应该是多少?
答117:2,对应背景和肝脏两类。
第190-216行:UNet3d类
class UNet3d(nn.Module):def __init__(self, in_channels, n_classes, n_channels):
问118:三个参数分别是什么?
答118:输入通道(1)、类别数(2)、基础通道数(24)。
self.conv = DoubleConv(in_channels, n_channels)
self.enc1 = Down(n_channels, 2 * n_channels)
self.enc2 = Down(2 * n_channels, 4 * n_channels)
self.enc3 = Down(4 * n_channels, 8 * n_channels)
self.enc4 = Down(8 * n_channels, 8 * n_channels)
问119:通道数为什么是24,48,96,192,192?
答119:每层翻倍增加特征,最后一层保持防止爆内存。
问120:enc是什么缩写?
答120:encoder编码器,提取特征。
self.dec1 = Up(16 * n_channels, 4 * n_channels)
问121:为什么是16×n_channels?
答121:8(enc4)+8(enc3跳跃连接)=16。
def forward(self, x):x1 = self.conv(x)x2 = self.enc1(x1)x3 = self.enc2(x2)x4 = self.enc3(x3)x5 = self.enc4(x4)
问122:x1到x5的尺寸变化?
答122:
- x1: [2,24,50,96,96]
- x2: [2,48,25,48,48]
- x3: [2,96,12,24,24]
- x4: [2,192,6,12,12]
- x5: [2,192,3,6,6]
mask = self.dec1(x5, x4)
mask = self.dec2(mask, x3)
mask = self.dec3(mask, x2)
mask = self.dec4(mask, x1)
mask = self.out(mask)
问123:为什么叫mask?
答123:分割结果是掩码,标记每个像素的类别。
第219-223行:创建模型
model=UNet3d(1,2,24).cuda()
img,label=next(iter(train_dl))
print(img.shape,label.shape)
img=img.cuda()
pred=model(img)
问124:.cuda()做什么?
答124:把模型和数据移到GPU上。
问125:24这个基础通道数怎么选的?
答125:平衡性能和内存,太小欠拟合,太大爆显存。
第228-229行:定义损失和优化器
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.00001)
问126:CrossEntropyLoss包含什么操作?
答126:LogSoftmax + NLLLoss,多分类标准损失。
问127:Adam是什么?
答127:Adaptive Moment Estimation,自适应学习率优化器。
问128:lr=0.00001为什么这么小?
答128:医学图像精度要求高,小学习率稳定训练。
第232-289行:train函数
from tqdm import tqdm
def train(epoch, model, trainloader, testloader):
问129:tqdm是什么?
答129:进度条库,显示训练进度。
correct = 0
total = 0
running_loss = 0
epoch_iou = []
问130:这些变量分别统计什么?
答130:正确像素数、总像素数、累计损失、每批次IOU。
model.train()
问131:model.train()做什么?
答131:启用dropout和batch normalization的训练模式。
for x, y in tqdm(trainloader):x, y = x.to('cuda'), y.to('cuda')
问132:为什么用.to(‘cuda’)而不是.cuda()?
答132:.to()更通用,可以指定设备字符串。
y_pred = model(x)
loss = loss_fn(y_pred, y)
问133:y_pred的形状是什么?
答133:[2,2,50,96,96],批次×类别×深×高×宽。
optimizer.zero_grad()
loss.backward()
optimizer.step()
问134:这三步分别做什么?
答134:清除梯度、反向传播计算梯度、更新参数。
问135:为什么要zero_grad?
答135:PyTorch梯度会累积,不清零会叠加。
with torch.no_grad():y_pred = torch.argmax(y_pred, dim=1)
问136:no_grad()的作用?
答136:禁用梯度计算,节省内存。
问137:argmax在dim=1做什么?
答137:在类别维度取最大值索引,得到预测类别。
correct += (y_pred == y).sum().item()
total += y.size(0)
问138:.item()是什么?
答138:将单元素张量转为Python数值。
intersection = torch.logical_and(y, y_pred)
union = torch.logical_or(y, y_pred)
batch_iou = torch.sum(intersection) / torch.sum(union)
问139:logical_and和logical_or是什么运算?
答139:逻辑与(交集)和逻辑或(并集)。
问140:IOU的值域是什么?
答140:0到1,1表示完全重合。
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / (total*96*96*50)
问141:为什么除以total×96×96×50?
答141:总像素数=批次数×每批像素数。
model.eval()
问142:eval()模式改变什么?
答142:关闭dropout,batch norm用运行统计。
if np.mean(epoch_test_iou)>0.9:static_dict=model.state_dict()torch.save(static_dict,'./checkpoint/{}_trainIOU_{}_testIOU_{}.pth'.format(epoch,round(np.mean(epoch_iou), 3),round(np.mean(epoch_test_iou),3)))
问143:state_dict包含什么?
答143:模型所有参数的字典。
问144:.pth是什么格式?
答144:PyTorch的模型文件扩展名。
问145:round(x, 3)做什么?
答145:保留3位小数。
第292-301行:训练循环
epochs = 100
train_loss = []
train_acc = []
test_loss = []
test_acc = []
问146:为什么创建这些列表但没用?
答146:可能原计划画损失曲线,但没实现。
for epoch in range(epochs):train(epoch, model, train_dl, test_dl)
问147:range(epochs)生成什么?
答147:0到99的整数序列。
第305-318行:预测和可视化
img,label=next(iter(train_dl))
print(img.shape,label.shape)
img=img.to('cuda')
pred=model(img)
问148:为什么又预测一次?
答148:训练后查看效果。
label_show=label[0,20,:,:]
plt.imshow(label_show,cmap='gray')
问149:为什么选第20层?
答149:随意选择的中间层查看。
preds=pred.cpu()
问150:为什么要.cpu()?
答150:matplotlib不能直接显示GPU张量。
plt.imshow(torch.argmax(preds.permute(1,2,0), axis=-1).detach().numpy(),cmap='gray')
问151:这行代码有什么问题?
答151:维度不对!preds是5维的,不能直接permute(1,2,0)。
plt.imshow(torch.argmax(pred.permute(1,2,0), axis=-1).detach().numpy())
问152:这行和上一行的区别?
答152:用了pred而不是preds,但pred在GPU上会报错。
问153:detach()的作用?
答153:断开计算图,返回不需要梯度的张量。