当前位置: 首页 > news >正文

深度学习ResNet模型提取影响特征

        大家好,我是带我去滑雪!

        影像组学作为近年来医学影像分析领域的重要研究方向,致力于通过从医学图像中高通量提取大量定量特征,以辅助疾病诊断、分型、预后评估及治疗反应预测。这些影像特征涵盖了形状、纹理、灰度统计及波形变换等多个维度,能够在无需侵入性操作的前提下,为临床提供潜在的生物标志物。相比于传统医生的主观评估,影像组学特征更加客观、量化,能捕捉到肉眼难以察觉的图像差异,从而提升疾病识别的敏感性与特异性。然而,随着特征数量和维度的不断增加,如何从海量数据中有效提取具有判别力的高阶特征,成为制约影像组学发展的关键问题。传统的手工特征提取方法依赖于专家经验设计,往往具有局限性,难以适应复杂多变的临床场景。此时,深度学习技术的引入为影像组学的发展带来了新的突破。特别是卷积神经网络(CNN)在图像处理方面的卓越表现,为自动化特征提取提供了有力工具。

         在众多CNN架构中,ResNet(残差网络)因其独特的残差连接机制,有效缓解了网络加深带来的梯度消失问题,能够训练更深层次的模型,从而捕捉更复杂、抽象的影像特征。相比浅层网络或传统方法,ResNet能自动从原始图像中学习出更具区分性的表征,提升分类和预测性能。在影像组学应用中,ResNet不仅可以代替手工特征提取过程,还能与传统特征融合,实现更高层次的特征整合,增强模型的泛化能力。因此,将ResNet应用于影像组学特征提取,不仅符合当前智能医疗发展的趋势,也为精准医学提供了强有力的技术支撑。

       这里使用公开的肺炎X-ray数据集,数据集包含5856张经过验证的胸部X光片图像,且被分为训练集和独立患者的测试集。数据集中所有文件已经被转换为JPEG格式,并且已经被划分为正常、肺炎两个类别。

        ResNet模型提取影响特征代码:

import os
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
from tqdm import tqdm
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt

import torch
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision import transforms


# 用于遍历图像
def images_iterator(image_dir):
    dataset_images = glob(f"{image_dir}/**/*.jpeg", recursive=True)
    for image_path in dataset_images:
        file_name = os.path.basename(image_path).split(".")[0]
        image = Image.open(image_path)
        # Resnet在预训练使用的参数
        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]
        # 参照Resnet预训练时的图像处理方式处理
        transform = transforms.Compose([
            # 转换为 Tensor,自动将像素值归一化到 [0, 1]
            transforms.ToTensor(),
            # 调整大小
            transforms.Resize((224, 224)),
            # resnet输入要求通道数量为3
            transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
            # 对像素取值归一化
            transforms.Normalize(
                mean=imagenet_mean,
                std=imagenet_std)
        ])
        input_tensor = transform(image)
        input_tensor = input_tensor.unsqueeze(0)

        # dimensions of input_tensor are [1, 3, 224, 224]
        # 返回文件名和预处理之后的图像
        yield file_name, input_tensor


def pca_plot(csv_path="./test.csv"):
    # 加载提出的特征
    df_data = pd.read_csv(csv_path)
    df_x = df_data.iloc[:, 2:].to_numpy()
    df_y = df_data.iloc[:, 1].to_numpy()
    # 通过PCA算法将Resnet提取的2048个特征降为3个
    pca = PCA(n_components=3)
    x = pca.fit(df_x).transform(df_x)

    # 绘制图像的一些参数设置
    category_names = ["NORMAL", "PNEUMONIA"]
    ax = plt.figure().add_subplot(projection='3d')
    colors = ["navy", "turquoise"]
    lw = 2
    # 绘制图像
    for color, target_name in zip(colors, category_names):
        ax.scatter(
            x[df_y == target_name, 0],
            x[df_y == target_name, 1],
            x[df_y == target_name, 2],
            color=color, alpha=0.8,
            lw=lw,
            label=target_name
        )
    plt.title("PCA of Chest X-ray")
    plt.show()


def main():
    # 加载预训练的resnet50模型
    resnet50_weight = ResNet50_Weights.DEFAULT
    print(resnet50_weight.transforms())
    resnet50_mdl = resnet50(weights=ResNet50_Weights.DEFAULT).eval()
    # 输出每一层的名称
    nodes_name, _ = get_graph_node_names(resnet50_mdl)
    # 根据输出结果,我们可以知道resnet50在通过全连接层分类之前的网络层名称是flatten
    print(nodes_name)
    return_nodes = {
        "flatten": "final_feature_map"
    }
    # 基于选择的输出和resnet50构建特征提取器
    feature_extracter = create_feature_extractor(resnet50_mdl, return_nodes=return_nodes)
    # 设定影像路径
    chest_xray = os.path.join(os.getcwd(), "chest_xray")
    train_dataset_dir = os.path.join(chest_xray, "train")
    test_dataset_dir = os.path.join(chest_xray, "test")
    # 设定列名
    column_names = ["patient_id", "category"] + [f"resnet_feature_{i + 1}" for i in range(2048)]

    with torch.no_grad():
        df_train = pd.DataFrame()
        df_test = pd.DataFrame()
        # 进行训练集印象特征提取
        for image_name, image_tensor in tqdm(images_iterator(train_dataset_dir)):
            # dimensions of out is [1, 2048]
            out = feature_extracter(image_tensor)
            out_features = out["final_feature_map"]
            out_features = out_features.cpu().numpy()[0]
            if "NORMAL" in image_name:
                category = "NORMAL"
            else:
                category = "PNEUMONIA"
            row_data = [image_name, category] + list(out_features)
            df_train = pd.concat([df_train, pd.Series(row_data)], ignore_index=True, axis=1)
        # 保存到csv文件
        df_train = df_train.T
        df_train.columns = column_names
        df_train.to_csv("train.csv", index=False)
        # 进行测试集影像特征提取
        for image_name, image_tensor in tqdm(images_iterator(test_dataset_dir)):
            # dimensions of out is [1, 2048]
            out = feature_extracter(image_tensor)
            out_features = out["final_feature_map"]
            out_features = out_features.cpu().numpy()[0]
            if "NORMAL" in image_name:
                category = "NORMAL"
            else:
                category = "PNEUMONIA"
            row_data = [image_name, category] + list(out_features)
            df_test = pd.concat([df_test, pd.Series(row_data)], ignore_index=True, axis=1)
        # 保存到csv文件
        df_test = df_test.T
        df_test.columns = column_names
        df_test.to_csv("test.csv", index=False)


if __name__ == "__main__":
    main()
    pca_plot()

       输出结果:

         使用ResNet模型提取的影像特征和传统方法提取的影像特征是一样的,最主要的区别只是深度学习算法提取的特征没有名称。获取ResNet提取的特征之后,就可以使用其它统计分析方法进行下一步分析。


更多优质内容持续发布中,请移步主页查看。

   点赞+关注,下次不迷路!

相关文章:

  • 小米运维面试题及参考答案(80道面试题)
  • CST1016.基于Spring Boot+Vue高校竞赛管理系统
  • DOM解析XML:Java程序员的“乐高积木式“数据搭建
  • 国内AI大模型卷到什么程度了?
  • Linux虚拟内存详解
  • LLaMA 常见面试题
  • 探索加密期权波动率交易的系统化实践——动态对冲工具使用
  • 配置SecureCRT8.5的粘贴复制等快捷键
  • 代码生成工具explain的高级用法
  • 【随身wifi】青龙面板保姆级教程
  • ROS2---std_msgs基础消息包
  • AI+高德MCP:制作一份旅游攻略
  • PyTorch进阶学习笔记[长期更新]
  • Magnet 库教程与命名规范指南
  • 【2025最新】windows本地部署LightRAG,完成neo4j知识图谱保存
  • AnythingLLM:windows部署体验
  • 信息系统项目管理师-软考高级(软考高项)​​​​​​​​​​​2025最新(二)
  • idea配置spring MVC项目启动(maven配置完后)
  • 组合数学——二项式系数
  • linux以C方式和内核交互监听键盘[香橙派搞机日记]
  • 国家消防救援局应急通信和科技司负责人张昊接受审查调查
  • 山西持续高温:阳城地表温度72.9℃破纪录,明日局部地区仍将超40℃
  • 央媒:安徽凤阳鼓楼坍塌楼宇部分非文物,系违规复建的“假古董”
  • 上海国际电影节将于6月3日公布排片表,6月5日中午开票
  • 事关政府信息公开,最高法发布最新司法解释
  • 王毅同德国外长瓦德富尔通电话