当前位置: 首页 > news >正文

K - 近邻(KNN)算法:基于约会数据集的分类任务全流程

K - 近邻(K-Nearest Neighbors, KNN)是机器学习中最直观的分类算法之一,核心逻辑为 “近朱者赤,近墨者黑”—— 通过寻找新样本的 K 个最近邻训练样本,根据多数邻域样本的类别完成分类。本文基于约会数据集,从数据加载、探索、预处理到模型训练、评估,完整拆解 KNN 算法的实现过程,并深入讲解 K 值选择等关键优化点。

一、项目背景与核心目标

1. 数据集介绍

本次使用的约会数据集(datingTestSet2.txt) 包含 1000 条样本,每条样本含 4 列数据,前 3 列为特征,第 4 列为类别标签,具体含义如下(根据 KNN 任务场景推测):

  • 特征 1:可能为 “每年飞行里程数”(数值较大,如 14488、75136 等);
  • 特征 2:可能为 “玩游戏所耗时间占比”(小数,如 0.9539、1.6698 等);
  • 特征 3:可能为 “每周消费冰淇淋公升数”(小数,如 0.2、1.3324 等);
  • 类别标签(第 4 列):取值 1、2、3,代表 3 种不同的约会对象类型(如 “不喜欢”“一般喜欢”“非常喜欢”)。

2. 核心目标

  • 构建 KNN 分类模型,基于 3 个特征预测约会对象类型;
  • 掌握 KNN 算法的实现流程(数据预处理、模型训练、评估);
  • 理解 K 值对模型性能的影响,学习 K 值选择的核心方法。

二、技术工具与环境准备

  • 编程语言:Python 3.9
  • 核心库说明
    库名核心用途
    numpy数值计算(模型底层依赖)
    pandas数据加载、探索与结构化处理
    matplotlib数据可视化(可选,用于特征分布分析)
    sklearn.neighborsKNN 分类模型实现(KNeighborsClassifier
    sklearn.model_selection数据集拆分(训练集 / 测试集)
    sklearn.preprocessing特征标准化(消除量纲影响)
    sklearn.metrics模型分类性能评估(准确率、精确率等)

三、K - 近邻算法实现步骤详解

1. 导入依赖库

首先导入所有需要的工具库,避免后续代码中重复引入:

# 数值计算与数据处理库
import numpy as np
import pandas as pd
# 可视化库(可选,用于后续特征分析)
import matplotlib.pyplot as plt
# KNN模型与数据处理工具
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 模型评估指标
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

2. 加载数据集

通过pandas.read_table()读取本地 TXT 格式的约会数据集,指定分隔符为制表符(\t),且数据集无表头(header=None):

# 加载约会数据集
# 注意:需将文件路径替换为你本地的实际路径
data_test_set = pd.read_table(r"D:\Desktop\CC是小陈\Machine Learning\datingTestSet2.txt",sep="\t",  # 分隔符为制表符(TXT文件常用格式)header=None  # 数据集无表头,列名默认设为0、1、2、3
)# 打印前5行数据,初步观察数据结构
print("数据集前5行预览:")
print(data_test_set.head())# 打印数据集基本维度
print(f"\n数据集维度:{data_test_set.shape}({data_test_set.shape[0]}条样本,{data_test_set.shape[1]}列)")

3. 数据探索与预处理

数据探索是建模前的关键步骤,需了解数据分布、缺失值情况,为后续预处理提供依据。

(1)查看数据基本信息
# 1. 查看数据类型与非空值数量
print("\n=== 数据类型与非空值统计 ===")
print(data_test_set.info())# 2. 查看数据统计描述(均值、标准差、最值等)
print("\n=== 数据统计描述 ===")
print(data_test_set.describe())# 3. 检查缺失值(KNN对缺失值敏感,需确保无缺失)
print("\n=== 各列缺失值数量 ===")
print(data_test_set.isnull().sum())
(2)特征标准化(关键预处理步骤)

KNN 算法基于 “距离”(如欧氏距离)判断样本相似度,若特征量纲差异大,会导致 “数值大的特征”(如飞行里程)在距离计算中占据主导地位,掩盖其他特征的影响。因此必须通过标准化将所有特征转换为 “均值 = 0,标准差 = 1” 的分布:

# 初始化标准化器
scaler = StandardScaler()# 提取特征矩阵(前3列)
X = data_test_set.iloc[:, :3]# 对特征进行标准化(拟合+转换)
X_scaled = scaler.fit_transform(X)# 查看标准化后的特征统计信息
print("\n=== 标准化后特征统计描述 ===")
print(pd.DataFrame(X_scaled, columns=["特征1_标准化", "特征2_标准化", "特征3_标准化"]).describe())

4. 拆分训练集与测试集

将标准化后的特征与类别标签组合,按 “75:25” 的比例拆分为训练集(用于模型学习邻居关系)和测试集(用于评估泛化能力):

# 提取类别标签(第4列)
y = data_test_set.iloc[:, 3]# 拆分数据集:test_size=0.25表示测试集占25%,random_state=42确保结果可复现
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.25, random_state=42
)# 打印拆分后的数据维度
print("\n=== 拆分后数据集维度 ===")
print(f"训练集特征:{X_train.shape},训练集标签:{y_train.shape}")
print(f"测试集特征:{X_test.shape},测试集标签:{y_test.shape}")
拆分结果
  • 训练集:750 条样本(1000×75%),用于训练 KNN 模型的 “邻居判断规则”;
  • 测试集:250 条样本(1000×25%),模拟 “新数据”,验证模型对未知样本的分类能力;
  • random_state=42:固定随机种子,确保每次运行代码时拆分结果一致,便于调试和对比。

5. 构建并训练 KNN 模型

sklearn中的KNeighborsClassifier默认使用欧氏距离投票法(K 个邻居中多数类为预测类别),核心参数为n_neighbors(即 K 值,默认 K=5):

# 初始化KNN模型(默认K=5,使用欧氏距离)
model = KNeighborsClassifier()# 用训练集训练模型(KNN无“训练过程”,实际是存储训练样本用于后续距离计算)
model.fit(X_train, y_train)# 查看模型默认参数(重点关注K值和距离度量)
print("\n=== KNN模型默认参数 ===")
print(f"K值(n_neighbors):{model.n_neighbors}")
print(f"距离度量(metric):{model.metric}")
print(f"权重方式(weights):{model.weights}")  # uniform:等权重投票,distance:距离越近权重越大
KNN 模型特点

与逻辑回归、线性回归不同,KNN 是惰性学习(Lazy Learner) 算法:

  • 训练阶段不进行参数学习,仅存储训练样本;
  • 预测阶段通过计算新样本与所有训练样本的距离,找到 K 个最近邻并投票,因此预测速度较慢(尤其大数据集)。

6. 模型预测与性能评估

用训练好的模型对测试集进行预测,通过多维度指标评估模型性能(准确率、精确率、召回率、F1 值):

(1)模型预测
# 对测试集进行分类预测
y_pred = model.predict(X_test)# 查看前10条测试数据的真实标签与预测标签
print("\n=== 测试集前10条预测结果 ===")
result_df = pd.DataFrame({"真实标签": y_test.iloc[:10].values,  # 注意:若y为Series需转成数组"预测标签": y_pred[:10]
})
print(result_df)
(2)模型评估

#7、评估模型
accuracy = accuracy_score(y_test,y_pred)
precision = precision_score(y_test,y_pred,average = 'weighted')
recall = recall_score(y_test,y_pred,average = 'weighted')
f1 = f1_score(y_test,y_pred,average = 'weighted')
print("accuracy:",accuracy)
print("precision:",precision)
print("recall:",recall)
print("f1:",f1)

四、K 值选择:KNN 模型的核心优化点

K 值是影响 KNN 性能的关键参数,直接决定 “邻居范围” 的大小,需结合业务场景和数据特点选择。

1. K 值对模型的影响

K 值大小特点风险适用场景
较小(如 K=1)邻居范围小,模型对局部数据敏感易过拟合(受噪声样本影响大,训练准、测试差)数据噪声少、局部特征明显的场景
较大(如 K=10)邻居范围大,模型更稳健易欠拟合(忽略局部特征,决策边界过于平滑)数据噪声多、全局特征明显的场景
极端大(如 K = 样本数)所有样本都是邻居,预测结果为训练集多数类完全欠拟合,模型失去意义无实际适用场景

2. K 值选择的常用方法

(1)经验法则
  • 多数场景下,K 值取1-10之间的整数;
  • 若数据集类别数多,可适当增大 K 值(如类别数 = 5,K=5-8);
  • 若数据集样本数多,可适当增大 K 值(如 10 万样本,K=10-20)。
(2)交叉验证(最可靠方法)

通过K 折交叉验证(如 5 折、10 折),在不同 K 值下计算模型平均准确率,选择准确率最高的 K 值:

# 示例:用5折交叉验证选择最优K值
from sklearn.model_selection import cross_val_score# 定义待测试的K值范围(1-10)
k_range = range(1, 11)
# 存储不同K值的交叉验证准确率
cv_scores = []# 遍历K值,计算交叉验证准确率
for k in k_range:knn = KNeighborsClassifier(n_neighbors=k)# 5折交叉验证,计算平均准确率scores = cross_val_score(knn, X_train, y_train, cv=5, scoring='accuracy')cv_scores.append(scores.mean())# 找到最优K值(准确率最高的K)
best_k = k_range[cv_scores.index(max(cv_scores))]
print(f"\n=== 交叉验证结果 ===")
print(f"不同K值的平均准确率:{cv_scores}")
print(f"最优K值:{best_k},对应准确率:{max(cv_scores):.4f}")# 可视化K值与准确率的关系
plt.plot(k_range, cv_scores, marker='o', linestyle='-')
plt.xlabel('K值(n_neighbors)')
plt.ylabel('5折交叉验证平均准确率')
plt.title('K值选择:准确率随K值的变化')
plt.xticks(k_range)
plt.grid(alpha=0.3)
plt.show()
(3)考虑数据量
  • 小数据集(如 < 1000 样本):K 值取较小值(1-5),避免因邻居过多引入无关样本;
  • 大数据集(如 > 10 万样本):K 值取较大值(5-20),模型更稳健且计算效率影响较小。

五、完整可运行代码

# coding:'UTF-8'
# author:ChenShaoQi
# create_file_time:2025/9/15#1、导入相关库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score#2、导入数据
data_test_set = pd.read_table(r"D:\Desktop\CC是小陈\Machine Learning\datingTestSet2.txt",sep = "\t",header = None)
print(data_test_set)#3、查看数据的基本信息
print(data_test_set.info())
print(data_test_set.describe())
print(data_test_set.isnull().sum())#4、将数据拆分为训练集和测试集
x = data_test_set.iloc[:,:3]
y = data_test_set.iloc[:,3]
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.25,random_state=42)#5、创建并训练模型
model = KNeighborsClassifier()
model.fit(x_train,y_train)#6、预测模型
y_pred = model.predict(x_test)#7、评估模型
accuracy = accuracy_score(y_test,y_pred)
precision = precision_score(y_test,y_pred,average = 'weighted')
recall = recall_score(y_test,y_pred,average = 'weighted')
f1 = f1_score(y_test,y_pred,average = 'weighted')
print("accuracy:",accuracy)
print("precision:",precision)
print("recall:",recall)
print("f1:",f1)

👏觉得文章对自己有用的宝子可以收藏文章并给小编点个赞!

👏想了解更多统计学、数据分析、数据开发、机器学习算法、数据治理、数据资产管理和深度学习等有关知识的宝子们,可以关注小编,希望以后我们一起成长!


文章转载自:

http://LfTIYX2m.smhtg.cn
http://uYRQR6Bq.smhtg.cn
http://bvV9j2lz.smhtg.cn
http://ddeQPIIO.smhtg.cn
http://BqCgW5e1.smhtg.cn
http://apgki3QS.smhtg.cn
http://u1fyV7hI.smhtg.cn
http://H3L3gu16.smhtg.cn
http://txMcJg3y.smhtg.cn
http://EUqG9dL9.smhtg.cn
http://bcflVF9k.smhtg.cn
http://0K1GxMcQ.smhtg.cn
http://L1fLtmJP.smhtg.cn
http://5oCTGPDU.smhtg.cn
http://QNYoAQEw.smhtg.cn
http://AmqWtmI6.smhtg.cn
http://9H1a9bco.smhtg.cn
http://iyqcJ3Io.smhtg.cn
http://XASCvqAF.smhtg.cn
http://UlaLTXBo.smhtg.cn
http://mPcI52Wj.smhtg.cn
http://gbISBoR9.smhtg.cn
http://Hwlpgbl8.smhtg.cn
http://zHXyfoVI.smhtg.cn
http://ivhkRSxC.smhtg.cn
http://oLq0vR5c.smhtg.cn
http://aSt4Kmvh.smhtg.cn
http://4ggNFMEQ.smhtg.cn
http://ir515WGO.smhtg.cn
http://qCMGV9hC.smhtg.cn
http://www.dtcms.com/a/384537.html

相关文章:

  • 机器学习实战第四章 线性回归
  • 概率统计面试题2:随机抛掷两点到圆心距离较小值的期望
  • 什么是 OFDM?它如何解决频率选择性衰落?
  • 第一部分:VTK基础入门(第3章:VTK架构与核心概念)
  • 基于深度学习的中文方言识别模型训练实战
  • 【机器学习】用Anaconda安装学习环境
  • 【C语言】C语言内存存储底层原理:整数补码、浮点数IEEE754与大小端(数据内存存储的深度原理与实践)
  • MongoDB - 连接
  • 【Day 57】Linux-Redis
  • Go语言爬虫:爬虫入门
  • HarmonyOS图表组件库对比:UCharts、VChart、Omni-UI、mcCharts
  • 生活中的花花草草和各色人物
  • HTML属性和值
  • 【STL库】unordered_map/unordered_set 类学习
  • 学习threejs,使用自定义GLSL 着色器,实现水面、粒子特效
  • 机器学习-第二章
  • 贪心算法在SDN流表优化中的应用
  • 植物1区TOP——GWAS eQTL如何精准定位调控棉花衣分的候选基因
  • iOS 灵动岛 ActivityKit 开发实践
  • JVM 垃圾收集器
  • 学习日记-XML-day55-9.14
  • SenseVoice + WebRTC:打造行业级实时语音识别系统的底层原理与架构设计
  • C++ 异常机制深度解析:从原理到实战的完整指南
  • 在 Qoder 等 AI 二创 IDE 里用 VS Code Remote-SSH 的“曲线连接”实战
  • 云计算与大数据技术深入解析
  • 如何用Verdi APP抽出某个指定module的interface hierarchy
  • MySQL 锁机制详解+示例
  • 消息队列的“翻车“现场:当Kafka和RocketMQ遇到异常时会发生什么?
  • 在Cursor上安装检索不到的扩展库cline的方式方法
  • 第二十一章 ESP32S3 IIC_OLED 实验