当前位置: 首页 > news >正文

CV 医学影像分类、分割、目标检测,之【皮肤病分类】项目拆解

CV 医学影像分类、分割、目标检测,之【皮肤病分类】项目拆解

    • 第1-12行:导入库
    • 第14-17行:读取标签文件
    • 第19-21行:获取疾病名称
    • 第23-26行:获取图片名列表
    • 第28-35行:筛选有标签的图片
    • 第38-43行:提取标签
    • 第47-51行:创建字典映射
    • 第53-59行:创建类别ID映射
    • 第61-70行:获取筛选后图片的标签
    • 第72-90行:定义数据变换
    • 第92-107行:自定义数据集类
    • 第114-120行:划分训练集测试集
    • 第122-130行:创建数据加载器
    • 第132-146行:可视化数据
    • 第148-151行:加载预训练模型
    • 第153-156行:定义损失和优化器
    • 第158-160行:GPU设置
    • 第163-211行:训练函数
    • 第213-228行:训练循环
    • 第230-233行:绘制损失曲线
    • 第235-260行:加载最佳模型测试
    • 第262-305行:预测新图片
    • 核心流程总结
    • 替换不同模型

 


目标:构建一个基于深度学习的皮肤病分类系统,能够自动识别8种皮肤病类型(黑色素瘤、黑素细胞痣、基底细胞癌、光化性角化病、良性角化病、皮肤纤维瘤、血管病变、鳞状细胞癌)

皮肤病数据集可以在阿里云天池里面搜索获取。

def skin_classification_main():"""皮肤病分类系统主函数 - 领导式全局规划"""# 阶段1:标签数据读取与处理部门label_processor = LabelProcessor()skinDisease, pic_data, df_Key = label_processor.load_and_parse_labels()# 阶段2:图片文件筛选与整理部门  image_organizer = ImageOrganizer()image_organizer.copy_valid_images(pic_data)# 阶段3:类别映射构建部门mapping_builder = MappingBuilder()skinLable_dic, class_to_id, id_to_class = mapping_builder.create_mappings(pic_data, df_Key, skinDisease)# 阶段4:数据集构建部门dataset_builder = DatasetBuilder()train_loader, test_loader = dataset_builder.create_dataloaders(class_to_id, skinLable_dic)# 阶段5:模型构建部门model_builder = ModelBuilder()model, loss_fn, optimizer = model_builder.build_training_components()# 阶段6:训练执行部门trainer = ModelTrainer()train_history = trainer.train_model(model, train_loader, test_loader, epochs=150)# 阶段7:训练结果可视化部门visualizer = TrainingVisualizer()visualizer.plot_training_curves(train_history)# 阶段8:模型评估部门evaluator = ModelEvaluator()final_performance = evaluator.evaluate_best_model(model, test_loader)# 阶段9:预测展示部门predictor = PredictionDemo()predictor.demo_batch_prediction(model, test_loader, id_to_class)# 阶段10:单图预测部门single_predictor = SingleImagePredictor()result = single_predictor.predict_single_image(model, image_path, id_to_class)return model, result

第1-12行:导入库

from PIL import Image

问1: PIL是什么缩写?
答1: Python Imaging Library(Python图像处理库)

问2: 为什么用from…import而不是import?
答2: 只导入需要的Image类,避免命名空间污染

问3: Image类能做什么?
答3: 打开、创建、修改、保存各种格式的图片文件

import torch

问4: torch的核心是什么?
答4: 张量(Tensor)运算和自动微分

问5: 什么是张量?
答5: 多维数组,0维是标量,1维是向量,2维是矩阵,3维以上叫张量

from torch.utils import data

问6: utils是什么?
答6: utilities工具集,data是数据加载工具

问7: 为什么需要专门的数据加载工具?
答7: 批量加载、打乱顺序、多线程预处理

import numpy as np

问8: numpy和torch的区别?
答8: numpy在CPU运算,torch可在GPU运算且支持自动求导

import pandas as pd

问9: pandas擅长什么?
答9: 表格数据处理,像Excel一样操作数据

from torchvision import transforms

问10: transforms是做什么变换?
答10: 图像预处理:裁剪、旋转、归一化等

import torchvision

问11: torchvision和torch的关系?
答11: torchvision是torch的计算机视觉扩展包

import matplotlib.pyplot as plt

问12: pyplot的plt是约定俗成吗?
答12: 是的,社区约定,便于代码交流

import torch.nn.functional as F

问13: functional和nn.Module的区别?
答13: functional是无状态函数,Module是有参数的层

import torch.nn as nn

问14: nn代表什么?
答14: Neural Network,神经网络模块

from tqdm import tqdm

问15: tqdm是什么意思?
答15: 阿拉伯语"进展",用来显示进度条

import os
import glob
import shutil

问16: 这三个都是文件操作,有什么区别?
答16: os基础操作,glob模式匹配,shutil高级操作


第14-17行:读取标签文件

df=pd.read_table('./skin_label.txt',sep='\t',header='infer')

问17: ./是什么路径?
答17: 当前目录,相对路径

问18: header='infer’是什么意思?
答18: 自动推断第一行是否为列名

df_Key=np.array(df.iloc[:,1:])

问19: iloc和loc的区别?
答19: iloc用整数位置索引,loc用标签索引

问20: 为什么转成numpy数组?
答20: numpy运算更快,且后续要用argmax

df_Key.shape

问21: shape返回什么?
答21: 元组(行数, 列数),这里是(6000, 9)


第19-21行:获取疾病名称

skinDisease=df.columns[1:].to_numpy()

问22: columns是什么?
答22: DataFrame的列名,Index对象

问23: to_numpy()和values的区别?
答23: to_numpy()是新方法,values将被弃用

skinDisease

问24: 不加print为什么也能输出?
答24: Jupyter/交互模式下,最后一个表达式自动显示


第23-26行:获取图片名列表

pic_data=np.array(df.iloc[:,0])

问25: 第0列是什么?
答25: 图片文件名列

pic_data=pic_data.tolist()

问26: 为什么要转成list?
答26: 后面要用in判断,list的in操作比array快

len(pic_data)

问27: len对不同对象的含义?
答27: list是元素个数,string是字符数,dict是键值对数


第28-35行:筛选有标签的图片

imgs=glob.glob('./data/skin_data/*.jpg')

问28: 是什么通配符?
答28: 匹配任意字符,
.jpg匹配所有jpg文件

for im in imgs:

问29: im是什么类型?
答29: 字符串,完整文件路径

    im_name=im[17:-4]

问30: 为什么是17?
答30: './data/skin_data/'正好17个字符

问31: -4是什么?
答31: 倒数第4个字符开始,去掉’.jpg’

    print(im_name)

问32: 这个print是调试用的?
答32: 是的,确认提取的文件名正确

    if im_name in pic_data:

问33: in的时间复杂度?
答33: list是O(n),set是O(1)

        print('E:/皮肤病分类/data/clear_skin_data/{}'.format(im_name))

问34: format和f-string的区别?
答34: format是旧语法,f-string(f’{im_name}')更简洁

        shutil.copy(im,'E:/皮肤病分类/data/clear_skin_data/{}.jpg'.format(im_name))

问35: copy和move的区别?
答35: copy保留原文件,move是剪切


第38-43行:提取标签

skin_label=[]

问36: 为什么用列表不用数组?
答36: 要逐个append,列表动态增长更高效

index=np.argmax(df_Key,axis=1)

问37: argmax返回什么?
答37: 最大值的索引位置

问38: axis=1和axis=0的记忆方法?
答38: axis=0沿着行方向(↓),axis=1沿着列方向(→)

for i in index:skin_index=skinDisease[i]skin_label.append(skin_index)

问39: i是什么值?
答39: 0-8的整数,表示疾病类别索引

问40: append和extend的区别?
答40: append加单个元素,extend加多个元素


第47-51行:创建字典映射

skinLable_dic={}
lableSkin_dic={}

问41: 为什么建两个字典?
答41: 双向映射:图片→标签,标签→图片

for i in range(6000):skinLable_dic[pic_data[i]]=skin_label[i]

问42: range(6000)和range(len(pic_data))哪个好?
答42: range(len(pic_data))更好,自适应数据长度


第53-59行:创建类别ID映射

class_id = list(set(skinLable_dic.values()))

问43: set的作用?
答43: 去重,获取唯一的疾病类别

问44: 为什么又转回list?
答44: set无序,list可以索引访问

id_to_class={}
class_to_id={}
for i,e in enumerate(class_id):class_to_id[e]=iid_to_class[i]=e

问45: enumerate返回什么?
答45: (索引, 元素)的元组

问46: 为什么需要数字ID?
答46: 神经网络输出是数字,不是字符串


第61-70行:获取筛选后图片的标签

clear_img_path=glob.glob('./data/clear_skin_data/*.jpg')

问47: 这是第二次glob,为什么?
答47: 获取筛选后的图片路径列表

clear_img_lable=[]
for img in clear_img_path:img_name=img[23:-4]

问48: 23是怎么算的?
答48: './data/clear_skin_data/'是23个字符

    classes=skinLable_dic[img_name]ids=class_to_id[classes]clear_img_lable.append(ids)

问49: 这里做了几次映射?
答49: 两次:文件名→疾病名→数字ID


第72-90行:定义数据变换

train_transformer=transforms.Compose([  transforms.RandomHorizontalFlip(0.2),

问50: Compose是什么设计模式?
答50: 组合模式,串联多个变换

问51: 0.2的概率是每张图片独立的吗?
答51: 是的,每次调用独立决定

   transforms.RandomRotation(68),

问52: 为什么是68度不是90度?
答52: 可能是经验值,避免过度旋转丢失信息

    transforms.RandomGrayscale(0.2),

问53: 灰度化的目的?
答53: 增强模型对颜色变化的鲁棒性

   transforms.Resize((128,128)),

问54: 为什么是128不是224?
答54: 平衡精度和速度,128够用且更快

   transforms.ToTensor(),

问55: Tensor和array的内存布局区别?
答55: Tensor是CHW(通道-高-宽),array通常是HWC

   transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) 

问56: 这个归一化后的范围?
答56: (pixel-0.5)/0.5,从[0,1]变为[-1,1]

问57: 为什么要归一化到[-1,1]?
答57: 零中心化,有助于梯度下降收敛


第92-107行:自定义数据集类

class Skindataset(data.Dataset):

问58: 为什么必须继承Dataset?
答58: DataLoader需要调用固定接口

    def __init__(self, img_paths, labels, transform):self.imgs = clear_img_pathself.labels = clear_img_lable

问59: 这里有bug吗?
答59: 有!应该用参数img_paths和labels,不是全局变量

    def __getitem__(self, index):

问60: 这个方法什么时候被调用?
答60: DataLoader迭代时自动调用

        img = self.imgs[index]label = self.labels[index]pil_img = Image.open(img)    data = self.transforms(pil_img)

问61: 每次都打开文件会不会慢?
答61: 会,但省内存,是时间换空间

        return data, label

问62: 返回顺序重要吗?
答62: 重要,约定是(输入, 标签)

    def __len__(self):return len(self.imgs)

问63: 为什么需要__len__?
答63: DataLoader需要知道数据集大小来计算批次数


第114-120行:划分训练集测试集

s = int(len(clear_img_path)*0.8)

问65: 为什么是0.8?
答65: 经验值,80%训练20%测试

问66: int()是向下取整吗?
答66: 是的,截断小数部分

train_imgs = clear_img_path[:s]
test_imgs = clear_img_path[s:]

问67: 这样分割有什么问题?
答67: 没打乱,可能有顺序偏差


第122-130行:创建数据加载器

train = Skindataset(train_imgs, train_labels, train_transformer)

问68: 这里会调用__init__吗?
答68: 会,创建实例时自动调用

dl_train = data.DataLoader(train,batch_size=32,shuffle=True)

问69: batch_size=32的含义?
答69: 每次送入网络32张图片

问70: 为什么要batch不要单张?
答70: 并行计算快,梯度估计更稳定

问71: shuffle=True的作用?
答71: 打乱顺序,防止模型记住顺序


第132-146行:可视化数据

img, label = next(iter(dl_train))

问72: iter()做了什么?
答72: 创建迭代器对象

问73: next()返回什么?
答73: 一个批次的(图片张量, 标签张量)

plt.rcParams['font.sans-serif'] = ['SimHei']

问74: SimHei是什么?
答74: 黑体字体,支持中文显示

skin_chinese={'MEL':'黑色素瘤','NV':'黑素细胞痣',...}

问75: 字典的键为什么用英文缩写?
答75: 与数据集标签保持一致

for i,(img,label) in enumerate(zip(img[:8],label[:8])):

问76: zip的作用?
答76: 将两个序列配对成元组

    img=(img.permute(1,2,0).numpy()+1)/2

问77: permute(1,2,0)在做什么?
答77: CHW转HWC,适配matplotlib

问78: +1再/2是为什么?
答78: [-1,1]恢复到[0,1]用于显示

    plt.subplot(2,4,i+1)

问79: i+1是因为?
答79: subplot索引从1开始,不是0


第148-151行:加载预训练模型

model=torchvision.models.resnet50()

问80: resnet50的50指什么?
答80: 网络深度,50层

问81: 预训练是在什么数据上?
答81: ImageNet,1000类日常物体

model.fc.out_features=8

问82: 这样直接赋值有用吗?
答82: 没用,应该替换整个fc层

问83: 正确写法是什么?
答83: model.fc = nn.Linear(model.fc.in_features, 8)


第153-156行:定义损失和优化器

loss_fn=nn.CrossEntropyLoss()

问84: 交叉熵适合什么任务?
答84: 多分类任务

问85: 交叉熵的数学本质?
答85: 衡量两个概率分布的差异

from torch.optim import lr_scheduler

问86: lr_scheduler没用到?
答86: 是的,导入了但没使用

optim=torch.optim.Adam(model.parameters(),lr=0.001)

问87: Adam是什么的缩写?
答87: Adaptive Moment Estimation

问88: lr=0.001是经验值吗?
答88: 是的,Adam的常用默认值


第158-160行:GPU设置

if torch.cuda.is_available():model.to('cuda')

问89: cuda是什么?
答89: NVIDIA的并行计算平台

问90: to(‘cuda’)做了什么?
答90: 把模型参数移到GPU内存


第163-211行:训练函数

def fit(epoch, model, trainloader, testloader):correct = 0total = 0running_loss = 0

问91: 这三个变量追踪什么?
答91: 正确数、总数、累积损失

    model.train()

问92: train()模式改变什么?
答92: 启用Dropout和BatchNorm的训练行为

    for x, y in tqdm(trainloader):

问93: tqdm包装的效果?
答93: 显示进度条

        if torch.cuda.is_available():x, y = x.to('cuda'), y.to('cuda')

问94: 每个batch都要to(‘cuda’)吗?
答94: 是的,数据在CPU,要移到GPU

        y_pred = model(x)

问95: model(x)等价于?
答95: model.forward(x)

        loss = loss_fn(y_pred, y)

问96: y_pred和y的形状?
答96: y_pred是(32,8),y是(32,)

        optim.zero_grad()

问97: 不清零会怎样?
答97: 梯度累加,相当于更大的batch

        loss.backward()

问98: backward()计算什么?
答98: loss对所有参数的偏导数

        optim.step()

问99: step()的更新公式?
答99: 参数 = 参数 - 学习率 × 梯度

        with torch.no_grad():

问100: no_grad()的作用?
答100: 禁用梯度计算,节省内存

            y_pred = torch.argmax(y_pred, dim=1)

问101: argmax后的形状?
答101: 从(32,8)变为(32,)

            correct += (y_pred == y).sum().item()

问102: .item()的作用?
答102: 将单元素张量转为Python数值

    epoch_loss = running_loss / len(trainloader.dataset)

问103: 为什么除以dataset长度不是batch数?
答103: 获得每个样本的平均损失

    model.eval()

问104: eval()改变什么?
答104: 关闭Dropout,BatchNorm用运行均值

    torch.save(static_dict,'./data/resnet_Chepoint/{}_train_acc_{}_test_acc_{}.pth'.format(epoch,round(epoch_acc, 3),round(epoch_test_acc,3)))

问105: .pth是什么格式?
答105: PyTorch的模型文件格式

问106: state_dict包含什么?
答106: 所有层的权重和偏置参数


第213-228行:训练循环

epochs = 150

问107: 150轮够吗?
答107: 看验证集性能,可能过拟合

for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(...)

问108: 每轮都保存模型?
答108: 是的,可以选最好的


第230-233行:绘制损失曲线

plt.plot(range(1, epochs+1), train_loss, label='train_loss')

问109: range从1开始?
答109: 让横轴从1开始,更直观


第235-260行:加载最佳模型测试

model.load_state_dict(torch.load('./data/resnet_Chepoint/143_train_acc_0.904_test_acc_0.981.pth'))

问110: 为什么选143轮的?
答110: 测试准确率最高(98.1%)


第262-305行:预测新图片

t_img='C:/Users/MSI-NB/AppData/Local/Temp/vasssssss.jpeg'

问111: 这是什么路径?
答111: Windows临时文件夹的图片

img_tensor=test_transformer(img)
img_tensor=img_tensor.unsqueeze(0)

问112: unsqueeze(0)做什么?
答112: 增加batch维度,(3,128,128)→(1,3,128,128)

pre=torch.argmax(out,axis=1).cpu().numpy()[0]

问113: .cpu()为什么需要?
答113: GPU张量不能直接转numpy

id_to_class[pre]

问114: 最终输出什么?
答114: 疾病类别名称,如’MEL’(黑色素瘤)


核心流程总结

这个项目的本质是一个迁移学习流程:

  1. 数据准备:图片+标签 → 数字化
  2. 数据增强:翻转旋转 → 泛化能力
  3. 特征提取:ResNet50 → 图像特征
  4. 微调分类:1000类 → 8类疾病
  5. 迭代优化:梯度下降 → 最小损失
  6. 模型应用:新图片 → 疾病诊断

替换不同模型

如果只想快速替换模型,最少只需改2处:

  • 不同模型架构不同,最后一层名称不同
  • 输入尺寸要求可能不同

除了最后一层,还有什么要改? 输入尺寸要求不同!

但手动修改,还是容易出现 BUG。

代码中哪些地方依赖于ResNet50?

答1: 主要是两处:

model = torchvision.models.resnet50()    # 第148行
model.fc.out_features = 8                # 第150行,这行还有bug
# ResNet系列
model = torchvision.models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 8)  # fc层# VGG系列
model = torchvision.models.vgg16()
model.classifier[6] = nn.Linear(4096, 8)  # classifier是个Sequential# DenseNet系列
model = torchvision.models.densenet121()
model.classifier = nn.Linear(model.classifier.in_features, 8)  # classifier层# EfficientNet系列
model = torchvision.models.efficientnet_b0()
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 8)  # classifier[1]# MobileNet系列
model = torchvision.models.mobilenet_v2()
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 8)  # classifier[1]

更优雅的解决方案:使用timm库

import timmdef get_model_timm(model_name='resnet50', num_classes=8, pretrained=True):"""问5:timm是什么?答5:PyTorch Image Models,包含700+预训练模型问6:为什么用timm更好?答6:统一接口,自动处理最后一层"""model = timm.create_model(model_name,pretrained=pretrained,num_classes=num_classes  # 自动替换最后一层!)return model# 使用示例 - 可以用任何模型!
model = get_model_timm('resnet50', num_classes=8)
model = get_model_timm('efficientnet_b7', num_classes=8)
model = get_model_timm('vit_base_patch16_224', num_classes=8)  # Vision Transformer!
http://www.dtcms.com/a/327596.html

相关文章:

  • OHEM (在线难例挖掘) 详细讲解
  • 【Vue.js】生产设备规划工具(报价单Word文档生成)【开发全流程】
  • 无人机航拍数据集|第14期 无人机水体污染目标检测YOLO数据集3000张yolov11/yolov8/yolov5可训练
  • etcd 备份与恢复
  • Etcd客户端工具Etcd Workbench更新了1.2.0版本!多语言支持了中文,新增了许多快捷功能使用体验再次提升
  • Spark 运行流程核心组件(一)作业提交
  • 干货分享|如何从0到1掌握R语言数据分析
  • 小红书笔记信息获取_实在智能RPA源码解读
  • 邦纳BANNER相机视觉加镜头PresencePLUSP4 RICOH FL-CC2514-2M工业相机
  • C++实现LINGO模型处理程序
  • Java结课案例-景点人数统计的几种场景
  • 日期格式化成英文月,必須指定語言環境
  • Secure CRT做代理转发
  • HTTP应用层协议-长连接
  • 记对外国某服务器的内网渗透
  • C++少儿编程(二十二)—条件结构
  • 机械臂运动规划与控制12讲
  • SQL 语言分类
  • 后端学习路线
  • 3D文档控件Aspose.3D实用教程:在 C# 中将 3MF 文件转换为 STL
  • 开疆智能Ethernet转ModbusTCP网关连接发那科机器人与三菱PLC配置案例
  • Spring Boot部署万亿参数模型推理方案(深度解析)
  • css之再谈浮动定位float(深入理解篇)
  • 物联网、大数据与云计算持续发展,楼宇自控系统应用日益广泛
  • 黑马程序员mysql课程p65 安装linux版本的mysql遇到问题
  • [密码学实战]基于国密TLCP协议的Java服务端实现详解(四十四)
  • 【基于DesignStart的M3 SoC】
  • 4/5G中频段频谱全球使用现状概述(截止2025 年7月)
  • 【unity实战】在 Unity 中实现卡牌翻转或者翻书的效果
  • 现代化水库运行管理矩阵建设的要点