当前位置: 首页 > news >正文

【机器学习09】调试策略、错误分析、数据增强、迁移学习

文章目录

  • 一 策略性调试机器学习模型
  • 二 偏差与方差的权衡
  • 三 神经网络中的偏差与方差
  • 四 机器学习开发的迭代循环
  • 五 错误分析:深入理解模型缺陷
    • 5.1 垃圾邮件分类实例
    • 5.2 确定优化方向
    • 5.3 错误分析流程
  • 六 以数据为中心的AI视角
    • 6.1 策略性地增加数据
    • 6.2 数据增强技术
    • 6.3 人工数据合成
    • 6.4 范式转移:从模型为中心到以数据为中心
  • 七 迁移学习:站在巨人的肩膀上


视频链接
吴恩达机器学习p76-83


一 策略性调试机器学习模型

在上一篇文章中,我们探讨了如何通过训练误差和交叉验证误差来诊断模型存在的偏差或方差问题。诊断是第一步,接下来我们需要根据诊断结果,采取相应的策略来优化模型。

[在此处插入图片1]

现在,我们可以为文章开头提出的诸多优化选项,给出更明确的“药方”:

  • 修复高方差(过拟合)的策略

    • 获取更多的训练样本:让模型见到更多样的数据,学习到更通用的规律,降低对训练集噪声的敏感度。
    • 尝试使用更少的特征:减少模型复杂度,直接限制其过拟合的能力。
    • 尝试增大正则化参数 λ:加大对模型参数的惩罚,使模型趋于简化,从而抑制过拟合。
  • 修复高偏差(欠拟合)的策略

    • 尝试获取更多的特征:为模型提供更多信息,使其有能力学习更复杂的规律。
    • 尝试添加多项式特征:通过增加特征的复杂度(如从线性变为多项式)来提升模型的拟合能力。
    • 尝试减小正则化参数 λ:减小对模型参数的惩罚,释放模型的复杂度,使其能更好地拟合数据。

二 偏差与方差的权衡

机器学习的核心挑战之一,就是在偏差和方差之间找到一个最佳的平衡点,这被称为“偏差-方差权衡”(The bias variance tradeoff)。

[在此处插入图片2]

  • 简单模型(如低阶多项式):模型本身表达能力有限,容易产生高偏差(欠拟合)。
  • 复杂模型(如高阶多项式):模型表达能力过强,容易过度学习训练数据中的细节和噪声,从而产生高方差(过拟合)。

如图中的曲线所示,随着模型复杂度(多项式次数 d)的增加,训练误差 J_train 会持续下降,但交叉验证误差 J_cv 会先下降后上升。我们的目标就是找到那个使 J_cv 达到最低点的“最佳拟合”区域,这正是偏差与方差权衡的最佳点。

三 神经网络中的偏差与方差

对于神经网络这种强大的模型,我们同样可以应用偏差-方差分析来进行系统性的调试。

[在此处插入图片3]

现代深度学习的一个重要观点是:一个足够大的神经网络,其本身是一个“低偏差机器”。这意味着,只要网络规模够大,理论上它总能很好地拟合训练数据。基于此,我们可以建立一个清晰的调试流程:

  1. 第一步:检查训练集表现。模型在训练集上表现好吗?即 J_train 是否足够低?
    • 如果“否”J_train 很高),说明模型存在高偏差问题。主要解决方案是使用一个更大的网络(增加层数或每层的神经元数量)或者尝试更优的网络架构。这通常能有效降低偏差。当然,更大的网络也意味着需要更强的计算能力(如GPU)。
  2. 第二步:检查交叉验证集表现。在训练集表现良好的前提下,模型在交叉验证集上表现好吗?即 J_cv 是否足够低?
    • 如果“否”J_cv 远高于 J_train),说明模型存在高方差问题(过拟合)。此时,最有效的解决方案通常是获取更多的数据
  3. 如果两步都回答“是”,那么恭喜你,模型的偏差和方差都得到了很好的控制,调试工作完成!

这个流程构成了一个简单而强大的迭代循环,为神经网络的优化提供了清晰的路径。

[在此处插入图片4]

一个常见的误解是担心网络太大导致过拟合。然而,现代深度学习的实践表明:只要正则化选择得当,一个更大的神经网络通常会比一个小的网络表现得一样好,甚至更好。因此,与其小心翼翼地调整网络大小,一个更有效的方法是,从一个足够大的、可能过拟合的网络开始,然后使用正则化来控制方差。

[在此处插入图片5]

神经网络的正则化与线性回归类似,其代价函数在原始损失的基础上,增加了一个对网络中所有权重 W 的惩罚项。在TensorFlow/Keras等框架中实现这一点非常直接,只需在定义网络层时,通过 kernel_regularizer 参数添加L2正则化项即可。如上图所示,通过为每一层增加正则化,我们就可以有效地控制模型的复杂度,防止过拟合。

四 机器学习开发的迭代循环

综上所述,机器学习的开发并非一蹴而就,而是一个持续迭代的闭环过程。

[在此处插入图片6]

这个循环主要包括三个阶段:

  1. 选择架构:基于问题和现有知识,提出一个模型、数据和训练策略的初始方案。
  2. 训练模型:执行训练过程,得到模型的参数。
  3. 诊断:通过偏差/方差分析和后续将介绍的错误分析,来诊断模型的不足之处。

诊断的结果会反过来指导我们如何调整第一步中的架构,例如是该用更大的模型,还是该收集更多的数据,从而开启新一轮的迭代,直至模型性能满足要求。

五 错误分析:深入理解模型缺陷

当模型的整体性能指标(如准确率)无法给我们提供更具体的优化方向时,我们就需要进行“错误分析”(Error analysis),通过深入研究模型犯错的样本来寻找改进的线索。

5.1 垃圾邮件分类实例

[在此处插入图片7]

我们以一个垃圾邮件(Spam)分类器为例。分类器的任务是区分一封邮件是垃圾邮件(如左侧的推销邮件)还是正常邮件(如右侧的私人邮件)。

[在此处插入图片8]

构建这样一个分类器,一个常见的监督学习方法是:

  • 定义特征 x:从邮件中提取特征。一个简单有效的方法是,首先确定一个包含N个高频词的词汇表(例如N=10,000),然后将每封邮件表示成一个N维的向量,向量的每个元素对应一个词汇表中的词,其值为1或0,表示该词是否在邮件中出现。
  • 定义标签 y:1代表垃圾邮件,0代表正常邮件。

5.2 确定优化方向

[在此处插入图片9]

假设我们的分类器已经构建完成,但错误率仍然较高。我们有很多潜在的改进方向:

  • 收集更多的数据,例如通过“蜜罐”项目诱捕垃圾邮件。
  • 开发更复杂的特征,例如基于邮件头的路由信息。
  • 定义更精细的文本特征,例如将“discounting”和“discount”视为同一个词(词干提取)。
  • 设计专门的算法来检测故意拼写错误的单词,如 “w4tches”, “med1cine”。

面对这么多选项,我们应该优先投入时间在哪一项上呢?错误分析将为我们提供答案。

5.3 错误分析流程

[在此处插入图片10]

错误分析的具体步骤如下:

  1. 从交叉验证集中,收集一批被模型错误分类的样本。例如,假设有500个验证样本,其中100个被分错了。
  2. 手动检查这100个错误样本,并根据错误的共同特征进行分类和统计
  3. 分析统计结果,找出导致模型犯错的主要原因。

例如,在检查了100个错误样本后,我们发现:

  • 与“医药/药品”(Pharma)相关的垃圾邮件有21例。
  • 包含“故意拼写错误”的有3例。
  • “不寻常的邮件路由”有7例。
  • “网络钓鱼/窃取密码”的有18例。

这个结果清晰地告诉我们,“医药”和“钓鱼”是导致模型出错的最主要类别。相比之下,投入大量精力去开发一个能识别各种拼写错误的复杂算法可能收效甚微,因为它最多只能解决3%的错误。因此,我们应该优先为“医药”和“钓鱼”邮件设计更具针对性的特征。

六 以数据为中心的AI视角

错误分析的结果往往指向同一个方向:我们需要更好、更有针对性的数据。

6.1 策略性地增加数据

[在此处插入图片11]

错误分析指导我们如何更高效地增加数据。

  • 增加错误类型的数据:与其盲目地收集所有类型的数据,不如重点去寻找并标注那些模型最容易犯错的类型。例如,专门去收集更多与“医药”相关的垃圾邮件样本。
  • 数据增强 (Data Augmentation):这是一种在不收集新数据的情况下,创造更多训练样本的强大技术。

6.2 数据增强技术

[在此处插入图片12]

数据增强的核心思想是:通过对一个已有的训练样本进行各种合理的变换,来生成一个新的、与原始样本标签相同的训练样本。

[在此处插入图片13]

在计算机视觉领域,数据增强的应用非常广泛。例如,对于一张手写字母“A”的图片,我们可以通过旋转、缩放、裁剪、扭曲、改变亮度对比度等方式,生成大量新的“A”的图片。

[在此- 处插入图片14]

在语音识别领域,我们可以对一段原始音频进行增强,例如,将原始语音与各种背景噪声(如人群嘈杂声、汽车行驶声)混合,或者模拟在信号不好的手机上通话的效果,从而生成新的训练音频。

[在此处插入图片15]

使用数据增强的一个关键原则是:所引入的变换或失真,应该能够代表真实世界中(尤其是测试集中)可能出现的变化。例如,为语音识别数据增加背景噪声是合理的,因为真实场景中总有噪声。但如果只是为图片添加纯粹的、无意义的随机像素噪声,通常对提升模型性能帮助不大。

6.3 人工数据合成

[在此处插入图片16]

数据增强的一个延伸是“人工数据合成”(Artificial data synthesis),即从零开始创造全新的训练数据。一个经典的应用场景是照片中的光学字符识别(Photo OCR)。

[在此处插入图片17]

在真实世界的照片中寻找并标注文字是一项非常耗时耗力的工作。而通过数据合成,我们可以利用计算机程序,将不同字体、大小、颜色的字符,渲染到各种各样的背景图片上,并自动生成精确的标注。这样,我们就能以极低的成本,获得几乎无限的训练数据。

6.4 范式转移:从模型为中心到以数据为中心

[在此处插入图片18]

传统上,机器学习的发展更多遵循“以模型为中心”(model-centric)的方法,即保持数据集固定,不断地迭代和优化模型代码(算法)。然而,吴恩达老师强调,在许多现代AI应用中,“以数据为中心”(data-centric)的AI开发范式正变得越来越重要。

AI = Code (模型) + Data (数据)

在以数据为中心的范式中,我们可能将模型架构固定下来,转而将主要精力投入到系统性地提升数据质量上,包括改善数据标注的一致性、通过错误分析和数据增强来扩充关键场景的数据。在许多实际问题中,提升数据质量所带来的性能增益,往往比无休止地调优模型要大得多。

七 迁移学习:站在巨人的肩膀上

当我们面临自己的训练数据量不足的问题时,“迁移学习”(Transfer learning)提供了一个非常强大的解决方案。

[在此处插入图片19]

迁移学习的核心思想是,将一个在大型数据集上预训练好的模型的“知识”迁移到我们自己的、数据量较小的任务上。

例如,我们要训练一个手写数字(0-9)分类器,但只有几千张图片。我们可以这样做:

  1. 找到一个已经在海量图像数据集(如ImageNet,包含上百万张图片和1000个类别)上训练好的、性能强大的神经网络。
  2. 保留这个预训练模型的大部分结构和参数(尤其是前面的层),只替换掉其最后的输出层,使其适应我们的新任务(例如,换成一个10个输出单元的Softmax层)。
  3. 然后,用我们自己的手写数字数据对这个新网络进行“微调”(fine tuning)。

微调有两种主要策略:

  • 选项1:只训练输出层。冻结预训练模型所有原有部分的参数,只更新我们新添加的输出层的参数。这种方法计算开销小,适用于我们的数据集非常小的情况。
  • 选项2:训练所有参数。将预训练模型的参数作为初始值,然后用我们的数据继续训练整个网络的所有参数(通常使用一个较小的学习率)。这种方法可以让模型更好地适应新数据,通常效果更好,但需要的数据量也相对更多。

[在此处插入图片20]

迁移学习为什么有效?因为在大型图像数据集上训练的神经网络,其浅层网络已经学会了如何识别非常基础和通用的图像特征,比如边缘、角点、曲线和基本形状。这些底层特征对于几乎所有的视觉任务都是有用的。因此,通过迁移学习,我们相当于“借用”了这些已经学好的特征检测器,而无需在自己的小数据集上从零开始学习,从而大大提高了学习效率和模型性能。

[在此处插入图片21]

总结一下迁移学习的流程:

  1. 下载一个在与你的任务输入类型相同(如都是图像,或都是音频)的大型数据集上预训练好的神经网络模型。
  2. 用你自己的(通常规模小得多的)数据对这个网络进行进一步的训练(微调)。

通过这种方式,即使只有几十或几百个样本,我们也能构建出性能相当不错的模型,真正实现了“站在巨人的肩膀上”。

http://www.dtcms.com/a/541273.html

相关文章:

  • 网站开发相关技术天空影院手机免费观看在线
  • 南昌哪里做网站好几百块钱建网站
  • 目标使用过期的TLS1.0 版协议
  • 笔试-计算网络信号
  • 一流的网站建设做ui必要的网站
  • 摄像头拍摄照片
  • UVa 1620 Lazy Susan
  • 中国工程建筑门户网站官网房产网签
  • RabbitMQ事务机制详解
  • 网站开发人员的水平wordpress听说对百度不友好
  • 中国网站空间西安营销策划推广公司
  • 【AI工具】dify智能体-Kimi-K2+Mermaid ,一键生成系统架构图
  • 如何利用代理 IP 构建分布式爬虫系统架构?
  • 拿别的公司名字做网站凡科网站怎么修改昨天做的网站
  • Gin 框架中路由的底层实现原理
  • 公司网站开发费进什么费用利用小米路由器mini做网站
  • h5游戏免费下载:飞机炸弹?
  • 【c++ qt】QtConcurrent与QFutureWatcher:实现高效异步计算
  • puppeteer生成PDF实践
  • Windows桌面添加我的电脑
  • 响应式网站和非响应式网站的区别wordpress 兼容php7
  • 03.OpenStack界面管理
  • 深度学习与大模型完全指南:从神经网络基础到模型训练实战
  • 神经网络发展【深度学习】
  • 类似红盟的网站怎么做阿里巴巴官网登录
  • 自创字 网站php开源网站管理系统
  • Linux Shell 中静默登录另一台机器并执行SQL文件
  • Python 实战:Web 漏洞 Python POC 代码及原理详解(1)
  • 前端学习之八股和算法
  • dataonline.vn免费Web容器的使用