当前位置: 首页 > news >正文

【机器学习】深入浅出KNN算法:原理解析与实践案例分享

在机器学习中,K-最近邻算法(K-Nearest Neighbors, KNN)是一种既直观又实用的算法。它既可以用于分类,也可以用于回归任务。本文将简单介绍KNN算法的基本原理、优缺点以及常见应用场景,并通过一个简单案例帮助大家快速入门。

1. KNN算法简介

KNN算法基于一个非常直观的思想:对于一个未知类别的数据点,可以通过查看它在特征空间中距离最近的K个邻居的类别或数值信息,来决定该数据点的类别或预测其值。算法的主要步骤如下:

1. 计算距离:常用的距离度量方法有欧氏距离、曼哈顿距离等。对于一个待预测的数据点,计算它与训练集中所有数据点的距离。

2. 选择最近邻:根据计算得到的距离,选取距离最小的K个数据点。

3. 决策机制

分类:采用投票机制,将待预测点归为K个邻居中出现频率最高的类别。

回归:计算K个邻居的数值平均值或加权平均值,作为预测结果。

由于KNN算法没有显式的训练过程,所以它属于一种懒惰学习(Lazy Learning)方法,即在训练阶段只存储数据,在预测时才进行计算。


2. KNN的优缺点

优点

简单易懂:KNN算法实现简单,容易理解,非常适合初学者入门机器学习。

无需训练过程:KNN不需要构建复杂的模型,直接利用存储的训练数据进行预测。

适应性强:既可以用于分类问题,也可以用于回归问题,具有较强的通用性。

缺点

计算成本高:当数据量较大时,每次预测都需要计算与所有训练样本之间的距离,计算量较大。

对噪声敏感:噪声数据或异常点可能会影响预测结果,尤其是当K值较小时。

数据不平衡问题:在类别分布不平衡的情况下,少数类可能会被多数类所掩盖,影响模型效果。


3. 应用场景

KNN算法在许多领域都有应用,包括但不限于:

手写数字识别:利用KNN对手写数字图片进行分类,实现简单而高效的数字识别。

推荐系统:基于用户相似性推荐商品或电影,利用KNN寻找兴趣相似的用户。

医学诊断:通过分析病人数据,预测疾病类别或风险值。

回归预测:例如房价预测,通过相似特征房屋的历史价格进行估值。


4. 实战案例:KNN分类

下面通过一个简单的案例,使用Python和scikit-learn库对Iris数据集进行KNN分类,帮助大家直观了解KNN的实际应用。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 加载Iris数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建KNN分类器,设置K值为3
knn_classifier = KNeighborsClassifier(n_neighbors=3)
knn_classifier.fit(X_train, y_train)

# 对测试集进行预测
y_pred = knn_classifier.predict(X_test)

# 计算并输出准确率
accuracy = accuracy_score(y_test, y_pred)
print("KNN分类器在Iris数据集上的准确率:{:.2f}%".format(accuracy * 100))

运行上述代码,你将会看到KNN分类器在Iris数据集上的表现。通过调整K值或选择不同的距离度量方式,可以进一步优化模型效果。


下面给出两个案例,分别使用在线下载的数据集,演示如何用 KNN 实现分类和回归。我们分别用 OpenML 上的 Iris 数据集(分类)和 scikit-learn 内置的 California Housing 数据集(回归)来说明。

案例 1:KNN 分类(Iris 数据集)

我们通过 fetch_openml 从 OpenML 下载 Iris 数据集,然后用 KNeighborsClassifier 进行分类,并输出预测准确率。

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 下载 Iris 数据集(注意:as_frame=True 会返回 Pandas DataFrame 格式)
iris = fetch_openml(name='iris', version=1, as_frame=True)
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构造并训练 KNN 分类器(这里取 k=3)
knn_classifier = KNeighborsClassifier(n_neighbors=3)
knn_classifier.fit(X_train, y_train)

# 对测试集进行预测
y_pred = knn_classifier.predict(X_test)

# 输出分类准确率
print("KNN 分类器准确率:", accuracy_score(y_test, y_pred))

运行该代码后,会输出模型在测试集上的准确率,说明 KNN 分类器在 Iris 数据集上的表现。


案例 2:KNN 回归(California Housing 数据集)

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error

# 下载 California Housing 数据集
housing = fetch_california_housing(as_frame=True)
X = housing.data
y = housing.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构造并训练 KNN 回归器(这里取 k=5)
knn_regressor = KNeighborsRegressor(n_neighbors=5)
knn_regressor.fit(X_train, y_train)

# 对测试集进行预测
y_pred = knn_regressor.predict(X_test)

# 计算并输出均方误差(MSE)
mse = mean_squared_error(y_test, y_pred)
print("KNN 回归器的均方误差:", mse)

运行该代码后,将输出模型在 California Housing 数据集上预测的均方误差,从而评估回归效果。

以上两个案例分别展示了如何利用在线数据和 scikit-learn 中的 KNN 模型进行分类和回归任务。根据具体问题的特点,可以调整 k 值、数据预处理及评估指标以获得更好的效果。


5. 总结

KNN算法因其简单直观而在入门机器学习时备受推崇,虽然在大规模数据和高维数据上存在计算和噪声问题,但其易于实现和理解的特点,使其成为很多初学者和实际应用场景中的不错选择。通过本文的介绍,希望大家对KNN算法有了基本的认识,并能在实践中灵活运用。


如果你有任何问题或想进一步讨论,欢迎在评论区留言交流!

希望这篇文章能帮助你快速上手KNN算法,开启机器学习之旅。

   

相关文章:

  • Dav_笔记14:优化程序提示 HINTs -4
  • 自动驾驶---基于深度学习模型的轨迹预测
  • TS语言自定义脚手架
  • 神经网络新手入门(1)目录
  • 责任链模式解析FilterChain
  • 2000-2020年年汇率平均价数据
  • Ubuntu 22.04.5 LTS 安装企业微信,(2025-02-17安装可行)
  • 二十多年前的苹果电源Power Mac G4 Mdd 电源接口
  • 宝塔docker 安装oracle11G
  • 【097】基于51单片机排队叫号系统【Keil程序+报告+原理图】
  • 4.【线性代数】——矩阵的LU分解
  • STC 51单片机63——关于STC8H的ADC通道切换问题
  • 软硬链接?
  • 附录2:组维接口信息大全
  • 过于依赖chatgpt编程会有哪些弊端?
  • IOT-CVE-2018-17066(D-Link命令注入漏洞)
  • ubuntu22.04安装kvm、virt-manage并配置SR-IOV操作
  • Spring Boot 启动优化✨
  • TCP协议(Transmission Control Protocol)
  • Kubernetes控制平面组件:Kubernetes如何使用etcd
  • 云主机如何做两个网站/今日国际新闻最新消息事件
  • 北京公司网站制作要多少钱/浏览器谷歌手机版下载
  • 个人备案的网站能做什么/怎么卸载windows优化大师
  • 网站建设比较好/网页设计代码案例
  • 小程序开发者工具下载/济南seo外贸网站建设
  • 法院网站建设方案/seo高级教程