KNN算法详解:鸢尾花识别和手写数字识别
KNN算法详解:从理论到实践的手写数字识别
引言
KNN(K-Nearest Neighbor)算法是机器学习中最基础、最直观的算法之一。它基于"物以类聚"的思想,通过计算样本间的距离来进行分类或回归预测。本文将详细介绍KNN算法的核心概念、实现原理以及在实际项目中的应用。
一、KNN算法核心概念
1.1 算法思想
KNN算法的核心思想可以用"近朱者赤,近墨者黑"来概括。具体来说:
- 分类问题:如果一个样本在特征空间中的k个最相似样本中的大多数属于某一个类别,则该样本也属于这个类别
- 回归问题:取k个最近邻样本的平均值作为预测结果
1.2 K值选择的影响
K值的选择对算法性能有重要影响:
- K值过小:容易过拟合,对噪声敏感
- K值过大:容易欠拟合,忽略局部特征
- 经验值:通常选择3-10之间的奇数
二、距离度量方法
2.1 欧式距离(最常用)
# 欧式距离公式
distance = √[(x₁-y₁)² + (x₂-y₂)² + ... + (xₙ-yₙ)²]
特点:最直观的距离度量方法,适用于连续型特征
2.2 曼哈顿距离
# 曼哈顿距离公式
distance = |x₁-y₁| + |x₂-y₂| + ... + |xₙ-yₙ|
特点:计算简单,适合高维数据,适用于离散型特征
2.3 切比雪夫距离
# 切比雪夫距离公式
distance = max(|x₁-y₁|, |x₂-y₂|, ..., |xₙ-yₙ|)
特点:只考虑最大差值,适用于棋盘距离计算
2.4 闵可夫斯基距离
# 闵可夫斯基距离公式
distance = (∑|xᵢ-yᵢ|ᵖ)^(1/p)
特点:距离度量的一般化公式
- p=1时:曼哈顿距离
- p=2时:欧式距离
- p=∞时:切比雪夫距离
三、特征预处理
3.1 为什么需要预处理
KNN算法对特征尺度非常敏感,不同特征的单位或大小相差较大时,会影响距离计算的准确性。
3.2 归一化(Min-Max Scaling)
from sklearn.preprocessing import MinMaxScaler# 归一化公式:X_norm = (X - X_min) / (X_max - X_min)
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)
特点:将数据映射到[0,1]区间,受异常值影响大
3.3 标准化(Z-Score Normalization)
from sklearn.preprocessing import StandardScaler# 标准化公式:X_std = (X - μ) / σ
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
特点:转换为均值为0,标准差为1的标准正态分布,对异常值不敏感
四、超参数优化
4.1 交叉验证
交叉验证是一种更完善的模型评估方法:
from sklearn.model_selection import cross_val_score# 5折交叉验证
scores = cross_val_score(estimator, X, y, cv=5)
原理:将训练集划分为n份,轮流使用每份作为验证集
4.2 网格搜索
from sklearn.model_selection import GridSearchCV# 定义参数网格
param_grid = {'n_neighbors': range(1, 20)}# 网格搜索
grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
grid_search.fit(X_train, y_train)# 获取最优参数
best_params = grid_search.best_params_
best_score = grid_search.best_score_
五、实战案例:手写数字识别
5.1 数据集介绍
MNIST手写数字数据集是计算机视觉领域的经典数据集:
- 28×28像素,784个特征
- 0-9十个类别
- 训练集42000条数据
5.2 完整实现代码
import os
# 解决多进程警告
os.environ["LOKY_MAX_CPU_COUNT"] = "4" # 设置CPU核心数import matplotlib
matplotlib.use('TkAgg') # 解决matplotlib显示警告
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import joblib
from collections import Counterdef show_digit(idx):"""显示指定索引的数字图片"""df = pd.read_csv("../data/手写数字识别.csv")if idx < 0 or idx > len(df) - 1:print("索引越界!")returnx = df.iloc[:, 1:] # 特征列y = df.iloc[:, 0] # 标签列print(f"您传入的索引对应的数字是:{y[idx]}")# 重塑为28x28图像digit = x.iloc[idx].values.reshape(28, 28)plt.imshow(digit, cmap='gray')plt.axis('on')plt.show()def train_model():"""训练KNN模型"""# 加载数据df = pd.read_csv("../data/手写数字识别.csv")x = df.iloc[:, 1:] # 特征列y = df.iloc[:, 0] # 标签列# 特征工程 - 二值化处理x = ((df.iloc[:, 1:] > 127) * 255).astype(int)# 数据集划分x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=18, stratify=y)# 模型训练estimator = KNeighborsClassifier(n_neighbors=5)estimator.fit(x_train, y_train)# 模型评估accuracy = accuracy_score(y_test, estimator.predict(x_test))print(f"准确率: {accuracy}")# 保存模型joblib.dump(estimator, "../model/手写数字识别.pkl")print("模型保存成功!")def use_model():"""使用训练好的模型进行预测"""# 读取图片img = plt.imread("../data/demo.png")# 加载模型knn = joblib.load("../model/手写数字识别.pkl")# 解决预测警告:使用DataFrame而不是numpy数组y_predict = knn.predict(pd.DataFrame(img.reshape(1, 784), columns=knn.feature_names_in_))print(f"预测结果: {y_predict}")if __name__ == '__main__':# show_digit(10) # 显示第10个数字train_model() # 训练模型# use_model() # 使用模型预测
5.3 关键优化点
5.3.1 解决多进程警告
os.environ["LOKY_MAX_CPU_COUNT"] = "4" # 设置CPU核心数
5.3.2 解决matplotlib显示警告
matplotlib.use('TkAgg') # 指定后端
5.3.3 解决预测警告
# 使用DataFrame而不是numpy数组进行预测
y_predict = knn.predict(pd.DataFrame(img.reshape(1, 784), columns=knn.feature_names_in_)
)
六、算法优缺点分析
6.1 优点
- 简单易懂:无需训练过程,算法逻辑直观
- 对异常值不敏感:基于距离计算,异常值影响有限
- 适用于多分类:天然支持多分类问题
- 无需假设:不假设数据分布
6.2 缺点
- 计算复杂度高:需要计算所有样本距离
- 内存消耗大:需要存储所有训练数据
- 对特征尺度敏感:必须进行特征预处理
- 维度灾难:高维数据效果差
七、应用场景
7.1 分类问题
- 图像分类(手写数字识别、人脸识别)
- 文本分类(垃圾邮件检测)
- 推荐系统(用户分类)
7.2 回归问题
- 房价预测
- 股票价格预测
- 温度预测
八、总结
KNN算法作为机器学习的基础算法,具有简单直观、易于理解的优点。虽然在某些场景下计算复杂度较高,但在小规模数据集和特征维度不高的情况下,仍然是一个很好的选择。
通过本文的学习,可以掌握:
- KNN算法的核心思想和实现原理
- 不同距离度量方法的特点和适用场景
- 特征预处理的重要性和方法
- 超参数优化的技术
- 实际项目中的完整实现流程
KNN算法为理解更复杂的机器学习算法奠定了重要基础,是每个机器学习学习者的必修内容。
更新日期 2025年8月25日