当前位置: 首页 > news >正文

【人工智能99问】BERT的训练过程和推理过程是怎么样的?(24/99)

文章目录

  • BERT的训练过程与推理过程
    • 一、预训练过程:学习通用语言表示
      • 1. 数据准备
      • 2. MLM任务训练(核心)
      • 3. NSP任务训练
      • 4. 预训练优化
    • 二、微调过程:适配下游任务
      • 1. 任务定义与数据
      • 2. 输入处理
      • 3. 模型结构调整
      • 4. 微调训练
    • 三、推理过程:对新数据预测
      • 1. 推理输入预处理
      • 2. 模型前向传播
      • 3. 输出预测结果
    • 四、训练与推理的核心区别
    • 五、总结

BERT的训练过程与推理过程

BERT的核心流程分为预训练(无监督学习通用语言知识)和微调(监督学习适配下游任务),而推理则是基于微调后的模型对新数据进行预测。以下结合具体实例详细说明。

一、预训练过程:学习通用语言表示

预训练是BERT的“筑基阶段”,通过无标注文本学习语言规律,核心任务为Masked Language Model(MLM)Next Sentence Prediction(NSP)

1. 数据准备

  • 语料来源:大规模无标注文本(如BooksCorpus、Wikipedia),示例文本片段:
    “自然语言处理是人工智能的重要分支。它研究计算机与人类语言的交互。”
  • 预处理步骤
    1. 分词:用WordPiece分词工具拆分文本为子词(Subword)。例如“人工智能”→“人工”+“智能”;“交互”→“交”+“互”。
    2. 构造输入序列:对单句或句子对添加特殊符号,最长长度限制为512 Token。
      • 单句格式:[CLS] 自然 语言 处理 是 人工 智能 的 重要 分支 。 [SEP]
      • 句子对格式:[CLS] 自然 语言 处理 是 人工智能 的 重要 分支 。 [SEP] 它 研究 计算机 与 人类 语言 的 交互 。 [SEP]

2. MLM任务训练(核心)

目标:随机掩盖部分Token,让模型预测被掩盖的Token,强制学习双向上下文。

  • 具体操作
    从输入序列中随机选择15%的Token进行处理:
    • 80%概率替换为[MASK]
      原序列:[CLS] 自然 语言 处理 是 人工 智能 的 重要 分支 。 [SEP]
      处理后:[CLS] 自然 [MASK] 处理 是 人工 智能 的 [MASK] 分支 。 [SEP]
      (需预测被掩盖的“语言”和“重要”)
    • 10%概率替换为随机Token:
      处理后:[CLS] 自然 图像 处理 是 人工 智能 的 重要 分支 。 [SEP]
      (“语言”被随机替换为“图像”,模型需识别错误并预测正确Token)
    • 10%概率不替换(保持原Token):
      处理后:[CLS] 自然 语言 处理 是 人工 智能 的 重要 分支 。 [SEP]
      (模型需“假装”不知道哪些Token被掩盖,避免依赖[MASK]符号)
  • 损失计算:对被掩盖的Token,通过输出层的Softmax计算预测概率,用交叉熵损失优化模型,让预测值接近真实Token。

3. NSP任务训练

目标:让模型学习句子间的逻辑关系,判断句子对是否连续。

  • 具体操作
    构造句子对(A, B),50%为真实连续句(正例),50%为随机拼接句(负例):
    • 正例:A=“自然语言处理是人工智能的重要分支。”,B=“它研究计算机与人类语言的交互。”(B是A的下一句)。
    • 负例:A=“自然语言处理是人工智能的重要分支。”,B=“猫是一种常见的家养宠物。”(B与A无关)。
      输入序列格式:[CLS] A [SEP] B [SEP],并添加Segment Embedding(A部分为0,B部分为1)。
  • 损失计算:通过[CLS]Token的输出向量接二分类层,用交叉熵损失优化模型,区分正例(标签1)和负例(标签0)。

4. 预训练优化

  • 训练配置:BERT-Base使用12层Transformer,批量大小256,训练步数100万,优化器为Adam(学习率5e-5,β1=0.9,β2=0.999)。
  • 正则化:每层加入Dropout(概率0.1),防止过拟合;使用层归一化稳定训练。
  • 硬件支持:需大规模GPU集群(如16块TPU,训练约4天),通过梯度累积处理大批次数据。

二、微调过程:适配下游任务

预训练完成后,模型已具备通用语言理解能力,微调阶段需针对具体任务(如情感分析、问答)调整输出层并优化参数。以下以情感分析任务(判断文本为“正面”或“负面”)为例说明。

1. 任务定义与数据

  • 任务目标:输入文本(如“这部电影剧情紧凑,演员演技出色!”),输出情感标签(正面/负面)。
  • 微调数据:带标签的情感语料,示例:
    文本标签
    “这部电影剧情紧凑,演员演技出色!”正面
    “画面模糊,音效刺耳,不推荐观看。”负面

2. 输入处理

  • 单句输入格式:[CLS] 这 部 电影 剧情 紧凑 , 演员 演技 出色 ! [SEP]
  • 嵌入向量构造:同预训练,包含Token Embedding(子词向量)、Segment Embedding(全为0,因单句)、Position Embedding(位置编码)。

3. 模型结构调整

  • 复用预训练的Transformer编码器,仅在输出层添加分类头
    [CLS]Token经过12层Transformer后的输出向量(768维),接入全连接层+Softmax,输出“正面”“负面”的概率。
    即:P(正面/负面)=Softmax(W⋅H[CLS]+b)P(正面/负面) = Softmax(W \cdot H_{[CLS]} + b)P(正面/负面)=Softmax(WH[CLS]+b),其中H[CLS]H_{[CLS]}H[CLS][CLS]的最终向量,WWWbbb是微调阶段待学习的参数。

4. 微调训练

  • 参数设置:冻结部分底层Transformer参数(或微调所有参数),学习率设为2e-5(低于预训练,避免破坏预训练知识),批量大小32,训练轮次3-5轮。
  • 损失计算:对每个样本,用交叉熵损失计算预测概率与真实标签的差异,通过反向传播更新分类头和少量Transformer参数。
  • 早停策略:用验证集监控性能,若连续多轮无提升则停止训练,防止过拟合。

三、推理过程:对新数据预测

微调完成后,模型可对未标注的新文本进行情感预测,推理过程即模型的前向传播计算。

1. 推理输入预处理

  • 输入新文本:“这部电影的特效震撼,剧情感人至深。”
  • 分词与格式转换:
    分词后:[CLS] 这 部 电影 的 特效 震撼 , 剧情 感人 至 深 。 [SEP]
    转换为嵌入向量:Token Embedding + Segment Embedding(全0) + Position Embedding。

2. 模型前向传播

  • 嵌入向量输入12层Transformer编码器:
    每层通过多头自注意力捕捉Token间关联(如“特效”与“震撼”、“剧情”与“感人”的语义强化),前馈网络进行非线性变换,最终输出每个Token的上下文向量。
  • 提取[CLS]向量:H_{[CLS]}(768维)聚合了整个句子的情感语义。

3. 输出预测结果

  • H_{[CLS]}输入分类头,计算概率:
    P(正面)=0.92P(正面) = 0.92P(正面)=0.92P(负面)=0.08P(负面) = 0.08P(负面)=0.08
  • 决策:取概率最高的标签,输出“正面”。

四、训练与推理的核心区别

维度预训练过程微调过程推理过程
数据类型无标注文本(大规模)有标注任务数据(小规模)无标注新文本(单条/批量)
目标学习通用语言规律(MLM+NSP)适配具体任务(如情感分类)输出新数据的预测结果
参数更新全量参数优化(数百万步)少量参数微调(数轮)无参数更新(仅前向传播)
计算复杂度极高(需大规模算力)中等(单GPU可完成)低(实时响应)

五、总结

BERT的训练过程通过“预训练筑基+微调适配”实现知识迁移:预训练用无监督任务从海量文本中学习语言本质,微调则用少量标注数据将通用知识转化为任务能力;而推理则是微调后模型对新数据的高效预测。这种模式大幅降低了NLP任务的落地门槛,成为现代自然语言处理的核心范式。

http://www.dtcms.com/a/323950.html

相关文章:

  • 部署一个自己的音乐播放器教程
  • Windows安装MySql8.0
  • MariaDB 数据库管理与web服务器
  • 双非二本如何找工作?
  • NVIDIA-SMI has failed because it couldn’t communicate with the NVIDIA driver.
  • 软件编程1-shell命令
  • RabbitMQ面试精讲 Day 18:内存与磁盘优化配置
  • 深度学习-卷积神经网络CNN-AlexNet
  • LeetCode_哈希表
  • 智能体革命:网络安全人的角色重塑与突围指南
  • GPU指令集入门教程
  • 安全运维工具链全解析
  • 代码可读性与维护性的实践与原则
  • H3C(基于Comware操作系统)与eNSP平台(模拟华为VRP操作系统)的命令差异
  • Vulhub靶场组件漏洞(XStream,fastjson,Jackson)
  • 【Vue✨】Vue3 中英文切换功能实现
  • kubernetes安装搭建
  • nginx+Lua环境集成、nginx+Lua应用
  • 【东枫科技】NTN-IOT 卫星互联网原型系统,高达1.6G大带宽
  • LeetCode简单题 - 学习
  • java生成用户登录token
  • Android Camera 打开和拍照APK源码
  • Redis实现消息队列三种方式
  • 前端学习日记 - 前端函数防抖详解
  • c#属性(Property)的概念定义及使用详解
  • 音视频学习(五十二):ADTS
  • i2c dump工具使用(202589)
  • WAV音频数据集MFCC特征提取处理办法
  • 人工智能正在学习自我提升的方式
  • Agent在游戏行业的应用:NPC智能化与游戏体验提升