当前位置: 首页 > news >正文

基于bert-base-chinese的外卖评论情绪分类项目

  • 从 CSV 中读取文本和标签,构建 Dataset。

  • random_split 划分训练集和测试集。

  • 用 Tokenizer 将文本编码为 BERT 输入,collate_fn 批量处理并 padding。

  • DataLoader 按 batch 输出训练数据。

  • 模型基于预训练 BERT,只训练分类层(fc)。

  • 训练流程:前向 → 损失 → 反向 → 更新 fc → 验证。

from transformers import BertTokenizer, BertModel
import os
import torch
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset, random_split, DataLoaderCURRENT_DIR = os.path.dirname(os.path.abspath(__file__))class OutSellDataset(Dataset):def __init__(self, filepath):self.dataset = pd.read_csv(filepath)def __len__(self):return len(self.dataset)def __getitem__(self, i):text = self.dataset.text[i]label = self.dataset.label[i]return text, label# 创建分类模型
class Model(nn.Module):def __init__(self,bert):super().__init__()self.bert=bertself.fc = torch.nn.Linear(in_features=768, out_features=2)# 冻结 BERT 参数,只训练分类层for param in self.bert.parameters():param.requires_grad = Falsedef forward(self, input_ids, attention_mask, token_type_ids):out = self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,)# 对抽取的特征只取第1个字的结果做分类即可logits = self.fc(out.last_hidden_state[:, 0])return logitsif __name__ == "__main__":ds_path = os.path.join(CURRENT_DIR, 'datasets', "waimai.csv")df = pd.read_csv(ds_path)print(df.head())print(df.info())print(df.label.value_counts())print(df.text.head())datasets = OutSellDataset(ds_path)print(datasets[0])# 数据集划分train_size = int(len(datasets) * 0.8)test_size = len(datasets) - train_size# 随机划分数据集generator = torch.Generator().manual_seed(42)  # 设置随机种子train_dataset, test_dataset = random_split(datasets, [train_size, test_size])print(len(train_dataset), len(test_dataset))# 预训练词典tokenizer = BertTokenizer.from_pretrained("bert-base-chinese",cache_dir=os.path.join(CURRENT_DIR, 'chinese'),do_lower_case=False,)# 自定义数据整理器def collate_fn(data):sents = [i[0] for i in data]labels = [i[1] for i in data]# 编码data = tokenizer(text=sents,truncation=True,padding="max_length",max_length=500,return_tensors="pt",return_length=True,)# input_ids:编码之后的数字# attention_mask:补零的位置是0, 其他位置是1input_ids = data["input_ids"]attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]labels = torch.LongTensor(labels)# # 把数据移动到计算设备上# input_ids = input_ids.to(device)# attention_mask = attention_mask.to(device)# token_type_ids = token_type_ids.to(device)# labels = labels.to(device)return input_ids, attention_mask, token_type_ids, labels# 加载器构建train_dl = DataLoader(dataset=train_dataset,batch_size=16,collate_fn=collate_fn,shuffle=True,drop_last=True,)test_dl = DataLoader(dataset=test_dataset,batch_size=16,collate_fn=collate_fn,# shuffle=True,drop_last=True,)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 预训练模型bert = BertModel.from_pretrained("bert-base-chinese",cache_dir=os.path.join(CURRENT_DIR, 'model'),).to(device)# 打印参数量print(sum(p.numel() for p in bert.parameters()))# 输出input_ids, attention_mask, token_type_ids, labels = next(iter(train_dl))input_ids = input_ids.to(device)attention_mask = attention_mask.to(device)token_type_ids = token_type_ids.to(device)# out = bert(#     input_ids=input_ids,#     attention_mask=attention_mask,#     token_type_ids=token_type_ids,# )# # 批大小,序列长度,隐藏维度# print(out.last_hidden_state.size())# # print(f"批大小,序列长度,隐藏维度\n{out.last_hidden_state}")## # 模型冻结# for param in bert.parameters():#     param.requires_grad = False# 使用分类模型logmodel = Model(bert).to(device)output = logmodel(input_ids, attention_mask, token_type_ids)print(output)# 定义超参数loss_fn = nn.CrossEntropyLoss()# 只训练 fc层optimizer = torch.optim.Adam(logmodel.fc.parameters(), lr=2e-4, eps=1e-8)# 微调# optimizer = torch.optim.Adam(logmodel.parameters(), lr=2e-4, eps=1e-8)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)# 定义训练函数def train(dataloader):logmodel.train()total_acc, total_count, total_loss = 0, 0, 0for input_ids, mask, type_ids, label in tqdm(dataloader, desc="Training", leave=False):# 移动到 deviceinput_ids, mask, type_ids, label = input_ids.to(device), mask.to(device), type_ids.to(device), label.to(device)optimizer.zero_grad()predicted_label = logmodel(input_ids, token_type_ids=type_ids, attention_mask=mask)loss = loss_fn(predicted_label, label)loss.backward()optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()total_count += label.size(0)total_loss += loss.item() * label.size(0)# 返回平均 loss 和准确率return total_loss / total_count, total_acc / total_countdef test(dataloader):logmodel.eval()total_acc, total_count, total_loss = 0, 0, 0with torch.no_grad():for input_ids, mask, type_ids, label in tqdm(dataloader, desc="Testing", leave=False):input_ids, mask, type_ids, label = input_ids.to(device), mask.to(device), type_ids.to(device), label.to(device)predicted_label = logmodel(input_ids, token_type_ids=type_ids, attention_mask=mask)loss = loss_fn(predicted_label, label)total_acc += (predicted_label.argmax(1) == label).sum().item()total_count += label.size(0)total_loss += loss.item() * label.size(0)return total_loss / total_count, total_acc / total_count# 开始训练epochs = 2train_loss=[]train_acc=[]test_loss=[]test_acc=[]for epoch in range(epochs):epoch_loss, epoch_acc = train(train_dl)epoch_test_loss, epoch_test_acc = test(test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)scheduler.step()template=("epoch:{:2d},train_loss:{:.5f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")print(template.format(epoch+1,epoch_loss,epoch_acc*100,epoch_test_loss,epoch_test_acc*100,))print("Done!")

http://www.dtcms.com/a/486105.html

相关文章:

  • OpenSSL EVP编程介绍
  • 网站服务器组建中国国际贸易网站
  • 上新!功夫系列高通量DPU卡 CONFLUX®-2200P 全新升级,带宽升 40% IOPS提60%,赋能多业务场景。
  • Spring Boot 3零基础教程,properties文件中配置和类的属性绑定,笔记14
  • 以数据智能重构 OTC 连锁增长逻辑,覆盖网络与合作生态双维赛跑
  • 【推荐100个unity插件】基于节点的程序化无限地图生成器 —— MapMagic 2
  • 71_基于深度学习的布料瑕疵检测识别系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
  • 工控机做网站服务器网络模块
  • Mac——文件夹压缩的简便方法
  • Playwright自动化实战一
  • 电商网站开发面临的技术问题做seo网站诊断书怎么做
  • 【Qt】QTableWidget 自定义排序功能实现
  • WPF 疑点汇总2.HorizontalAlignment和 HorizontalContentAlignment
  • 【Qt】3.认识 Qt Creator 界面
  • 垂直网站建设付费小说网站怎么做
  • PDFBox - PDDocument 与 byte 数组、PDF 加密
  • 【Pytorch】分类问题交叉熵
  • 如何轻松删除 realme 手机中的联系人
  • Altium Designer怎么制作自己的集成库?AD如何制作自己的原理图库和封装库并打包生成库文件?AD集成库制作好后如何使用丨AD集成库使用方法
  • Jackson是什么
  • 代码实例:Python 爬虫抓取与解析 JSON 数据
  • 襄阳建设网站首页百度知识营销
  • 山东住房和城乡建设厅网站电话开发软件都有哪些
  • AbMole| Yoda1( M9372;GlyT2-IN-1; Yoda 1)
  • LLM监督微调SFT实战指南(Qwen3-0.6B-Base)
  • 【基础算法】多源 BFS
  • *@UI 视角下主程序与子程序的菜单页面架构及关联设计
  • Virtio 半虚拟化技术解析
  • 网站设计怎么好看律师做网络推广哪个网站好
  • 用commons vfs 框架 替换具体的sftp 实现