当前位置: 首页 > news >正文

Transformer实战(15)——使用PyTorch微调Transformer语言模型

Transformer实战(15)——使用PyTorch微调Transformer语言模型

    • 0. 前言
    • 1. 从零开始微调预训练模型
    • 2. 单步前向传播和反向传播
    • 3. 训练循环
    • 小结
    • 系列链接

0. 前言

在本节中,我们将全面剖析预训练 Transformer 模型的微调过程。相比于依赖高级 API (如 Trainer )的便捷封装,本节聚焦于使用 PyTorch 手动构建训练管道——包括模型加载、优化器配置、前向传播、反向传播、损失计算以及自定义数据集与数据加载器的实现。通过对单步前向与反向传播的演示,再到完整的 epoch 循环与验证流程,将深入理解 AdamW 优化器在 Transformer 微调中的优势,以及如何将批数据高效地送入模型进行训练和评估。

1. 从零开始微调预训练模型

接下来,我们将从零开始微调预训练模型,以了解其背后的运行机制。

(1) 首先,加载模型进行微调。本节我们选择 DistilBert,它是 BERT 的一个轻量、快速的版本:

from transformers import DistilBertForSequenceClassification
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')

(2) 要对模型进行微调,需要将其设置为训练模式:

model.train()

(3) 接下来,加载分词器:

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('bert-base-uncased')

(4) 由于 Trainer 类为我们组织了整个过程,因此在之前的 IMDb 情感分类模型训练中,我们并未涉及优化和其他训练设置。在本节中,我们需要自己实例化优化器,选择 AdamW 优化器,它是 Adam 算法的一种改进实现,解决了权重衰减问题。研究表明,AdamW 比使用 Adam 训练的模型能够产生更好的训练损失和验证损失优化表现,因此在 Transformer 训练过程中被广泛使用:

from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=1e-3)

2. 单步前向传播和反向传播

为了从零开始设计微调过程,我们必须理解如何实现单步前向传播和反向传播。将单个批次的数据通过 Transformer 层并获取输出,即前向传播。然后,使用输出和真实标签计算损失,并根据损失更新模型权重,即反向传播。

(1) 接收一个批次中的三个句子及其对应的标签,并执行前向传播。最后,模型会自动计算损失:

import torch
texts= ["this is a good example","this is a bad example","this is a good one"]
labels= [1,0,1]
labels = torch.tensor(labels).unsqueeze(0)
encoding = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
outputs
# SequenceClassifierOutput(loss=tensor(0.6841, grad_fn=<NllLossBackward0>), logits=tensor([[ 0.0428,  0.1283],
#        [ 0.0652,  0.2260],
#        [-0.0192,  0.1239]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

模型接受由分词器生成的 input_idsattention_mask 作为输入,并使用真实标签计算损失。输出包含损失和 logitsloss.backward() 会通过使用输入和标签对模型进行评估来计算张量的梯度。optimizer.step() 执行一次优化步骤,并使用已计算的梯度更新权重,这就是反向传播。将这些代码放入一个循环中时,还会添加 optimizer.zero_grad() 清除所有参数的梯度。在循环开始时调用这一操作非常重要,否则可能会积累多个步骤的梯度。输出的第二个张量是 logits,在深度学习中,logits(logistic units) 是神经网络结构的最后一层,由实数预测值组成。在分类问题中,logits 需要通过 softmax 函数转换为概率;而对于回归任务,则只需进行简单的归一化处理。

(2) 模型只会输出 logits,而不会计算损失。以下示例展示了如何手动计算交叉熵损失:

from torch.nn import functional
labels = torch.tensor([1,0,1])
outputs = model(input_ids, attention_mask=attention_mask)
loss = functional.cross_entropy(outputs.logits, labels)
loss.backward()
optimizer.step()
loss
# tensor(0.5494, grad_fn=<NllLossBackward0>)

3. 训练循环

通过以上内容,我们了解了如何将批数据通过网络进行单步前向传播。接下来,设计循环以批次形式遍历整个数据集,进行多个 epoch 的训练。

(1) 首先实现 Dataset,它是 torch.Dataset 的子类,继承了成员变量和函数,并实现了 __init__()__getitem__() 这两个抽象函数:

from torch.utils.data import Dataset
class MyDataset(Dataset):def __init__(self, encodings, labels):self.encodings = encodingsself.labels = labelsdef __getitem__(self, idx):item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}item['labels'] = torch.tensor(self.labels[idx])return itemdef __len__(self):return len(self.labels)

(2) 接下来,使用 sst2 (Stanford Sentiment Treebank v2) 数据集,来微调模型进行情感分析。我们还将加载与 sst2 对应的评估指标:

import datasets
from datasets import load_dataset
sst2= load_dataset("glue","sst2")
from evaluate import load
metric = load("glue", "sst2")

(3) 提取句子和对应的标签:

texts=sst2['train']['sentence']
labels=sst2['train']['label']
val_texts=sst2['validation']['sentence']
val_labels=sst2['validation']['label']

(4) 将数据集通过分词器处理,并实例化 MyDataset 对象,使 DistiBert 模型能够处理这些数据:

train_dataset= MyDataset(tokenizer(texts, truncation=True, padding=True), labels)
val_dataset=  MyDataset(tokenizer(val_texts, truncation=True, padding=True), val_labels)

(5) 接下来,实例化一个 DataLoader 类,它提供了按加载顺序遍历数据样本的接口,也有助于批处理和内存固定:

from torch.utils.data import DataLoadertrain_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader =  DataLoader(val_dataset, batch_size=16, shuffle=True)

(6) 检测设备并定义 AdamW 优化器:

from transformers import  AdamWdevice = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)

(7) 我们已经了解了如何实现前向传播,在这个过程中,批数据通过神经网络通过神经网络的每一层,从第一个层到最后一个层,依次处理,并经过激活函数后传递到下一个层。为了在多个 epoch 中遍历整个数据集,需要嵌套循环:外循环用于 epoch,而内循环则用于处理每个批次。每个 epoch 由两个模块组成;一个用于训练,另一个用于评估。在训练循环中调用了 model.train(),而在评估循环中调用了 model.eval()。通过相应的指标对象来跟踪模型的表现:

for epoch in range(3):model.train()for batch in train_loader:optimizer.zero_grad()input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs[0]loss.backward()optimizer.step()model.eval()for batch in val_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)predictions=outputs.logits.argmax(dim=-1)  metric.add_batch(predictions=predictions,references=batch["labels"],)eval_metric = metric.compute()print(f"epoch {epoch}: {eval_metric}")

微调后的模型,达到了大约 90.94% 的准确率:

epoch 0: {'accuracy': 0.9048165137614679}
epoch 1: {'accuracy': 0.8944954128440367}
epoch 2: {'accuracy': 0.9094036697247706}

小结

本节通过 DistilBert 的微调,介绍了使用 PyToch 从加载预训练模型、手动构建优化器,到实现前向与反向传播、管理梯度,展示了使用 PyTorch 微调 Transformer 语言模型的完整流程。我们演示了如何定义 DatasetDataLoader,并利用 SST-2 数据集完成多轮训练与评估,最终取得约 90.94% 的验证准确率。

系列链接

Transformer实战(1)——词嵌入技术详解
Transformer实战(2)——循环神经网络详解
Transformer实战(3)——从词袋模型到Transformer:NLP技术演进
Transformer实战(4)——从零开始构建Transformer
Transformer实战(5)——Hugging Face环境配置与应用详解
Transformer实战(6)——Transformer模型性能评估
Transformer实战(7)——datasets库核心功能解析
Transformer实战(8)——BERT模型详解与实现
Transformer实战(9)——Transformer分词算法详解
Transformer实战(10)——生成式语言模型 (Generative Language Model, GLM)
Transformer实战(11)——从零开始构建GPT模型
Transformer实战(12)——基于Transformer的文本到文本模型
Transformer实战(13)——从零开始训练GPT-2语言模型
Transformer实战(14)——微调Transformer语言模型用于文本分类

http://www.dtcms.com/a/352894.html

相关文章:

  • centos 判断一个对象是文件还是文件夹
  • HarmonyOS 高效数据存储全攻略:从本地优化到分布式实战
  • 财务报表怎么做?财务常用的报表软件都有哪些
  • vscode 调试 指定 python文件 运行路径
  • IO 字符流 【详解】| Java 学习日志 | 第 13 天
  • npm run start 的整个过程
  • LeetCode 刷题【54. 螺旋矩阵】
  • 共享云服务器替代传统电脑做三维设计会卡顿吗
  • Spring Boot 启动失败:循环依赖排查到懒加载配置的坑
  • 手写MyBatis第37弹: 深入MyBatis MapperProxy:揭秘SQL命令类型与动态方法调用的完美适配
  • 特征降维-特征组合
  • YOLO 目标检测:数据集构建(LabelImg 实操)、评估指标(mAP/IOU)、 NMS 后处理
  • Java全栈开发工程师的面试实战:从基础到微服务
  • 科普 | 5G支持的WWC架构是个啥(2)?
  • Android系统框架知识系列(十七):Telephony Service - 移动通信核心引擎深度解析
  • 5G NR学习笔记 预编码(precoding)和波束赋形(beamforming)
  • DAY 58 经典时序预测模型2
  • 不用伪基站也能攻破5G?Sni5Gect框架如何实现“隐形攻击”
  • spire.doc在word中生成公式
  • OpenCV实战1.信用卡数字识别
  • 第1.7节:机器学习 vs 深度学习 vs 强化学习
  • 20.19 LoRA微调Whisper终极指南:3步实现中文语音识别错误率直降37.8%
  • Apifox 8 月更新|新增测试用例、支持自定义请求示例代码、提升导入/导出 OpenAPI/Swagger 数据的兼容性
  • TDengine与StarRocks在技术架构和适用场景上有哪些主要区别?
  • 【C++】set 容器的使用
  • 面试记录6 c++开发工程师
  • 【PostgreSQL内核学习:通过 ExprState 提升哈希聚合与子计划执行效率】
  • 前端漏洞(下)- URL跳转漏洞
  • buuctf——web刷题第四页
  • Ansible模块实战,操作技巧