当前位置: 首页 > news >正文

01_K近邻

描述

k-NN 算法可以说是最简单的机器学习算法。构建模型只需要保存训练数据集即可。

借助k-NN算法,理解机器学习的一些概念。

K-近邻(KNN)是一种基于实例的学习方法,预测时通过计算待预测样本与训练集中所有样本的距离,选取距离最近的 K 个邻居,并根据邻居的标签进行预测。

K近邻分类

导入库约定

import pandas as pd
import numpy as np

创建样例数据

import mglearn # 需要单独安装,主要是生成一些数据,帮助我们理解一些概念(生产环境不需要)X,Y = mglearn.datasets.make_forge()
# 数据集绘图
mglearn.discrete_scatter(X[:, 0], X[:, 1], Y)
plt.legend(["Class 0", "Class 1"], loc=4)
plt.xlabel("First feature")
plt.ylabel("Second feature")
print("X.shape: {}".format(X.shape))

K参数

k-NN 算法最简单的版本只考虑一个最近邻,也就是与想要预测的数据点最近的训练数据点。

mglearn.plots.plot_knn_classification(n_neighbors=1) # 执行后可以看到一个图

除了仅考虑最近邻,我还可以考虑任意个(k 个)邻居。这也是 k 近邻算法名字的来历。在考虑多于一个邻居的情况时,用“投票法”(voting)来指定标签。也就是说,对于每个测试点,数一数多少个邻居属于类别 0,多少个邻居属于类别 1。然后将出现次数更多的类别(也就是 k 个近邻中占多数的类别)作为预测结果。

mglearn.plots.plot_knn_classification(n_neighbors=3)# 执行后可以看到一个图,与n_neighbors=1的情况下对比,看一下差异就明白了

训练数据

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier# 将数据分为训练集和测试集,以便评估泛化性能
# random_state 是一个用于控制随机性的参数。它的主要作用是确保每次运行代码时,能够生成相同的随机数序列,从而保证结果的可重复性(目前取值随意,对结果没有多大影响)
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,random_state=0) # train_test_split将数据分为# 例化KNeighborsClassifier分类器。这时可以设定参数,比如邻居(n_neighbors)的个数
clf = KNeighborsClassifier(n_neighbors=3)clf.fit(X_train,Y_train) # 训练分类器,对于 KNeighborsClassifier 来说就是保存数据集

模型评估

predict 方法来对测试数据进行预测。对于测试集中的每个数据点,都要计算它在训练集的最近邻,然后找出其中出现次数最多的类别

print("Test set predictions: {}".format(clf.predict(X_test)))

为了评估模型的泛化能力好坏,可以对测试数据和测试标签调用 score 方法

print("Test set accuracy: {:.2f}".format(clf.score(X_test, Y_test))) # 准确率评估

预测分类

现在可以用clf去预测一个新数据的类别了

x_new = [[5.678,2.000]] # 注意格式是二维数组
prediction = clf.predict(x_new)
print('Predicted target :{}'.format(Y[prediction]))

一些概念

泛化

如果一个模型能够对没见过的数据做出准确预测,我们就说它能够从训练集泛化(generalize)到测试集。

过拟合

构建一个对现有信息量来说过于复杂的模型就会出现过拟合 (overfitting)。如果你在拟合模型时过分关注训练集的细节,得到了一个在训练集上表现很好、但不能泛化到新数据上的模型,那么就存在过拟合。

欠拟合

过拟合相反,选择过于简单的模型被称为欠拟合(underfitting)

鸢尾花的例子

import pandas as pd
import numpy as np
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifieriris = pd.read_csv(r'..\..\seaborn-data\iris.csv') # 加载数据
sns.pairplot(iris,hue='species') # 分析数据 通过sepl_length、sepal_length、petal_length、petal_width 可以区分出类别# 处理iris,提取出类别、特征数据
# iris.groupby('species').groups.keys()
iris_class = ['setosa', 'versicolor', 'virginica']
iris_tz_array=iris.select_dtypes(include='number')
iris_class_array=iris['species'].map({'setosa':0, 'versicolor':1, 'virginica':2})# 训练数据
X_train,X_test,Y_train,Y_test = train_test_split(iris_tz_array,iris_class_array,random_state=0)
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train,Y_train)# 模型评估
# knn.predict(X_test) # 返回预测结果,可以与Y_test对比
knn.score(X_test,Y_test) # 匹配度(精度)# 预测一个新的分类
x_new= [[5,2.8,1,0.3]]
iris_class[knn.predict(x_new)[0]]

K近邻回归

对于二维数据集,还可以在 xy 平面上画出所有可能的测试点的预测结果。根据平面中每个点所属的类别对平面进行着色。这样可以查看决策边界(decision boundary),即算法对类别 0 和类别 1 的分界线。

fig, axes = plt.subplots(1, 3, figsize=(10, 3))for n_neighbors, ax in zip([1, 3, 9], axes):# fit方法返回对象本身,所以我们可以将实例化和拟合放在一行代码中clf = KNeighborsClassifier(n_neighbors=n_neighbors).fit(X, Y)mglearn.plots.plot_2d_separator(clf, X, fill=True, eps=0.5, ax=ax, alpha=.4)mglearn.discrete_scatter(X[:, 0], X[:, 1], Y, ax=ax)ax.set_title("{} neighbor(s)".format(n_neighbors))ax.set_xlabel("feature 0")ax.set_ylabel("feature 1")
axes[0].legend(loc=3)

使用单一邻居绘制的决策边界紧跟着训练数据。随着邻居个数越来越多,决策边界也越来越平滑。更平滑的边界对应更简单的模型。

X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, stratify=cancer.target, random_state=66)
training_accuracy = []
test_accuracy = []
# n_neighbors取值从1到10
neighbors_settings = range(1, 11)
for n_neighbors in neighbors_settings:# 构建模型clf = KNeighborsClassifier(n_neighbors=n_neighbors)clf.fit(X_train, y_train)# 记录训练集精度training_accuracy.append(clf.score(X_train, y_train))# 记录泛化精度test_accuracy.append(clf.score(X_test, y_test))plt.plot(neighbors_settings, training_accuracy, label="training accuracy")
plt.plot(neighbors_settings, test_accuracy, label="test accuracy")
plt.ylabel("Accuracy")
plt.xlabel("n_neighbors")
plt.legend()

图像的 x 轴是 n_neighbors,y 轴是训练集精度和测试集精度。虽然现实世界的图像很少有非常平滑的,但我们仍可以看出过拟合与欠拟合的一些特征

k 近邻算法还可以用于回归。我们还是先从单一近邻开始,这次使用 wave 数据集。我们添加了 3 个测试数据点,在 x 轴上用绿色五角星表示。利用单一邻居的预测结果就是最近邻的目标值(尝试执行下面两段代码)

mglearn.plots.plot_knn_regression(n_neighbors=1)
mglearn.plots.plot_knn_regression(n_neighbors=3)

用 于 回 归 的 k 近 邻 算 法 在 scikit-learn 的 KNeighborsRegressor 类 中 实 现。 其 用 法 与KNeighborsClassifier 类似:

from sklearn.neighbors import KNeighborsRegressorX, Y = mglearn.datasets.make_wave(n_samples=40)# 将wave数据集分为训练集和测试集
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state=0)
# 模型实例化,并将邻居个数设为3
reg = KNeighborsRegressor(n_neighbors=3)
# 利用训练数据和训练目标值来拟合模型
reg.fit(X_train, Y_train)# 模型评估
print("Test set predictions:\n{}".format(reg.predict(X_test)))
print("Test accuracy: {:.2f}".format(reg.score(X_test, Y_test)))

鸢尾花的例子

import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressoriris = pd.read_csv(r'..\..\seaborn-data\iris.csv')
iris_class = ['setosa', 'versicolor', 'virginica']
iris_tz_array=iris.select_dtypes(include='number')
iris_class_array=iris['species'].map({'setosa':0, 'versicolor':1, 'virginica':2})X_train,X_test,Y_train,Y_test = train_test_split(iris_tz_array,iris_class_array,random_state=0)
knn = KNeighborsRegressor(n_neighbors=3)
knn.fit(X_train,Y_train)knn.score(X_test,Y_test) # 可以调整一下n_neighbors的值,看看匹配度(精度)的变化# 预测一个新的分类
x_new= [[5,2.8,1,0.3]]
iris_class[int(knn.predict(x_new)[0])]
# 虽然只是将KNeighborsClassifier换成了KNeighborsRegressor ,但是匹配度(精度)有一定的差异

总结

k-NN 的优点之一就是模型很容易理解,通常不需要过多调节就可以得到不错的性能。在考虑使用更高级的技术之前,尝试此算法是一种很好的基准方法。构建最近邻模型的速度通常很快,但如果训练集很大(特征数很多或者样本数很大),预测速度可能会比较慢。

虽然 k 近邻算法很容易理解,但由于预测速度慢且不能处理具有很多特征的数据集,所以在实践中往往不会用到。

相关文章:

  • 网络基础-----C语言经典题目(12)
  • kivy android打包buildozer.spec GUI配置
  • LeetCode 1295.统计位数为偶数的数字:模拟
  • 4:机器人目标识别无序抓取程序二次开发
  • 4.30阅读
  • 变量char2、*char2、pChar3、*pChar3的存储位置
  • Qwen3-32B的幻觉问题
  • uv安装及使用
  • C++初阶-string类2
  • Vue Router路由原理
  • 网工_ICMP协议
  • ZYNQ MPSOC之PL与PS数据交互DMA方式
  • MCP 服务器搭建【sse 类型】实现上市公司年报查询总结, 127.0.0.1:8000/sse直接配置配合 Cherry Studio使用简单
  • 讯飞星辰焕新发布!Agent规模化应用的通关密码
  • 学习笔记——《Java面向对象程序设计》-常用实用类
  • 复刻低成本机械臂 SO-ARM100 材料齐活篇
  • 欧拉计划 Project Euler61(循环的多边形数)题解
  • Java中的多态与继承
  • 共筑数字经济新生态 共绘数字中国新蓝图 ——思特奇受邀出席2025年第八届数字中国建设峰会
  • 动画震动效果
  • 软硬件企业集中发布未成年人模式使用手册
  • 【社论】人工智能,年轻的事业
  • 发布亮眼一季度报后,东阿阿胶股价跌停:现金流隐忧引发争议
  • 美国“杜鲁门”号航母一战机坠海
  • 餐饮店直播顾客用餐,律师:公共场所并非无隐私,需对方同意
  • 外交部回应涉长江和记出售巴拿马运河港口交易:望有关各方审慎行事,充分沟通