当前位置: 首页 > news >正文

Python 机器学习核心入门与实战进阶 Day 2 - KNN(K-近邻算法)分类实战与调参

✅ 今日目标

  • 理解 KNN 的原理与“以邻为近”的思想
  • 掌握 K 值选择与模型效果的关系
  • 学会使用 sklearn 训练 KNN 模型
  • 实现 KNN 分类 + 模型评估 + 超参数调优

📘 一、KNN 算法原理

KNN(K-Nearest Neighbors)核心思想:

给定一个待预测样本,找到训练集中“距离它最近”的 K 个样本,用这些样本的类别进行多数投票预测。

特点描述
模型类型懒惰学习(无显式训练过程)
距离度量欧几里得距离(默认)或自定义
参数调优K 值、距离函数、权重方式
适用场景数据量不大,维度不高,需快速建模时

🧪 二、KNN 分类流程(代码实践)

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score# 生成数据
X = [[1], [2], [3], [10], [11], [12]]
y = [0, 0, 0, 1, 1, 1]# 训练测试划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)# 建模
model = KNeighborsClassifier(n_neighbors=3)
model.fit(X_train, y_train)# 预测
y_pred = model.predict(X_test)
print("准确率:", accuracy_score(y_test, y_pred))

🧠 三、K 值选择对模型的影响

K 值模型表现
K 太小模型过拟合,受噪声影响大
K 太大模型过于平滑,泛化能力下降
一般建议使用奇数,避免投票平局;通过交叉验证选择最佳 K

🔧 四、模型调参建议(使用 GridSearchCV)

from sklearn.model_selection import GridSearchCVparam_grid = {'n_neighbors': list(range(1, 11))}
grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
grid_search.fit(X_train, y_train)print("最优K值:", grid_search.best_params_)
print("最佳准确率:", grid_search.best_score_)

🧾 今日总结

技能工具
快速建模KNeighborsClassifier
评估效果accuracy_score()
参数调优GridSearchCV()
可视化分类边界使用 matplotlibseaborn

🧪 建议练习脚本

  • 使用 sklearn 中的 KNN 模型实现学生是否及格分类

  • 尝试多种 K 值进行训练,并绘制准确率变化图

  • 使用 GridSearchCV 找出最优 K

  • 可视化分类边界(二维特征时)

    # KNN 分类实战演示:学生是否及格预测from sklearn.datasets import make_classification
    from sklearn.model_selection import train_test_split, GridSearchCV
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.metrics import accuracy_score, classification_report
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as npplt.rcParams['font.family'] = 'Arial Unicode MS'  # Mac 用户可用
    plt.rcParams['axes.unicode_minus'] = False
    # 1. 模拟学生成绩数据(两个特征:成绩 + 性别)
    np.random.seed(42)
    size = 100
    scores = np.random.randint(40, 100, size)
    genders = np.random.choice([0, 1], size=size)  # 0=女, 1=男
    pass_label = (scores >= 60).astype(int)X = np.column_stack(((scores - scores.mean()) / scores.std(), genders))  # 标准化+性别
    y = pass_label# 2. 拆分训练集与测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 3. 不同 K 值准确率比较
    acc_list = []
    k_values = range(1, 16)for k in k_values:model = KNeighborsClassifier(n_neighbors=k)model.fit(X_train, y_train)y_pred = model.predict(X_test)acc = accuracy_score(y_test, y_pred)acc_list.append(acc)# 4. 可视化不同 K 值的准确率
    plt.plot(k_values, acc_list, marker='o', linestyle='--')
    plt.title("不同 K 值下的准确率")
    plt.xlabel("K 值")
    plt.ylabel("准确率")
    plt.xticks(k_values)
    plt.grid(True)
    plt.tight_layout()
    plt.show()# 5. 使用 GridSearchCV 找最佳 K
    param_grid = {'n_neighbors': list(range(1, 16))}
    grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
    grid_search.fit(X_train, y_train)print("✅ 最佳 K 值:", grid_search.best_params_)
    print("📋 最佳交叉验证准确率:", grid_search.best_score_)# 6. 在测试集上评估
    best_model = grid_search.best_estimator_
    y_pred = best_model.predict(X_test)print("\\n=== 最终模型评估(测试集) ===")
    print("准确率:", accuracy_score(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    

    运行输出:

    在这里插入图片描述

    ✅ 最佳 K 值: {'n_neighbors': 1}
    📋 最佳交叉验证准确率: 0.9875
    \n=== 最终模型评估(测试集) ===
    准确率: 0.95precision    recall  f1-score   support0       0.88      1.00      0.93         71       1.00      0.92      0.96        13accuracy                           0.95        20macro avg       0.94      0.96      0.95        20
    weighted avg       0.96      0.95      0.95        20
    
http://www.dtcms.com/a/266559.html

相关文章:

  • 【MATLAB代码】AOA与TDOA混合定位例程,适用于三维环境、4个锚点的情况,订阅专栏后可以获得完整代码
  • 计算机网络笔记(不全)
  • Windows 本地安装部署 Apache Druid
  • 无人机载重模块技术要点分析
  • Science Robotics发表 | 20m/s自主飞行+避开2.5mm电线的微型无人机!
  • CSS长度单位问题
  • 通过Claude 生成图片的prompt集锦(一)
  • 7.4项目一问题准备
  • 实验五-Flask的简易登录系统
  • 数据结构 之 【堆】(堆的概念及结构、大根堆的实现、向上调整法、向下调整法)(C语言实现)
  • K8s服务发布基础
  • CI/CD持续集成与持续部署
  • 基于大模型的强直性脊柱炎全周期预测与诊疗方案研究
  • 力扣面试150(15/150)
  • 7.4 arm作业
  • 玩转n8n工作流教程(一):Windows系统本地部署n8n自动化工作流(n8n中文汉化)
  • 全平台兼容+3倍加载提速:GISBox将重新定义三维可视化标准
  • Java 实现excel大批量导出
  • 什么是金字塔思维?
  • 三体融合实战:Django+讯飞星火+Colossal-AI的企业级AI系统架构
  • RK-Android11-系统增加一个属性值
  • 【HDMI CEC】 设备 OSD 名称功能详解
  • 《设计模式之禅》笔记摘录 - 3.工厂方法模式
  • 【modbus学习笔记】Modbus协议解析
  • WPF学习(四)
  • 分布式集合通信--学习笔记
  • ComfyUI工作流:一键换背景体验不同场景
  • 如何搭建 OLAP 系统?OLAP与数据仓库有什么关系?
  • 2-2 PID-代码部分
  • Fiddler 中文版怎么配合 Postman 与 Wireshark 做多环境接口调试?