当前位置: 首页 > news >正文

【机器学习算法篇】K-近邻算法

K-近邻算法

文章目录

  • K-近邻算法
    • 一、算法概述
      • 1.1 什么是K近邻算法
    • 二、距离度量方式
      • 1. 欧式距离
      • 2. 曼哈顿距离
      • 3. 切比雪夫距离
      • 4. 闵可夫斯基距离
      • 5. 标准化欧氏距离
      • 6. 余弦相似度
      • 7. 马氏距离
    • 三、机器学习库Scikit-learn
      • 1. Scikit-learn包含目录
      • 2. 任务类别
      • 3. Scikit-learn 的主要模块结构
      • 4. Scikit-learn 的典型流程
      • 5. K值的选择
    • 四、KNN算法的高效实现 —— KD树与Ball Tree
      • 4.1 为什么需要加速结构
    • 五、机器学习完整流程(以鸢尾花数据集为例)
      • 5.1 获取数据集
      • 5.2 划分训练集与测试集
      • 5.3 特征标准化
      • 5.4 建立模型并训练
      • 5.5 模型评估
    • 六、超参数优化:交叉验证与网格搜索

一、算法概述

1.1 什么是K近邻算法

K-近邻算法,也叫KNN算法,是一种基于实例的监督学习算法

核心思想:“物以类聚”——如果一个样本的特征与某个类别的样本相似,那么它很可能也属于这个类别。

算法的基本逻辑如下:

  1. 计算待分类样本与训练集中各样本的距离;
  2. 选取距离最近的K个邻居;
  3. 统计这K个邻居的类别出现频率;
  4. 将频率最高的类别作为预测结果。

KNN由Cover与Hart在1968年提出,至今仍是机器学习中最基础且常用的分类算法之一。

二、距离度量方式

KNN中最关键的一步是——如何衡量样本之间的“距离”。不同的距离度量方式,会直接影响分类结果。

1. 欧式距离

欧氏距离是我们小初高接触到的距离度量方法,只是我们接触到的维度较低,该距离度量方法直观易于理解。
二维平面上点 a ( x 1 , y 1 ) 与 b ( x 2 , y 2 ) 间的欧氏距离 : d 12 = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 二维平面上点a(x_{1},y_{1})与b(x_{2},y_{2})间的欧氏距离:d_{12} = \sqrt{(x_{1}-x_{2})^{2}+(y_{1}-y_{2})^{2}} 二维平面上点a(x1,y1)b(x2,y2)间的欧氏距离:d12=(x1x2)2+(y1y2)2

三维空间点 a ( x 1 , y 1 , z 1 ) 与 b ( x 2 , y 2 , z 2 ) 间的欧氏距离 : d 12 = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 + ( z 1 − z 2 ) 2 三维空间点a(x_{1},y_{1},z_{1})与b(x_{2},y_{2},z_{2})间的欧氏距离:d_{12} = \sqrt{(x_{1}-x_{2})^{2}+(y_{1}-y_{2})^{2}+(z_{1}-z_{2})^{2}} 三维空间点a(x1,y1,z1)b(x2,y2,z2)间的欧氏距离:d12=(x1x2)2+(y1y2)2+(z1z2)2

n 维空间点 a ( x 11 , x 12 , … , x 1 n ) 与 b ( x 21 , x 22 , … , x 2 n ) 间的欧氏距离(两个 n 维向量) : d 12 = ∑ k = 1 n ( x 1 k − x 2 k ) 2 n维空间点a(x_{11},x_{12},\ldots,x_{1n})与b(x_{21},x_{22},\ldots,x_{2n})间的欧氏距离(两个n维向量):d_{12} = \sqrt{\sum_{k=1}^{n}(x_{1k}-x_{2k})^{2}} n维空间点a(x11,x12,,x1n)b(x21,x22,,x2n)间的欧氏距离(两个n维向量):d12=k=1n(x1kx2k)2

2. 曼哈顿距离

顾名思义,“曼哈顿距离”源于曼哈顿整齐的方格式街区规划。想象一下,从一个十字路口到另一个,你无法斜穿大楼,只能驾车沿着棋盘般的街道曲折前行。这段实际行驶路径就是“曼哈顿距离”,它也因此得名“城市街区距离”。
d ( a , b ) = ∣ a 1 − b 1 ∣ + ∣ a 2 − b 2 ∣ + . . . + ∣ a n − b n ∣ d(a,b) = |a_1-b_1| + |a_2-b_2| + ... + |a_n-b_n| d(a,b)=a1b1+a2b2+...+anbn


3. 切比雪夫距离

国际象棋中的国王是“八方之王”,可以直行、横行或斜向移动,其一步之遥的领地覆盖了相邻的八个方格。由此,国王从 (x1, y1) 抵达 (x2, y2) 所需的最少步数,在数学上被定义为“切比雪夫距离”。

d ( a , b ) max ⁡ i ∣ a i − b i ∣ d(a,b)\max_i |a_i - b_i| d(a,b)imaxaibi
在国际象棋中,国王走一步的最远距离即是切比雪夫距离。


4. 闵可夫斯基距离

闵氏距离不是一个单一的距离标准,而是一个“距离家族”。它为我们提供了一个统一的数学框架,用以概括和生成多种具体的距离度量公式。

d ( a , b ) = ( ∑ i ∣ a i − b i ∣ p ) 1 / p d(a,b) = \left(\sum_i |a_i - b_i|^p\right)^{1/p} d(a,b)=(iaibip)1/p

当 p = 1 是曼哈顿距离;当 p = 2 是欧式距离;当 p → ∞ 是切比雪夫距离。 当 p=1 是曼哈顿距离;当p=2 是欧式距离;当 p→∞是切比雪夫距离。 p=1是曼哈顿距离;当p=2是欧式距离;当p是切比雪夫距离。


5. 标准化欧氏距离

当各特征的量纲不一致时,欧氏距离可能失真。
因此常用标准化处理:
d ( a , b ) = ∑ i ( a i − b i ) 2 s i 2 d(a,b)=\sqrt{\sum_i \frac{(a_i-b_i)^2}{s_i^2}} d(a,b)=isi2(aibi)2

其中 s i 为特征的标准差。 其中 s_i 为特征的标准差。 其中si为特征的标准差。

6. 余弦相似度

夹角余弦衡量的是向量方向的“亲密程度”,其取值范围在-1到1之间。值越接近1,意味着二者方向越一致,如同“同舟共济”;值越接近-1,则意味着方向越相反,如同“背道而驰”。该值为0时,表示二者方向垂直。

适用于文本或高维稀疏向量

cos ⁡ ( θ ) = a ⋅ b ∥ a ∥ ∥ b ∥ \cos(\theta)=\frac{a\cdot b}{\|a\|\|b\|} cos(θ)=a∥∥bab

案例

假设我们有三个文本片段:

  • 文档A: “人工智能改变世界”
  • 文档B: “人工智能塑造未来”
  • 文档C: “今天 天气很好”

第一步:构建词袋模型
首先,我们列出所有文档中出现的不重复的词,形成一个词汇表:
["人工智能", "改变", "世界", "塑造", "未来", "今天", "天气", "很好"]

第二步:文本向量化
我们将每个文档都转化为一个基于该词汇表的向量。向量的每个维度对应一个词,其值可以是该词在文档中出现的次数(词频)。

  • 文档A向量: [1, 1, 1, 0, 0, 0, 0, 0]
    • ("人工智能"出现1次, "改变"出现1次, "世界"出现1次,其他词均未出现)
  • 文档B向量: [1, 0, 0, 1, 1, 0, 0, 0]
    • ("人工智能"出现1次, "塑造"出现1次, "未来"出现1次)
  • 文档C向量: [0, 0, 0, 0, 0, 1, 1, 1]
    • (“今天”、“天气”、"很好"各出现1次)

第三步:计算余弦相似度
现在我们使用夹角余弦公式来计算每对文档向量的相似度。

  • 文档A vs 文档B:
    • 它们共享一个共同的词"人工智能"。
    • 计算其向量夹角的余弦值,结果约为 0.33
    • 解读:虽然它们用词不完全相同,但都围绕“人工智能”这一核心主题,因此表现出一定的相似性。
  • 文档A vs 文档C:
    • 它们没有任何共同的词汇。
    • 计算其向量夹角的余弦值,结果为 0
    • 解读:它们的向量在向量空间中相互垂直,意味着内容完全不相关。
  • 文档B vs 文档C:
    • 同样没有共同词汇,余弦相似度也为 0

通过这个例子,我们可以看到:

  • 余弦值接近1:表示文本内容非常相似。
  • 余弦值接近0:表示文本内容没有关联,主题迥异。
  • 余弦值接近-1:在文本向量中很少出现,因为词频不能为负。

因此,余弦相似度通过将文本转化为向量并计算其方向夹角,巧妙地规避了文本长度差异的干扰,专注于捕捉内容主题上的一致性,从而成为一种强大且高效的文本相似度度量方法。

7. 马氏距离

马氏距离是一种由印度统计学家马哈拉诺比斯提出的统计距离。它不仅是计算两点间的几何距离,更是度量它们相对于整个数据集分布的位置相似度。其关键优势在于,它独立于测量尺度,并能通过协方差矩阵排除各特征间相关性的干扰,从而提供比欧氏距离更合理的相似度判断。

d ( a , b ) = ( a − b ) T Σ − 1 ( a − b ) d(a,b) = \sqrt{(a-b)^T \Sigma^{-1} (a-b)} d(a,b)=(ab)TΣ1(ab)
其中 Σ是协方差矩阵。该距离在多维统计分析中十分常见。

三、机器学习库Scikit-learn

在这里插入图片描述

Scikit-learn是 Python 中最常用的机器学习库之一,提供了从数据预处理到模型训练、评估、优化的一整套工具。

  • 封装了C/C++后端的算法实现,性能优秀。
  • 统一的API设计,包括许多知名的机器学习算法的实现。
  • 涵盖分类、回归、聚类、降维、特征工程等常见任务。

官方网址

1. Scikit-learn包含目录

在这里插入图片描述

2. 任务类别

任务类别定义应用场景常用算法
分类识别对象属于哪个类别垃圾邮件检测、图像识别梯度提升、最近邻、随机森林、逻辑回归等
回归预测与对象相关的连续值属性药物反应、股票价格梯度提升、最近邻、随机森林、岭回归等
模型选择比较、验证和选择参数和模型通过参数调整提高准确率网格搜索、交叉验证、指标等
聚类自动将相似的对象分组客户细分、实验结果分组k-Means、HDBSCAN、层次聚类等
预处理特征提取和规范化转换文本等输入数据以供算法使用预处理、特征提取等
降维减少需要考虑的随机变量的数量可视化、提高效率主成分分析 (PCA)、特征选择、非负矩阵分解等

3. Scikit-learn 的主要模块结构

模块功能描述类/函数案例
sklearn.datasets内置与外部数据集接口load_iris(), fetch_20newsgroups()
sklearn.preprocessing特征工程与数据预处理StandardScaler, MinMaxScaler
sklearn.model_selection数据划分与模型选择train_test_split, GridSearchCV
sklearn.neighborsK近邻算法模块KNeighborsClassifier, KNeighborsRegressor
sklearn.tree决策树算法DecisionTreeClassifier, export_graphviz
sklearn.ensemble集成学习模型RandomForestClassifier, GradientBoosting
sklearn.linear_model线性/逻辑回归模型LinearRegression, LogisticRegression
sklearn.svm支持向量机SVC, SVR
sklearn.cluster聚类算法KMeans, DBSCAN
sklearn.metrics模型评估指标accuracy_score, confusion_matrix
sklearn.decomposition降维算法PCA, TruncatedSVD

4. Scikit-learn 的典型流程

几乎所有机器学习任务在 sklearn 中都可以遵循统一的五步流程:

步骤说明典型函数
1️⃣ 获取数据从文件或内置数据集中加载datasets.load_iris()
2️⃣ 数据预处理标准化、编码、特征选择StandardScaler, LabelEncoder
3️⃣ 构建模型选择算法模型KNeighborsClassifier()
4️⃣ 模型训练使用训练集进行学习.fit(X_train, y_train)
5️⃣ 模型评估在测试集上评估性能.score(X_test, y_test)

5. K值的选择

knn = KNeighborsClassifier(n_neighbors=5) 

n_neighbors:int,可选,默认= 5。

K的取值对模型性能影响极大

  • K过小: 容易受噪声影响,过拟合;
  • K过大: 平滑过度,容易欠拟合。

策略选择:

  • 通过交叉验证选择最优K;
  • 常用范围:3~15;
  • 若类别分布不平衡,可加入距离加权,即距离近的邻居权重更大。

四、KNN算法的高效实现 —— KD树与Ball Tree

4.1 为什么需要加速结构

KNN的缺点之一是计算量大
在每次预测时都需要计算所有训练样本的距离,时间复杂度为 O(ND)。

为此,可通过空间划分结构加速搜索:

  • KD树:适合低维数据;
  • Ball Tree:适合高维、稠密数据。

KD树通过不断用垂直于坐标轴的超平面划分空间,从而在搜索时能快速排除不可能的邻居。

五、机器学习完整流程(以鸢尾花数据集为例)

5.1 获取数据集

from sklearn.datasets import load_iris # 导入iris数据集加载器
iris = load_iris() # 加载iris数据集

5.2 划分训练集与测试集

from sklearn.model_selection import train_test_split # 导入训练集和测试集分割工具
# 将数据分割为训练集和测试集,测试集占比20%
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=22) # 进行数据分割

5.3 特征标准化

from sklearn.preprocessing import StandardScaler # 导入标准化器
# 初始化标准化器
scaler = StandardScaler() # 创建StandardScaler对象
# 对训练数据进行标准化
x_train = scaler.fit_transform(x_train) # 对训练数据进行拟合和转换
# 对测试数据进行标准化
x_test = scaler.transform(x_test) # 对测试数据进行转换

5.4 建立模型并训练

from sklearn.neighbors import KNeighborsClassifier # 导入KNeighborsClassifier分类器
# 初始化一个K近邻分类器,设置近邻数为5
knn = KNeighborsClassifier(n_neighbors=5) # 创建KNeighborsClassifier对象。n_neighbors:int,可选(默认= 5)
# 使用训练数据训练KNN分类器
knn.fit(x_train, y_train) # 使用训练数据训练模型

5.5 模型评估

# 评估KNN分类器在测试集上的准确率
print("预测准确率:", knn.score(x_test, y_test))

预测准确率: 0.9333333333333333

六、超参数优化:交叉验证与网格搜索

KNN算法中最重要的超参数是K值

可通过 GridSearchCV 结合交叉验证自动寻找最优组合:

from sklearn.model_selection import GridSearchCV# 定义要搜索的超参数网格
param_grid = {"n_neighbors": [1,3,5,7,9]}
# 初始化GridSearchCV,使用KNN分类器,超参数网格,并进行5折交叉验证
grid = GridSearchCV(KNeighborsClassifier(), param_grid=param_grid, cv=5)
# 在训练数据上执行网格搜索
grid.fit(x_train, y_train)# 打印网格搜索找到的最佳超参数
print("最优参数:", grid.best_params_)
# 打印网格搜索期间获得的最高准确率
print("最高准确率:", grid.best_score_)

最优参数: {‘n_neighbors’: 5} 最高准确率: 0.9583333333333333

http://www.dtcms.com/a/491007.html

相关文章:

  • K8S高可用集群-二进制部署 + nginx-proxy静态Pod版本
  • 使用Open CASCADE和BRepOffsetAPI_MakePipeShell创建螺旋槽钻头三维模型
  • 邯郸网站制作多少钱网站如何收录
  • 如何区分Android、Android Automotive、Android Auto
  • 企业融资方式有哪几种淄博网站seo价格
  • python - 第五天
  • 凡科网的网站建设怎么做网站 建设 公司
  • 透过浏览器原理学习前端三剑客:HTML、CSS与JavaScript
  • 镇江市网站建设江西省建设厅教育网站上查询
  • dede网站怎么设置首页相亲网站透露自己做理财的女生
  • Docker在已经构建好的镜像中安装包
  • 智慧物流赛项竞赛内容与技能要求深度解析
  • GPU散热革命:NVIDIA微通道液冷板(MLCP)技术深度解析
  • Docker安装部署MySQL一主二从集群
  • 搭建网站服务器多少钱网站在建设中是什么意思
  • Java 11对集合类做了哪些增强?
  • SQLSugar框架数据库优先
  • 工程建设教育网站北京网站建设cnevo
  • Vector数据库性能大比武:Pinecone、Weaviate、Chroma速度与准确率实测
  • 天津老区建设促进会网站移动开发的现状和前景
  • 笔试强训(六)
  • Iterator迭代器 【ES6】
  • spring boot实现接口数据脱敏,整合jackson实现敏感信息隐藏脱敏
  • 基于单片机的汽车多参数安全检测与报警系统设计
  • C++设计模式_行为型模式_备忘录模式Memento
  • 温州h5建站关于网站建设的文章
  • 大连专业做网站wordpress 4.5 汉化主题
  • Spring Boot 3零基础教程,Spring Boot 日志分组,笔记20
  • 【单调向量 单调栈】3676. 碗子数组的数目|1848
  • 【JUnit实战3_01】第一章:JUnit 起步