当前位置: 首页 > news >正文

【机器学习】超参数调优指南:交叉验证,网格搜索,混淆矩阵——基于鸢尾花与数字识别案例的深度解析

一、前言:为何要学交叉验证与网格搜索?

大家好!在机器学习的道路上,我们经常面临一个难题:模型调参。比如在 KNN 算法中,选择多少个邻居(n_neighbors)直接影响预测效果。

蛮力猜测:就像在厨房随便“加盐加辣椒”,不仅费时费力,还可能把菜搞砸。

交叉验证 + 网格搜索:更像是让你请来一位“大厨”,提前试好所有配方,帮你挑选出最完美的“调料搭配”。

交叉验证与网格搜索的组合,能让你在众多超参数组合中自动挑选出最佳方案,从而让模型预测达到“哇塞,这也太准了吧!”的境界。


二、概念扫盲:交叉验证 & 网格搜索

1. 交叉验证(Cross-Validation)

核心思路:

分组品尝:将整个数据集平均分成若干份(比如分成 5 份,即“5折交叉验证”)。

轮流担任评委:每次选取其中一份作为“验证集”(就像让这部分数据来“评委打分”),剩下的作为“训练集”来训练模型。

集体评定:重复多次,每一份都轮流担任验证集,然后把所有“评分”取平均,作为模型在数据集上的最终表现。

好处:

• 每个样本都有机会既当“选手”又当“评委”,使得评估结果更稳定、可靠。

• 避免单一划分带来的偶然性,确保你调出来的参数在不同数据切分下都表现良好。


2. 网格搜索(Grid Search)

核心思路:

列出所有可能:将你想尝试的超参数组合“罗列成一个表格(网格)”。

自动试菜:每种组合都进行一次完整的模型训练和评估,记录下它们的表现。

选出最佳配方:最后找出在交叉验证中表现最好的超参数组合。

好处:

• 自动化、系统化地寻找最佳参数组合,避免你手动“胡乱猜测”。

• 和交叉验证结合后,每个参数组合都经过了多次评估,结果更稳健。

3. 网格搜索 + 交叉验证

这两者结合就像“炼丹”高手的秘诀:

交叉验证解决了“数据切分”的问题,让评估更准确;

网格搜索解决了“超参数组合”问题,帮你遍历所有可能性。

合体后,你就能轻松找到最优超参数,让模型发挥出最佳性能!


三、案例一:鸢尾花数据集 + KNN + 交叉验证网格搜索

3.1 数据集介绍

数据来源:scikit-learn 内置的 load_iris

特征:萼片长度、萼片宽度、花瓣长度、花瓣宽度

目标:根据花的外部特征预测其所属的鸢尾花种类


3.2 代码示例

下面代码展示如何在鸢尾花数据集上使用 KNN 算法,并通过 GridSearchCV(交叉验证+网格搜索)自动调优 n_neighbors 参数:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

def iris_knn_cv():
    """
    使用KNN算法在鸢尾花数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。
    """
    # 1. 加载数据
    iris = load_iris()
    X = iris.data      # 特征矩阵,包含四个特征
    y = iris.target    # 标签,分别代表三种鸢尾花

    # 2. 划分训练集和测试集
    # test_size=0.2 表示 20% 的数据用于测试,保证测试结果具有代表性
    # random_state=22 固定随机数种子,确保每次运行划分一致
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=22
    )

    # 3. 数据标准化
    # 标准化可使各特征均值为0、方差为1,消除量纲影响(对于基于距离的KNN非常重要)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # 4. 构建KNN模型及参数调优
    knn = KNeighborsClassifier()  # 初始化KNN模型

    # 4.1 设置网格搜索参数范围:尝试不同的邻居数
    param_grid = {
        'n_neighbors': [1, 3, 5, 7, 9]
    }

    # 4.2 进行网格搜索 + 交叉验证(5折交叉验证)
    grid_search = GridSearchCV(
        estimator=knn,        # 待调参的模型
        param_grid=param_grid,  # 超参数候选列表
        cv=5,                 # 5折交叉验证:将训练集分为5个子集,每次用1个子集验证,其余4个训练
        scoring='accuracy',   # 以准确率作为评估指标
        n_jobs=-1             # 使用所有CPU核心并行计算
    )
    grid_search.fit(X_train_scaled, y_train)  # 自动遍历各参数组合并评估

    # 4.3 输出网格搜索结果
    print("最佳交叉验证分数:", grid_search.best_score_)
    print("最优超参数组合:", grid_search.best_params_)
    print("最优模型:", grid_search.best_estimator_)

    # 5. 模型评估:用测试集评估最优模型的泛化能力
    best_model = grid_search.best_estimator_
    y_pred = best_model.predict(X_test_scaled)
    acc = accuracy_score(y_test, y_pred)
    print("在测试集上的准确率:{:.2f}%".format(acc * 100))

    # 6. 可视化(选做):可进一步绘制混淆矩阵或学习曲线

# 直接调用函数进行测试
if __name__ == "__main__":
    iris_knn_cv()

输出: 

3.3 结果解读

最佳交叉验证分数:表示在5折交叉验证过程中,所有参数组合中平均准确率最高的值。

最优超参数组合:显示在候选参数 [1, 3, 5, 7, 9] 中哪个 n_neighbors 的效果最好。

测试集准确率:验证模型在未见数据上的表现,反映其泛化能力。

通过这个案例,你可以看到交叉验证网格搜索如何自动帮你“挑菜”选料,让 KNN 模型在鸢尾花分类任务上达到最佳表现。


四、案例二:手写数字数据集 + KNN + 交叉验证网格搜索

4.1 数据集介绍

数据来源:scikit-learn 内置的 load_digits

特征:每张 8×8 像素的手写数字图像被拉伸成64维特征向量

目标:识别图片中数字所属类别(0~9)

4.2 代码示例

下面代码展示如何在手写数字数据集上使用 KNN 算法,并通过交叉验证网格搜索调优参数:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits        # 导入手写数字数据集(内置于 scikit-learn)
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler  # 导入数据标准化工具
from sklearn.neighbors import KNeighborsClassifier  # 导入KNN分类器
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns  # 导入 seaborn,用于绘制更美观的图表

def digits_knn_cv():
    """
    使用KNN算法在手写数字数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。
    """
    # 1. 加载数据
    digits = load_digits()  # 从scikit-learn加载内置手写数字数据集
    X = digits.data         # 特征数据,形状为 (1797, 64),每一行对应一张图片的64个像素值
    y = digits.target       # 目标标签,共10个类别(数字 0 到 9)
    
    # 2. 数据可视化:展示前5张图片及其标签
    # 创建一个1行5列的子图区域,图像尺寸为10x2英寸
    fig, axes = plt.subplots(1, 5, figsize=(10, 2))
    for i in range(5):
        # 显示第 i 张图片,使用灰度图(cmap='gray')
        axes[i].imshow(digits.images[i], cmap='gray')
        # 设置每个子图的标题,显示该图片对应的标签
        axes[i].set_title("Label: {}".format(digits.target[i]))
        # 关闭坐标轴显示(避免坐标信息干扰视觉效果)
        axes[i].axis('off')
    plt.suptitle("手写数字数据集示例")  # 为整个图表添加一个总标题
    plt.show()  # 显示图表
    
    # 3. 数据划分 + 标准化
    # 将数据划分为训练集和测试集,其中测试集占20%,random_state保证每次划分一致
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    # 初始化标准化工具,将特征数据转换为均值为0、方差为1的标准正态分布
    scaler = StandardScaler()
    # 仅在训练集上拟合标准化参数,并转换训练集数据
    X_train_scaled = scaler.fit_transform(X_train)
    # 使用相同的转换参数转换测试集数据(避免数据泄露)
    X_test_scaled = scaler.transform(X_test)
    
    # 4. 构建KNN模型及网格搜索调参
    knn = KNeighborsClassifier()  # 初始化KNN分类器,暂未指定 n_neighbors 参数
    # 定义一个字典,列出希望尝试的超参数组合
    # 这里我们测试不同邻居数的效果:[1, 3, 5, 7, 9]
    param_grid = {
        'n_neighbors': [1, 3, 5, 7, 9]
    }
    # 初始化网格搜索对象,结合交叉验证
    grid_search = GridSearchCV(
        estimator=knn,         # 需要调参的KNN模型
        param_grid=param_grid, # 超参数候选组合
        cv=5,                  # 5折交叉验证,将训练数据分成5份,每次用4份训练,1份验证
        scoring='accuracy',    # 使用准确率作为模型评估指标
        n_jobs=-1              # 并行计算,使用所有可用的CPU核心加速计算
    )
    # 在标准化后的训练集上进行网格搜索,自动尝试所有参数组合,并进行交叉验证
    grid_search.fit(X_train_scaled, y_train)
    
    # 5. 输出网格搜索调参结果
    # 打印在交叉验证中获得的最佳平均准确率
    print("手写数字 - 最佳交叉验证分数:", grid_search.best_score_)
    # 打印获得最佳结果时所使用的超参数组合,例如 {'n_neighbors': 3}
    print("手写数字 - 最优超参数组合:", grid_search.best_params_)
    # 打印最佳模型对象,该模型已使用最优参数重新训练
    best_model = grid_search.best_estimator_
    
    # 6. 模型评估:用测试集评估模型效果
    # 使用最优模型对测试集进行预测
    y_pred = best_model.predict(X_test_scaled)
    # 计算测试集上的准确率
    acc = accuracy_score(y_test, y_pred)
    print("手写数字 - 测试集准确率:{:.2f}%".format(acc * 100))
    
    # 7. 可视化混淆矩阵(直观展示各数字分类效果)
    # 混淆矩阵能够显示真实标签与预测标签之间的对应关系
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(6, 5))
    # 使用 seaborn 的 heatmap 绘制混淆矩阵,annot=True 表示在每个单元格中显示数字
    sns.heatmap(cm, annot=True, cmap='Blues', fmt='d')
    plt.title("手写数字 - 混淆矩阵")
    plt.xlabel("预测值")
    plt.ylabel("真实值")
    plt.show()

# 直接调用函数进行测试
if __name__ == "__main__":
    digits_knn_cv()

输出:

4.3 结果解读

最优 n_neighbors:通过交叉验证,我们找到了在候选参数中使模型表现最佳的邻居数量。

测试集准确率:在手写数字识别任务上,通常准确率能达到90%以上,证明 KNN 在小数据集上也能表现不错。

混淆矩阵:直观展示哪些数字容易混淆(例如数字“3”和“5”),便于进一步分析和改进。

混淆矩阵图的含义与作用

1. 横纵坐标的含义

行(纵轴)代表真实标签(真实的数字 0~9)。

列(横轴)代表模型预测的标签(预测的数字 0~9)。

2. 数值和颜色深浅

• 单元格 (i, j) 内的数值表示:真实类别为 i 的样本中,有多少被预测为 j。

• 越靠近对角线(i = j)代表预测正确的数量;

• 离对角线越远,说明模型将真实类别 i 的样本错误地预测成类别 j。

• 热力图中颜色越深表示数量越多,浅色则表示数量少。

3. 作用

评估模型分类效果:如果对角线上的数值高且远离对角线的数值低,说明模型分类准确度高;反之,说明某些类别容易被混淆。

发现易混淆的类别:通过观察非对角线位置是否有较大的数值,可以知道哪些数字最容易被误判。例如,模型可能经常把“3”预测成“5”,这能提示我们在后续改进中加强这两个类别的区分。

比单纯的准确率更全面:准确率只能告诉你模型整体正确率,而混淆矩阵能告诉你哪类错误最多,便于更有针对性地提升模型性能。


五、总结 & 彩蛋

1. 交叉验证的价值

• 有效避免过拟合,通过多次分组验证,使得模型评估更稳健。

2. 网格搜索的强大

• 自动遍历所有超参数组合,省去手动调参的烦恼,快速锁定“最佳拍档”。

3. KNN 的局限

• 虽然简单易用,但在大规模、高维数据中计算量较大,且对异常值较敏感。

4. 后续进阶

• 可以尝试随机搜索(RandomizedSearchCV)或贝叶斯优化,甚至转向更复杂的模型如 CNN 进行数字识别。


结语

如果你觉得本篇文章对你有所帮助,请记得点赞、收藏、转发和评论哦!你的支持是我继续创作的最大动力。让我们一起在机器学习的道路上不断探索、不断进步,早日成为调参界的“神仙”!

祝学习愉快,炼丹顺利~

相关文章:

  • 【Mysql】索引
  • 深入解析「卡顿帧堆栈」 | UWA GPM 2.0 技术细节与常见问题
  • 独立开发者倾向于使用哪些技术栈
  • 知识篇 | DeepSeek企业部署模式主要有6种
  • 【AI面板识别】
  • 3.10 实战Hugging Face Transformers:从文本分类到模型部署全流程
  • 基于TI的TDA4高速信号仿真条件的理解 4.1 4.2
  • AI 安全时代:SDL与大模型结合的“王炸组合”——技术落地与实战指南
  • CTF-内核pwn入门1: linux内核模块基础原理
  • 25工商管理研究生复试面试问题汇总 工商管理专业知识问题很全! 工商管理复试全流程攻略 工商管理考研复试真题汇总
  • 在Windows系统中安装Open WebUI并连接Ollama
  • OpenCalib(四)水平棋盘格标靶检测
  • 【AI论文】InfiniteHiP:在单块GPU上将语言模型上下文扩展至300万个令牌
  • Windows 启动 SSH 服务
  • 华为路由器:链路聚合实验
  • PriorityBlockingQueue实现原理
  • Linux 文件与目录管理
  • 翻转硬币(思维题,巧用bitset)
  • 30道Qt面试题(答案公布)
  • 【Python项目】信息安全领域中语义搜索引擎系统
  • 被取消总统候选人资格,金文洙:将采取政治法律措施讨回公道
  • 巴基斯坦对印度发起网络攻击,致其约70%电网瘫痪
  • 拿出压箱底作品,北京交响乐团让上海观众享受音乐盛宴
  • 烈士沈绍藩遗孤、革命家帅孟奇养女舒炜逝世,享年96岁
  • 人民财评:网售“婴儿高跟鞋”?不能让畸形审美侵蚀孩子身心
  • 协会:坚决支持司法机关依法打击涉象棋行业的违法行为