当前位置: 首页 > news >正文

BGE-M3 文本情感分类实战:预训练模型微调,导出ONNX并测试

基于BGE-M3开源模型(由智谱 AI 开发的中英双语通用向量表征模型),实现一个简单的二分类情感分析任务(积极 / 消极文本判断)。通过微调预训练模型,展示 NLP 中经典的 “预训练 + 微调” 范式。并导出ONNX模型,进行测试,以适应生产环境。

微调步骤

  1. 加载预训练模型与分词器;
  2. 构建分类头并整合到模型中;
  3. 数据预处理与 DataLoader 构建;
  4. 定义优化器、损失函数与训练循环;
  5. 验证与推理。
    实际应用中,需根据数据规模和任务复杂度调整模型结构与训练策略。BGE-M3 在语义表征任务中表现优异,适合作为文本分类、语义搜索等下游任务的基础模型。

微调代码

from transformers import AutoModel, AutoTokenizer,  get_linear_schedule_with_warmup
import torch
from torch.optim import AdamW  
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
model = AutoModel.from_pretrained("BAAI/bge-m3")# 定义分类模型
class TextClassifier(nn.Module):def __init__(self, model, hidden_size=1024, num_classes=2):super().__init__()self.model = modelself.classifier = nn.Sequential(nn.Dropout(0.1),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, num_classes))def forward(self, inputs):# 获取模型输出outputs = self.model(**inputs)# 使用 [CLS] tokencls_embedding = outputs.last_hidden_state[:, 0, :]# 通过分类器logits = self.classifier(cls_embedding)return logits# 模拟数据集
class TextDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_length=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_length,padding='max_length',truncation=True,return_tensors='pt')print(torch.tensor(label, dtype=torch.long))return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}# 初始化分类器
classifier = TextClassifier(model)# 模拟数据
texts = ["这是一个积极的句子,充满了正能量。","这是一个消极的句子,感觉非常糟糕。","今天天气真好,阳光明媚。","这个电影太无聊了,浪费时间。","我喜欢这个产品,质量非常好。","这个服务太差劲了,非常不满意。","大模型对程序员来说是一个很好的工具。","大模型对初级开发者来说是一个坏消息。"
]
#  PyTorch 的交叉熵损失函数 nn.CrossEntropyLoss 要求标签必须是从 0 开始的连续整数(如 0、1、2...)
labels = [1, 0, 1, 0, 1, 0,1,0]  # 1表示积极,0表示消极# 划分训练集和验证集
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)# 创建数据加载器
train_dataset = TextDataset(train_texts, train_labels, tokenizer)
val_dataset = TextDataset(val_texts, val_labels, tokenizer)train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2)# 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classifier.to(device)optimizer = AdamW(classifier.parameters(), lr=2e-5)
epochs = 3
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)# 训练循环
def train_epoch(model, dataloader, optimizer, scheduler, device):model.train()total_loss = 0for batch in tqdm(dataloader, desc="Training"):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)optimizer.zero_grad()outputs = model({'input_ids': input_ids,'attention_mask': attention_mask})loss = nn.CrossEntropyLoss()(outputs, labels)total_loss += loss.item()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()avg_loss = total_loss / len(dataloader)return avg_loss# 验证循环
def evaluate(model, dataloader, device):model.eval()total_loss = 0correct_predictions = 0total_predictions = 0with torch.no_grad():for batch in tqdm(dataloader, desc="Evaluating"):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model({'input_ids': input_ids,'attention_mask': attention_mask})loss = nn.CrossEntropyLoss()(outputs, labels)total_loss += loss.item()_, predictions = torch.max(outputs, dim=1)correct_predictions += (predictions == labels).sum().item()total_predictions += labels.size(0)avg_loss = total_loss / len(dataloader)accuracy = correct_predictions / total_predictionsreturn avg_loss, accuracy# 训练模型
print("开始训练模型...")
for epoch in range(epochs):print(f"Epoch {epoch + 1}/{epochs}")train_loss = train_epoch(classifier, train_dataloader, optimizer, scheduler, device)val_loss, val_accuracy = evaluate(classifier, val_dataloader, device)print(f"训练损失: {train_loss:.4f}")print(f"验证损失: {val_loss:.4f}")print(f"验证准确率: {val_accuracy:.4f}")print("-" * 50)# 保存模型
torch.save(classifier.state_dict(), 'text_classifier.pth')
print("模型训练完成并保存!")# 推理with torch.no_grad():text = "这是一个测试句子,非常棒!"inputs = tokenizer(text, return_tensors="pt").to(device)logits = classifier(inputs)predictions = torch.argmax(logits, dim=1)print(f"预测类别: {predictions.item()}")print("-"*80)text = "大模型对初级开发者来说是一个坏消息。"inputs = tokenizer(text, return_tensors="pt").to(device)logits = classifier(inputs)predictions = torch.argmax(logits, dim=1)print(f"预测类别: {predictions.item()}")

加载微调权重模型预测

from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn as nn
import numpy as np# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 定义分类模型
class TextClassifier(nn.Module):def __init__(self, model, hidden_size=1024, num_classes=2):super().__init__()self.model = modelself.classifier = nn.Sequential(nn.Dropout(0.1),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, num_classes))def forward(self, inputs):# 获取模型输出outputs = self.model(**inputs)# 使用 [CLS] tokencls_embedding = outputs.last_hidden_state[:, 0, :]# 通过分类器logits = self.classifier(cls_embedding)return logits# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
base_model = AutoModel.from_pretrained("BAAI/bge-m3")# 初始化分类器
classifier = TextClassifier(base_model)
classifier.to(device)# 加载保存的模型权重
try:classifier.load_state_dict(torch.load('text_classifier.pth', map_location=device))print("模型加载成功!")
except FileNotFoundError:print("错误: 找不到模型文件 'text_classifier.pth'。请确保该文件在正确的路径下。")exit()# 设置为评估模式
classifier.eval()# 预测函数
def predict_sentiment(text):"""预测文本的情感极性(积极或消极)"""# 预处理文本inputs = tokenizer(text,add_special_tokens=True,max_length=128,padding='max_length',truncation=True,return_tensors='pt')# 将输入移至设备inputs = {k: v.to(device) for k, v in inputs.items()}# 进行预测with torch.no_grad():outputs = classifier(inputs)# 获取预测概率probabilities = torch.nn.functional.softmax(outputs, dim=1)# 获取预测类别predicted_class = torch.argmax(probabilities, dim=1).item()# 类别映射sentiment_map = {0: "消极", 1: "积极"}return {"text": text,"predicted_class": predicted_class,"sentiment": sentiment_map[predicted_class],"confidence": probabilities[0][predicted_class].item()}# 示例预测
if __name__ == "__main__":# 测试几个例子test_texts = ["这个产品真的太棒了,我非常满意!","这个服务太糟糕了,简直是浪费时间。","今天天气真好,适合出去散步。","这个电影很无聊,不推荐观看。"]for text in test_texts:result = predict_sentiment(text)print(f"文本: {result['text']}")print(f"情感: {result['sentiment']} ({result['confidence']:.4f})")print("-" * 50)

生产环境:导出ONNX,并测试

建议使用C++等生产环境语言,这里使用python演示

from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn as nn
import numpy as np
import onnx
import onnxruntime as ort
import os# 定义分类模型
class TextClassifier(nn.Module):def __init__(self, model, hidden_size=1024, num_classes=2):super().__init__()self.model = modelself.classifier = nn.Sequential(nn.Dropout(0.1),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, num_classes))def forward(self, input_ids, attention_mask):outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)cls_embedding = outputs.last_hidden_state[:, 0, :]logits = self.classifier(cls_embedding)return logits# 保存模型为ONNX格式
def export_to_onnx(model, tokenizer, output_path='onnx_models/text_classifier.onnx'):# 创建目录(如果不存在)os.makedirs(os.path.dirname(output_path), exist_ok=True)model.eval()text = "示例句子"inputs = tokenizer(text, return_tensors="pt")input_names = ['input_ids', 'attention_mask']output_names = ['logits']dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'sequence_length'},'attention_mask': {0: 'batch_size', 1: 'sequence_length'},'logits': {0: 'batch_size'}}torch.onnx.export(model,(inputs['input_ids'], inputs['attention_mask']),output_path,export_params=True,opset_version=14,do_constant_folding=True,input_names=input_names,output_names=output_names,dynamic_axes=dynamic_axes)print(f"模型已导出为ONNX格式: {output_path}")onnx.checker.check_model(output_path)print("ONNX模型验证通过!")return output_path# 使用ONNX模型进行推理
def onnx_inference(onnx_path, tokenizer, text, max_length=128):session = ort.InferenceSession(onnx_path)inputs = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors="np")input_names = ['input_ids', 'attention_mask']onnx_inputs = {name: inputs[name] for name in input_names}outputs = session.run(None, onnx_inputs)logits = outputs[0]predictions = np.argmax(logits, axis=1)return predictions[0]# 主流程
if __name__ == "__main__":# 加载模型和分词器print("加载预训练模型和分词器...")tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")model = AutoModel.from_pretrained("BAAI/bge-m3")# 初始化分类器classifier = TextClassifier(model)# 加载已经训练好的权重try:classifier.load_state_dict(torch.load('text_classifier.pth'))print("已加载训练好的模型权重")except FileNotFoundError:print("警告: 未找到预训练权重,将使用随机初始化的模型")# 导出为ONNX格式print("\n开始导出模型为ONNX格式...")onnx_model = TextClassifier(model, hidden_size=1024, num_classes=2)onnx_model.load_state_dict(torch.load('text_classifier.pth'))onnx_model.to('cpu')output_path = 'onnx_models/text_classifier.onnx'onnx_path = export_to_onnx(onnx_model, tokenizer, output_path)# 使用ONNX模型进行推理测试print("\n使用ONNX模型进行推理测试:")test_texts = ["这是一个测试句子,非常棒!","大模型对初级开发者来说是一个坏消息。","这个餐厅的服务很糟糕,再也不会来了。","这本书真是太精彩了,值得一读。"]for text in test_texts:prediction = onnx_inference(output_path, tokenizer, text)sentiment = "积极" if prediction == 1 else "消极"print(f"文本: {text}")print(f"ONNX模型预测类别: {prediction} ({sentiment})\n")

结果

在这里插入图片描述

相关文章:

  • OpenCv高阶(十七)——dlib库安装、dlib人脸检测
  • Jeecg漏洞总结及tscan poc分享
  • Mujoco 学习系列(四)官方模型仓库 mujoco_menagerie
  • LangChain文档加载器实战:构建高效RAG数据流水线
  • 第八天的尝试
  • js中encodeURIComponent函数使用场景
  • 3.9/Q1,GBD数据库最新文章解读
  • FinalShell 密码在线解析方法(含完整源码与运行平台)
  • SQLServer与MySQL数据迁移案例解析
  • mysql日志文件binlog分析记录
  • 软考 系统架构设计师系列知识点之杂项集萃(69)
  • [Usaco2007 Dec]队列变换 题解
  • Python之web错误处理与异常捕获
  • LeRobot的机器人控制系统(下)
  • 有监督学习——决策树
  • 从3.7V/5V到7.4V,FP6291在应急供电智能门锁中的应用
  • 为什么mosquitto 禁用了 topic “#“后,无法使用主题中包含%c client_id了?
  • 【动手学深度学习】2.1. 数据操作
  • 技术篇-2.4.Python应用场景及开发工具安装
  • 如果验证集缺失或测试集缺失应该怎么办?
  • 张掖哪家公司做网站/网址生成短链接
  • 源代码网站培训/哈尔滨网络推广优化
  • 网站备案号被注销什么原因/广州seo网站管理
  • 义乌市建设局网站/友情链接的网站图片
  • 个人网站备案 备注/苏州seo培训
  • 如何自己做时时彩网站/最新域名ip地址