当前位置: 首页 > news >正文

机器学习中的过拟合及示例

文章目录

  • 机器学习中的过拟合及示例
    • 1. 过拟合的定义
    • 2. 过拟合的常见例子
      • 例1:图像分类中的过拟合
      • 例2:回归任务中的过拟合
      • 例3:自然语言处理(NLP)中的过拟合
    • 3. Python代码示例:过拟合的直观演示
      • 示例1:多项式回归中的过拟合
      • 示例2:神经网络在分类任务中的过拟合
    • 4. 过拟合的解决方案(简要)
    • 5. 总结
    • 附记

机器学习中的过拟合及示例

1. 过拟合的定义

过拟合(Overfitting) 是机器学习模型对训练数据拟合得过于紧密,因而在训练数据上表现优异,但在新数据(测试集或真实场景)上性能显著下降的现象。其本质是模型过度学习了训练数据中的噪声、异常值、随机波动或局部特征,而非数据背后的真实规律,导致泛化能力差。

核心特征

  • 训练误差低,验证/测试误差高。
  • 模型复杂度过高(例如神经网络层数过多、决策树分支过细)。

值得注意的是,虽然我们有一些方法可以缓解过拟合问题,但很难将其彻底消除,为此,我们的目标就是尽量将过拟合最小化。减少过拟合现象的最有效方法是采集更多高质量的有标签数据,但如果没有更多有标签数据,也可以通过增强现有数据或利用无标签数据进行与训练等方法来缓解过拟合问题。

2. 过拟合的常见例子

例1:图像分类中的过拟合

  • 场景:训练一个猫狗分类模型时,模型可能记住训练图片中的背景(如草地、沙发)而非动物本身特征。
  • 表现:训练准确率99%,测试准确率仅60%。

例2:回归任务中的过拟合

  • 场景:用高阶多项式拟合简单的线性数据,模型强行穿过所有噪声点。
  • 表现:训练集均方误差(MSE)趋近于0,但测试集MSE极高。

例3:自然语言处理(NLP)中的过拟合

  • 场景:文本情感分类模型过度依赖某些特定词汇(如“awesome”代表正面),而忽略上下文语义。
  • 表现:对包含训练集中高频词的新句子分类准确,但对新词或复杂句式表现差。

3. Python代码示例:过拟合的直观演示

示例1:多项式回归中的过拟合

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error# 生成合成数据(真实关系为 y = 0.5x + 噪声)
np.random.seed(42)
X = np.linspace(0, 10, 20)
y = 0.5 * X + np.random.normal(0, 1, X.shape[0])# 将数据分为训练集和测试集
X_train, X_test = X[:15], X[15:]
y_train, y_test = y[:15], y[15:]# 使用高阶多项式拟合(复杂度远超真实关系)
poly = PolynomialFeatures(degree=15)  # 故意使用过高的阶数
X_poly_train = poly.fit_transform(X_train.reshape(-1, 1))
X_poly_test = poly.transform(X_test.reshape(-1, 1))model = LinearRegression()
model.fit(X_poly_train, y_train)# 计算训练集和测试集误差
train_pred = model.predict(X_poly_train)
test_pred = model.predict(X_poly_test)
train_mse = mean_squared_error(y_train, train_pred)
test_mse = mean_squared_error(y_test, test_pred)print(f"训练集MSE: {train_mse:.4f}")  # 输出:训练集MSE: 0.1583
print(f"测试集MSE: {test_mse:.4f}")    # 输出:测试集MSE: 21.3476# 可视化拟合曲线
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
plt.scatter(X, y, label="真实数据")
X_plot = np.linspace(0, 10, 100)
y_plot = model.predict(poly.transform(X_plot.reshape(-1, 1)))
plt.plot(X_plot, y_plot, 'r', label="过拟合模型")
plt.legend()
plt.show()

运行效果
多项式回归过拟合
输出结果

  • 训练集MSE接近0,测试集MSE极大。
  • 拟合曲线剧烈震荡,强行穿过所有训练点(如上图)。

示例2:神经网络在分类任务中的过拟合

from keras import layers, models
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split# 生成非线性可分数据(月亮数据集)
X, y = make_moons(n_samples=200, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)# 构建过于复杂的神经网络
model = models.Sequential([layers.Input(shape=(2,)),layers.Dense(128, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# 训练模型(故意不设早停和正则化)
history = model.fit(X_train, y_train, epochs=500, validation_data=(X_test, y_test), verbose=0)# 绘制训练与验证准确率曲线
plt.rcParams['font.sans-serif']=['Microsoft YaHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.legend()
plt.show()

神经网络过拟合
输出结果

  • 训练准确率接近100%,验证准确率停滞在约90%-95%。
  • 验证准确率曲线后期下降(如上图)。

4. 过拟合的解决方案(简要)

  1. 降低模型复杂度:减少神经网络层数/神经元数、限制决策树深度。
  2. 正则化:L1/L2正则化(添加权重惩罚项)、Dropout(随机禁用神经元)。
  3. 数据增强:增加训练数据多样性。
  4. 早停(Early Stopping):在验证误差开始上升时终止训练。
  5. 交叉验证:合理划分数据,避免评估偏差。

5. 总结

过拟合的本质是模型对训练数据的“死记硬背”,通过代码示例可以直观看到:

  • 高阶多项式回归在训练集完美拟合,但测试集表现灾难性下降。
  • 复杂神经网络的训练准确率与验证准确率差距显著。
    实际应用中需通过监控训练/验证误差、调整模型复杂度及数据策略来缓解过拟合。

附记

本文中Python程序的执行环境为Python 3.12.9,依赖的包如下:

  • “keras>=3.9.2”
  • “matplotlib>=3.10.3”
  • “scikit-learn>=1.6.1”

相关文章:

  • 计算机图形学编程(使用OpenGL和C++)(第2版)学习笔记 12.曲面细分
  • MySQL初阶:sql事务和索引
  • 深入解析Spring Boot与Redis集成:高效缓存实践
  • 如何彻底清空docker里面不使用的容器?
  • 面向对象详解和JVM底层内存分析
  • Windows运维工具批处理版
  • 使用Python和`python-docx`库复制Word文档样式
  • mysql中4种扫描方式和聚簇索引非聚簇索引【爽文一篇】
  • BG开发者日志517:demo数据分析与修改方向
  • muduo库TcpConnection模块详解——C++
  • MySQL只操作同一条记录也会死锁吗?
  • Linux面试题集合(5)
  • 通俗版解释CPU、核心、进程、线程、协程的定义及关系
  • AGI大模型(21):混合检索之混合搜索
  • CSS 浮动与定位以及定位中z-index的堆叠问题
  • 管理前端项目依赖版本冲突导致启动失败的问题的解决办法
  • 深度学习---知识蒸馏(Knowledge Distillation, KD)
  • 代码随想录算法训练营第60期第三十九天打卡
  • C# 深入理解类(静态函数成员)
  • UDP三种通信方式
  • 陈刚:推动良好政治生态和美好自然生态共生共优相得益彰
  • 梅花奖在上海|湘剧《夫人如见》竞梅,长沙文旅来沪推广
  • 上海一保租房社区亮相,首批546套房源可拎包入住
  • 韩正会见美国景顺集团董事会主席瓦格纳
  • 师爷、文士、畸人:会稽范啸风及其著述
  • 音乐节困于流量