当前位置: 首页 > news >正文

分类模型(BERT)训练全流程

使用BERT实现分类模型的完整训练流程

BERT (Bidirectional Encoder Representations from Transformers) 是一种强大的预训练语言模型,在各种NLP任务中表现出色。下面我将详细梳理使用BERT实现文本分类模型的完整训练过程。

1. 准备工作

1.1 环境配置

pip install transformers torch tensorflow pandas sklearn

1.2 选择BERT版本

  • BERT-base (110M参数)
  • BERT-large (340M参数)
  • 中文BERT (如bert-base-chinese)
  • 领域特定BERT (如BioBERT, SciBERT)

2. 数据准备

2.1 数据格式

text,label
"这个电影很好看",1
"产品体验很差",0
...

2.2 数据预处理

import pandas as pd
from sklearn.model_selection import train_test_split# 读取数据
df = pd.read_csv('data.csv')# 划分训练集和验证集
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)# 查看类别分布
print(train_df['label'].value_counts())

3. 使用Transformers库加载BERT

3.1 导入必要组件

from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
from torch.utils.data import Dataset, DataLoader

3.2 初始化Tokenizer和Model

# 选择预训练模型
MODEL_NAME = 'bert-base-chinese'  # 中文模型# 加载分词器
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)# 加载模型
model = BertForSequenceClassification.from_pretrained(MODEL_NAME,num_labels=len(train_df['label'].unique()),  # 类别数量output_attentions=False,output_hidden_states=False
)# 移至GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

4. 创建数据集和数据加载器

4.1 自定义Dataset类

class TextDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = str(self.texts[idx])label = self.labels[idx]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'labels': torch.tensor(label, dtype=torch.long)}

4.2 创建数据加载器

MAX_LEN = 128  # BERT最大输入长度
BATCH_SIZE = 32def create_data_loader(df, tokenizer, max_len, batch_size):ds = TextDataset(texts=df['text'].to_numpy(),labels=df['label'].to_numpy(),tokenizer=tokenizer,max_len=max_len)return DataLoader(ds,batch_size=batch_size,num_workers=4)train_data_loader = create_data_loader(train_df, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(val_df, tokenizer, MAX_LEN, BATCH_SIZE)

5. 训练准备

5.1 设置优化器和学习率调度器

EPOCHS = 3
LEARNING_RATE = 2e-5optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHSscheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=0,num_training_steps=total_steps
)loss_fn = torch.nn.CrossEntropyLoss().to(device)

5.2 训练函数

def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):model = model.train()losses = []correct_predictions = 0for d in data_loader:input_ids = d["input_ids"].to(device)attention_mask = d["attention_mask"].to(device)labels = d["labels"].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask,labels=labels)loss = outputs.losslogits = outputs.logits_, preds = torch.max(logits, dim=1)correct_predictions += torch.sum(preds == labels)losses.append(loss.item())loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()scheduler.step()optimizer.zero_grad()return correct_predictions.double() / n_examples, np.mean(losses)

5.3 评估函数

def eval_model(model, data_loader, loss_fn, device, n_examples):model = model.eval()losses = []correct_predictions = 0with torch.no_grad():for d in data_loader:input_ids = d["input_ids"].to(device)attention_mask = d["attention_mask"].to(device)labels = d["labels"].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask,labels=labels)loss = outputs.losslogits = outputs.logits_, preds = torch.max(logits, dim=1)correct_predictions += torch.sum(preds == labels)losses.append(loss.item())return correct_predictions.double() / n_examples, np.mean(losses)

6. 训练循环

from collections import defaultdict
import numpy as nphistory = defaultdict(list)
best_accuracy = 0for epoch in range(EPOCHS):print(f'Epoch {epoch + 1}/{EPOCHS}')print('-' * 10)train_acc, train_loss = train_epoch(model,train_data_loader,loss_fn,optimizer,device,scheduler,len(train_df))print(f'Train loss {train_loss} accuracy {train_acc}')val_acc, val_loss = eval_model(model,val_data_loader,loss_fn,device,len(val_df))print(f'Val loss {val_loss} accuracy {val_acc}')print()history['train_acc'].append(train_acc)history['train_loss'].append(train_loss)history['val_acc'].append(val_acc)history['val_loss'].append(val_loss)if val_acc > best_accuracy:torch.save(model.state_dict(), 'best_model_state.bin')best_accuracy = val_acc

7. 模型评估与预测

7.1 加载最佳模型

model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
model.load_state_dict(torch.load('best_model_state.bin'))
model = model.to(device)

7.2 预测函数

def get_predictions(model, data_loader):model = model.eval()review_texts = []predictions = []prediction_probs = []real_values = []with torch.no_grad():for d in data_loader:texts = d["text"]input_ids = d["input_ids"].to(device)attention_mask = d["attention_mask"].to(device)labels = d["labels"].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask)_, preds = torch.max(outputs.logits, dim=1)probs = torch.softmax(outputs.logits, dim=1)review_texts.extend(texts)predictions.extend(preds)prediction_probs.extend(probs)real_values.extend(labels)predictions = torch.stack(predictions).cpu()prediction_probs = torch.stack(prediction_probs).cpu()real_values = torch.stack(real_values).cpu()return review_texts, predictions, prediction_probs, real_values

7.3 生成分类报告

from sklearn.metrics import classification_report, confusion_matrixy_review_texts, y_pred, y_pred_probs, y_test = get_predictions(model, val_data_loader)print(classification_report(y_test, y_pred))

8. 模型保存与部署

8.1 保存整个模型

model.save_pretrained("./my_bert_classifier")
tokenizer.save_pretrained("./my_bert_classifier")

8.2 创建预测API示例

from flask import Flask, request, jsonify
import torchapp = Flask(__name__)# 加载模型和tokenizer
model = BertForSequenceClassification.from_pretrained('./my_bert_classifier')
tokenizer = BertTokenizer.from_pretrained('./my_bert_classifier')
model.eval()@app.route('/predict', methods=['POST'])
def predict():data = request.get_json()text = data['text']encoded_text = tokenizer.encode_plus(text,max_length=128,add_special_tokens=True,return_token_type_ids=False,padding='max_length',return_attention_mask=True,return_tensors='pt',)input_ids = encoded_text['input_ids']attention_mask = encoded_text['attention_mask']with torch.no_grad():output = model(input_ids, attention_mask)_, prediction = torch.max(output.logits, dim=1)prob = torch.softmax(output.logits, dim=1)return jsonify({'prediction': prediction.item(),'probability': prob[0][prediction.item()].item()})if __name__ == '__main__':app.run(host='0.0.0.0', port=5000)

关键注意事项

  1. 学习率选择:BERT通常使用很小的学习率(2e-5到5e-5)
  2. 批量大小:根据GPU内存选择,通常16-64
  3. 训练轮次:BERT微调通常3-5个epoch就足够
  4. 序列长度:根据任务调整MAX_LEN,太长会浪费计算资源
  5. 类别不平衡:可使用class_weight参数调整损失函数
  6. GPU使用:建议使用CUDA加速训练

通过以上流程,您可以完整地实现一个基于BERT的文本分类模型,从数据准备到训练评估,最后到部署应用。

http://www.dtcms.com/a/292909.html

相关文章:

  • IO复用(多路转接)
  • c语言学习(days08)
  • 对比学习 | 软标签损失计算
  • 安科瑞工商业光储充新能源电站ACCU-100M微电网协调控制器
  • MyBatis-Plus 分页实战
  • 目前主流的AI深度学习框架对Windows和Linux的支持哪个更好
  • 单细胞转录组学+空间转录组的整合及思路
  • 一个不起眼的问题,导致插件加载失败
  • python中 tqdm ,itertuples 是什么
  • 学习软件测试的第十九天
  • ​Eyeriss 架构中的访存行为解析(腾讯元宝)
  • Java学习----Redis集群
  • SHAP的升级版:可解释性框架Alibi的相关介绍(一)
  • L1与L2正则化:核心差异全解析
  • RabbitMQ03——面试题
  • DOM/事件高级
  • haprox七层代理
  • 医院如何实现节能降耗?
  • <另一种思维:语言模型如何展现人类的时间认知>读后总结
  • 【上市公司变量测量】Python+FactSet Revere全球供应链数据库,测度供应链断裂与重构变量——丁浩员等(2024)《经济研究》复现
  • Day28| 122.买卖股票的最佳时机 II、55. 跳跃游戏、45.跳跃游戏 II、1005.K次取反后最大化的数组和
  • Spring AI Alibaba + JManus:从架构原理到生产落地的全栈实践——一篇面向 Java 架构师的 20 分钟深度阅读
  • MSTP实验
  • 深入理解 Qt 中的 QImage 与 QPixmap:底层机制、差异、优化策略全解析
  • 集训Demo5
  • 代码检测SonarQube+Git安装和规范
  • 从FDTD仿真到光学神经网络:机器学习在光子器件设计中的前沿应用工坊
  • Matlab学习笔记:界面使用
  • 【数据结构初阶】--栈和队列(二)
  • CanOpen--SDO 数据帧分析