当前位置: 首页 > news >正文

从代码学习深度学习 - 自然语言推断:微调BERT PyTorch版

文章目录

  • 前言
  • 加载预训练的BERT
    • 构建PyTorch数据集
    • 创建数据加载器
    • 微调BERT
    • 定义分类模型
    • 训练模型
  • 辅助工具代码
  • 总结


前言

自然语言推断(NLI)是自然语言处理(NLP)领域一个核心且富有挑战性的任务。它的目标是判断两个句子——“前提(Premise)”和“假设(Hypothesis)”——之间的逻辑关系。这种关系通常分为三类:

  1. 蕴含(Entailment): 假设的意义可以从前提中推断出来。
  2. 矛盾(Contradiction): 假设的意义与前提相矛盾。
  3. 中性(Neutral): 前提和假设之间没有明确的逻辑关系。

例如:

  • 前提: 一个人在马上。
  • 假设: 一个人在动物身上。
  • 关系: 蕴含

近年来,以BERT(Bidirectional Encoder Representations from Transformers)为代表的预训练语言模型在众多NLP任务中取得了革命性的突破。其强大的上下文理解能力,使其成为解决NLI等任务的理想选择。

本篇博客将带领大家,通过PyTorch代码,一步步实现如何“微调(Fine-tuning)”一个预训练好的BERT模型,使其适应并高效地完成自然语言推断任务。我们将使用经典的SNLI(Stanford Natural Language Inference)数据集,并详细剖析从数据加载、模型构建到最终训练的全过程。
在这里插入图片描述

完整代码:[通过网盘分享的文件:自然语言推断:微调BERT.rar
链接: https://pan.baidu.com/s/1OxS-BU0MSOJXXB5wJA394w?pwd=8rc6 提取码: 8rc6
–来自百度网盘超级会员v6的分享]


加载预训练的BERT

微调的第一步,是加载一个已经在海量文本数据上(如维基百科)预训练好的BERT模型。这个预训练过程让BERT学会了通用的语言知识,我们要做的是在这个基础上,针对我们的特定任务(NLI)进行“微调”。

我们定义一个函数 load_pretrained_model 来加载模型及其对应的词汇表。词汇表(Vocabulary)是词元(token)到索引(index)的映射,是模型处理文本的基础。

import json
import os
import torch
import utils_for_vocab
import utils_for_model
import utils_for_traindef load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,num_heads, num_layers, dropout, max_len, devices):"""加载预训练的BERT模型和词汇表参数:pretrained_model (str): 预训练模型名称,用于构建数据目录路径num_hiddens (int): 隐藏层维度 [256]ffn_num_hiddens (int): 前馈网络隐藏层维度 [512]num_heads (int): 多头注意力机制的头数 [4]num_layers (int): Transformer层数 [2]dropout (float): dropout比例 [0.1]max_len (int): 最大序列长度 [512]devices (list): 可用的GPU设备列表返回:bert (BERTModel): 加载了预训练参数的BERT模型vocab (Vocab): 词汇表对象"""# 构建数据目录路径data_dir = pretrained_model + ".torch"# 定义空词表以加载预定义词表vocab = utils_for_vocab.Vocab()# 从JSON文件加载词汇表的索引到词汇的映射# vocab.idx_to_token: list,维度为 [vocab_size],存储索引到词汇的映射vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))# 构建词汇到索引的映射字典# vocab.token_to_idx: dict,存储词汇到索引的映射vocab.token_to_idx = {token: idx for idx, token in enumerate(vocab.idx_to_token)}# 创建BERT模型实例# bert: BERTModel对象,包含编码器和预训练任务头bert = utils_for_model.BERTModel(len(vocab),                    # vocab_size: 词汇表大小 [vocab_size]num_hiddens,                   # num_hiddens: 隐藏层维度 [256]norm_shape=[256],              # norm_shape: 层归一化的形状 [256]ffn_num_input=256,             # ffn_num_input: 前馈网络输入维度 [256]ffn_num_hiddens=ffn_num_hiddens,  # ffn_num_hiddens: 前馈网络隐藏层维度 [512]num_heads=4,                   # num_heads: 多头注意力头数 [4]num_layers=2,                  # num_layers: Transformer层数 [2]dropout=0.2,                   # dropout: dropout比例 [0.2]max_len=max_len,               # max_len: 最大序列长度 [512]key_size=256,                  # key_size: 注意力机制中key的维度 [256]query_size=256,                # query_size: 注意力机制中query的维度 [256]value_size=256,                # value_size: 注意力机制中value的维度 [256]hid_in_features=256,           # hid_in_features: 隐藏层输入特征维度 [256]mlm_in_features=256,           # mlm_in_features: 掩码语言模型输入特征维度 [256]nsp_in_features=256            # nsp_in_features: 下一句预测任务输入特征维度 [256])# 加载预训练的BERT模型参数# torch.load返回的是state_dict,包含模型的所有参数bert.load_state_dict(torch.load(os.path.join(data_dir, 'pretrained.params')))return bert, vocab# 获取所有可用的GPU设备
# devices: list,包含可用GPU设备的列表
devices = utils_for_train.try_all_gpus()# 加载预训练的BERT模型和词汇表
# bert: BERTModel对象,已加载预训练参数
# vocab: Vocab对象,包含词汇表映射
bert, vocab = load_pretrained_model('bert.small',                # pretrained_model: 预训练模型名称num_hiddens=256,             # num_hiddens: 隐藏层维度 [256]ffn_num_hiddens=512,         # ffn_num_hiddens: 前馈网络隐藏层维度 [512]num_heads=4,                 # num_heads: 多头注意力头数 [4]num_layers=2,                # num_layers: Transformer层数 [2]dropout=0.1,                 # dropout: dropout比例 [0.1]max_len=512,                 # max_len: 最大序列长度 [512]devices=devices              # devices: 可用GPU设备列表
)

这里我们加载了一个小型的BERT模型(bert.small),它包含2个Transformer层,隐藏层维度为256。加载完成后,我们可以打印bert对象,查看其详细的模型结构。

bert
```输出的模型结构会非常详细,它清晰地展示了BERT的内部组件,包括词元嵌入(`token_embedding`)、片段嵌入(`segment_embedding`)、由多个编码器块(`EncoderBlock`)组成的编码器(`encoder`),以及用于预训练的MLM(`MaskLM`)和NSP(`NextSentencePred`)任务头。在微调阶段,我们主要关心的是`encoder`部分。## 微调BERT的数据集数据是模型训练的“养料”。对于NLI任务,我们需要将SNLI数据集处理成BERT能够理解的格式。### 数据读取与预处理首先,我们需要一个函数来读取SNLI数据集的原始文本文件。该文件是制表符分隔的,我们需要从中抽取出前提、假设和标签。```python
# 该函数位于 utils_for_data.py
def read_snli(data_dir, is_train):"""将SNLI数据集解析为前提、假设和标签"""# ... (代码见附录)

BERT处理成对的句子(如前提和假设)时,需要一种特殊的输入格式。两个句子被拼接在一起,并用特殊标记隔开:
[CLS] 前提词元 [SEP] 假设词元 [SEP]

  • [CLS]:位于序列开头,它的最终隐藏状态被用作整个序列的聚合表示,通常用于分类任务。
  • [SEP]:用于分隔两个句子。

我们还需要一个“片段索引(Segment ID)”,用来区分哪个词元属于前提(标记为0),哪个属于假设(标记为1)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
图2: BERT处理句子对(如NLI任务)的输入格式图示。

get_tokens_and_segments 函数负责实现这个格式转换。

# 该函数位于 utils_for_data.py
def get_tokens_and_segments(tokens_a, tokens_b=None):"""获取输入序列的词元及其片段索引"""# ... (代码见附录)

构建PyTorch数据集

为了与PyTorch的 DataLoader 高效配合,我们创建一个自定义的Dataset类——SNLIBERTDataset。这个类封装了所有的数据预处理逻辑:

  1. 词元化(Tokenization): 将句子切分成词元。
  2. 格式化: 调用get_tokens_and_segments构建BERT输入格式。
  3. 截断(Truncation): 由于BERT输入有最大长度限制(如128或512),需要将过长的句子对进行截断。
  4. 填充(Padding): 将所有序列填充到相同的最大长度,以便进行批量处理。
  5. 数值化: 将词元转换为词汇表中的索引。

这个类还巧妙地使用了Python的multiprocessing库来并行处理数据,极大地加速了预处理过程。

import torch
import multiprocessing
import utils_for_data
import utils_for_vocabclass SNLIBERTDataset(torch.utils.data.Dataset):"""用于BERT模型的SNLI数据集处理类该类继承自torch.utils.data.Dataset,用于处理Stanford Natural Language Inference (SNLI)数据集,将其转换为适合BERT模型训练的格式。"""def __init__(self, dataset, max_len, vocab=None):"""初始化SNLI BERT数据集"""# 对前提和假设句子进行词元化处理all_premise_hypothesis_tokens = [[\p_tokens, h_tokens] for p_tokens, h_tokens in zip(\*[utils_for_vocab.tokenize([s.lower() for s in sentences])for sentences in dataset[:2]])]# 将标签转换为张量self.labels = torch.tensor(dataset[2])self.vocab = vocabself.max_len = max_len# 预处理所有的词元对,生成模型输入格式(self.all_token_ids, self.all_segments,self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)print('read ' + str(len(self.all_token_ids)) + ' examples')def _preprocess(self, all_premise_hypothesis_tokens):"""使用多进程预处理所有的前提-假设词元对"""pool = multiprocessing.Pool(4)  # 使用4个进程out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)all_token_ids = [token_ids for token_ids, segments, valid_len in out]all_segments = [segments for token_ids, segments, valid_len in out]valid_lens = [valid_len for token_ids, segments, valid_len in out]return (torch.tensor(all_token_ids, dtype=torch.long),torch
http://www.dtcms.com/a/278338.html

相关文章:

  • Cesium 9 ,Cesium 离线地图本地实现与服务器部署( Vue + Cesium 多项目共享离线地图切片部署实践 )
  • H264的帧内编码和帧间编码
  • 2025年睿抗机器人开发者大赛CAIP-编程技能赛本科组(省赛)解题报告 | 珂学家
  • Python 变量与简单输入输出:从零开始写你的第一个交互程序
  • 【Java入门到精通】(四)Java语法进阶
  • 动手学深度学习——线性回归的从零开始实现
  • 【记录】BLE|百度的旧蓝牙随身音箱手机能配对不能连接、电脑能连接不能使用的解决思路(Wireshark捕获并分析手机蓝牙报文)
  • 1.2.2 高级特性详解——AI教你学Django
  • 【图片识别改名】水印相机拍的照片如何将照片的名字批量改为水印内容?图片识别改名的详细步骤和注意事项
  • 【WPF】WPF 自定义控件 实战详解,含命令实现
  • 【零基础入门unity游戏开发——unity3D篇】3D光源之——unity6的新功能Adaptive Probe Volumes(APV)(自适应探针体积)
  • ACL流量控制实验
  • 深入了解linux系统—— 进程信号的产生
  • 客户端主机宕机,服务端如何处理 TCP 连接?详解
  • EasyExcel实现Excel文件导入导出
  • VScode链接服务器一直卡在下载vscode服务器,无法连接成功
  • C++之哈希表的基本介绍以及其自我实现(开放定址法版本)
  • 多客户端 - 服务器结构-实操
  • 史上最清楚!读者,写者问题(操作系统os)
  • 基于 Gitlab、Jenkins与Jenkins分布式、SonarQube 、Nexus 的 CiCd 全流程打造
  • SQL创建三个表
  • 从 JSON 到 Python 对象:一次通透的序列化与反序列化之旅
  • Dubbo高阶难题:异步转同步调用链上全局透传参数的丢失问题
  • Selenium动态网页爬虫编写与解释
  • 【微信小程序】
  • 当你在 Git 本地提交后,因权限不足无法推送到服务端,若想撤销本次提交,可以根据不同的需求选择合适的方法,下面为你介绍两种常见方式。
  • 清除 Android 手机 SIM 卡数据的4 种简单方法
  • 云手机常见问题解析:解决延迟、掉线等困扰
  • 云手机的多重用途:从游戏挂机到办公自动化
  • kafka的部署