当前位置: 首页 > news >正文

【数学建模学习笔记】机器学习回归:K邻近回归

K 近邻回归入门:从零开始预测房价

对于机器学习初学者来说,K 近邻回归是一个非常适合入门的算法。它原理简单、容易理解,而且不需要复杂的数学推导就能上手实践。本文将用通俗的语言解释 K 近邻回归的核心概念,并通过房价预测的实例带你一步步掌握这个算法。

什么是 K 近邻回归?

K 近邻回归(K-Nearest Neighbors Regression)的核心思想可以用一句话概括:"物以类聚,人以群分"

想象一下,如果你想知道一套房子的价格,但不知道具体行情,你会怎么做?最直观的方法就是看看附近类似的房子都卖多少钱 —— 这其实就是 K 近邻回归的基本思路!

  • "近邻":指的是与我们要预测的样本(比如一套房子)在特征上相似的数据点
  • "K":是我们选择的参考样本数量(比如参考最近的 5 套或 10 套房子)
  • "回归":表示我们的目标是预测一个连续的数值(比如房价)

K 近邻回归的工作原理

K 近邻回归的工作流程非常简单,就像下面这三步:

  1. 找到与待预测样本最相似的 K 个样本(近邻)
  2. 计算这 K 个样本的目标值(比如房价)的平均值
  3. 用这个平均值作为待预测样本的预测结果

举个例子:如果我们要预测一套 80 平米的房子价格,选择 K=3,找到 3 套最相似的房子价格分别是 100 万、110 万和 105 万,那么预测价格就是 (100+110+105)/3 = 105 万。

实战:用 K 近邻回归预测房价

接下来,我们通过一个完整的实例来学习如何用 Python 实现 K 近邻回归。

第一步:准备工作

首先,我们需要导入必要的 Python 库,并加载数据:

# 导入需要的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns# 机器学习相关库
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error, r2_score# 读取房价数据
df = pd.read_excel('https://labfile.oss.aliyuncs.com/courses/40611/%E5%8E%9F%E5%A7%8B%E6%95%B0%E6%8D%AE_%E6%88%BF%E4%BB%B7%E9%A2%84%E6%B5%8B%EF%BC%88mini%E7%89%88%E6%95%B0%E6%8D%AE%EF%BC%89.xlsx')# 查看数据前5行
print("数据前5行:")
print(df.head())

第二步:理解数据

我们的数据包含了各种可能影响房价的因素,比如面积、房龄、是否有电梯等。让我们先了解一下数据的基本情况:

# 查看数据集中有哪些特征
print("\n数据特征:")
print(df.columns.tolist())# 查看"户型"和"电梯"这两个分类特征的取值情况
print("\n户型分布:")
print(df['户型'].value_counts())
print("\n电梯分布:")
print(df['电梯'].value_counts())

第三步:数据预处理

在使用数据之前,我们需要进行一些预处理工作,包括处理缺失值和转换分类特征:

# 检查缺失值
print("\n各特征缺失值数量:")
print(df.isnull().sum())# 处理分类特征:将文字转换为数字
df['户型'] = df['户型'].map({"高端装修": 3,"简单装修": 1,"精装修": 2
})df['电梯'] = df['电梯'].map({"无": 0,"有": 1
})# 去除含有缺失值的行
df = df.dropna()# 为了方便理解,将部分中文列名改为英文
column_mapping = {'户型': 'Type','电梯': 'Elevator','面积': 'Area','房龄': 'Age','装修程度': 'Decoration','容积率': 'Plot_Ratio','绿化率': 'Greening_Rate','房价': 'House_Price'
}
df.rename(columns=column_mapping, inplace=True)# 查看处理后的数据
print("\n处理后的数据前5行:")
print(df.head())

第四步:准备训练数据和测试数据

我们需要将数据分为特征(影响房价的因素)和目标(房价),并进一步分为训练集(用于训练模型)和测试集(用于评估模型):

# 选择特征和目标变量
# 这里我们选择了部分可能影响房价的关键特征
X = df[['Type', 'Elevator', 'Area', 'Age', 'Decoration', 'Plot_Ratio', 'Greening_Rate']]
y = df['House_Price']# 数据标准化:将不同量级的特征转换到同一量级
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 划分训练集和测试集(80%用于训练,20%用于测试)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)print(f"\n训练集样本数:{X_train.shape[0]}")
print(f"测试集样本数:{X_test.shape[0]}")

为什么需要标准化?因为 K 近邻算法是基于距离计算的,如果某个特征的数值范围很大(比如面积,单位是平方米),而另一个特征的数值范围很小(比如是否有电梯,0 或 1),那么数值范围大的特征会对距离计算产生更大的影响,这可能不公平。标准化可以解决这个问题。

第五步:创建并训练 K 近邻回归模型

现在我们可以创建 K 近邻回归模型并进行训练了:

# 创建K近邻回归器,这里选择K=5
k = 5
regressor = KNeighborsRegressor(n_neighbors=k)# 训练模型
regressor.fit(X_train, y_train)print(f"\nK={k}的K近邻回归模型训练完成!")

K 值的选择很重要:

  • K 值太小:模型容易受噪声影响,泛化能力差
  • K 值太大:可能会包含太多不相似的样本,预测变得模糊

在实际应用中,我们通常会尝试不同的 K 值,选择效果最好的那个。

第六步:评估模型

训练完成后,我们需要评估模型的预测效果:

# 用测试集进行预测
y_pred = regressor.predict(X_test)# 计算评估指标
mse = mean_squared_error(y_test, y_pred)  # 均方误差,值越小越好
r2 = r2_score(y_test, y_pred)  # R²得分,越接近1越好print("\n模型评估结果:")
print(f"均方误差 (MSE): {mse:.2f}")
print(f"R²得分: {r2:.4f}")# 可视化预测结果与实际结果的对比
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.6)
plt.plot([y.min(), y.max()], [y.min(), y.max()], 'r--')  # 理想情况下的预测线
plt.xlabel('实际房价')
plt.ylabel('预测房价')
plt.title('实际房价 vs 预测房价')
plt.show()

第七步:分析特征重要性

K 近邻回归本身不直接提供特征重要性,但我们可以通过分析特征与目标变量的相关性来了解哪些因素对房价影响更大:

# 计算特征与房价的相关性
correlation = df[['Type', 'Elevator', 'Area', 'Age', 'Decoration', 'Plot_Ratio', 'Greening_Rate', 'House_Price']].corr()# 用热力图可视化相关性
plt.figure(figsize=(10, 8))
sns.heatmap(correlation, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
plt.title('特征与房价的相关性热力图')
plt.show()# 显示各特征与房价的相关性
print("\n各特征与房价的相关性:")
print(correlation['House_Price'].sort_values(ascending=False))

相关性的值范围在 - 1 到 1 之间:

  • 接近 1:正相关(特征值越大,房价越高)
  • 接近 - 1:负相关(特征值越大,房价越低)
  • 接近 0:几乎没有线性关系

如何选择合适的 K 值?

K 值的选择对模型性能有很大影响,我们可以通过尝试不同的 K 值并比较模型性能来找到最佳值:

# 尝试不同的K值,找到最佳的K
mse_scores = []
r2_scores = []
k_values = range(1, 31)for k in k_values:regressor = KNeighborsRegressor(n_neighbors=k)regressor.fit(X_train, y_train)y_pred = regressor.predict(X_test)mse_scores.append(mean_squared_error(y_test, y_pred))r2_scores.append(r2_score(y_test, y_pred))# 可视化不同K值对应的模型性能
plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)
plt.plot(k_values, mse_scores, 'bo-')
plt.xlabel('K值')
plt.ylabel('均方误差 (MSE)')
plt.title('不同K值对应的MSE')
plt.grid(True)plt.subplot(1, 2, 2)
plt.plot(k_values, r2_scores, 'ro-')
plt.xlabel('K值')
plt.ylabel('R²得分')
plt.title('不同K值对应的R²')
plt.grid(True)plt.tight_layout()
plt.show()# 找到最佳K值
best_k = k_values[np.argmin(mse_scores)]
print(f"\n最佳K值为:{best_k}")
print(f"在K={best_k}时,MSE最小:{min(mse_scores):.2f}")

总结

通过本文的学习,你应该已经了解了:

  1. K 近邻回归的基本原理:找到相似的样本,用它们的平均值进行预测
  2. 实现 K 近邻回归的完整流程:数据准备→预处理→模型训练→评估
  3. 如何选择合适的 K 值:通过尝试不同 K 值并比较性能
  4. 如何分析特征对预测结果的影响:通过相关性分析

K 近邻回归的优点是简单易懂、实现容易,不需要假设数据分布;缺点是对于大数据集计算速度较慢,对高维数据效果可能不佳。

希望这个入门教程能帮助你理解 K 近邻回归算法,并能应用它解决实际问题!


文章转载自:

http://hYhgyD2B.cgmzt.cn
http://RXaoGIQF.cgmzt.cn
http://dnxxNRpU.cgmzt.cn
http://w4aK7RCZ.cgmzt.cn
http://NtsD7AHA.cgmzt.cn
http://NkHL9c7L.cgmzt.cn
http://LnvH8kew.cgmzt.cn
http://Nysa0tVu.cgmzt.cn
http://xbxgfGCm.cgmzt.cn
http://JyhnB2wb.cgmzt.cn
http://1eDwPsVH.cgmzt.cn
http://j4afx464.cgmzt.cn
http://27qfCLSa.cgmzt.cn
http://Qr2Oe6n7.cgmzt.cn
http://kJmXFOmY.cgmzt.cn
http://PrNCsPq6.cgmzt.cn
http://BqVYosaB.cgmzt.cn
http://vtnwnWcV.cgmzt.cn
http://CaWWp9th.cgmzt.cn
http://yaITwJ8p.cgmzt.cn
http://Jbcc2i1w.cgmzt.cn
http://n8h0kAY2.cgmzt.cn
http://akFLtkLQ.cgmzt.cn
http://RoVqg4wI.cgmzt.cn
http://NTAzIghC.cgmzt.cn
http://bJXZaLpg.cgmzt.cn
http://nFcVfa0u.cgmzt.cn
http://UEuQySxC.cgmzt.cn
http://b5klP3ZB.cgmzt.cn
http://fyyfPkhV.cgmzt.cn
http://www.dtcms.com/a/367093.html

相关文章:

  • JavaEE 进阶第二期:开启前端入门之旅(二)
  • 准确率可达99%!注意力机制+UNet,A会轻松收割!
  • SpringBoot 项目一些语法记录
  • 单通道ADC采集实验(单次非扫描软件触发)
  • 同步安卓手机的照片到NAS的方案(完美)
  • 嵌入式设备的外设驱动优化
  • 51单片机---硬件学习(跑马灯、数码管、外部中断、按键、蜂鸣器)
  • 嵌入式 - 硬件:51单片机(3)uart串口
  • 深度剖析:智能驾驶到底给2025带来了什么
  • MTK Linux DRM分析(三十六)- MTK mtk_cec.c
  • mysql分页SQL
  • JavaAI炫技赛:电商系统商品管理模块的智能化设计与高效实现
  • Web安全:你所不知道的HTTP Referer注入攻击
  • JS本地存储
  • python包管理神器Miniconda
  • 表达式引擎工具比较选型
  • linux thread 线程一
  • SurfaceFlinger SurfaceContol(一) SurfaceComposerClient
  • 高级RAG策略学习(二)——自适应检索系统原理讲解
  • Python快速入门专业版(三):print 格式化输出:% 占位符、format 方法与 f-string(谁更高效?)
  • 2025打磨机器人品牌及自动化打磨抛光设备技术新版分析
  • 只会git push?——git团队协作进阶
  • Ubuntu系统配置镜像源
  • RTSP H.265 与 RTMP H.265 的差异解析:标准、扩展与增强实现
  • Vue基础知识-脚手架开发-子传父(props回调函数实现和自定义事件实现)
  • 九、数据库技术基础
  • Roo Code之自定义指令(Custom Instructions),规则(Rules)
  • 掌握DNS解析:从基础到BIND部署全解析
  • git push -u origin main 这个-u起什么作用
  • 微信小程序日历事件添加实现