当前位置: 首页 > news >正文

第69课 分类任务: 基于BERT训练情感分类与区别二分类本质思考

【GPT入门】第69课 分类任务: 基于BERT训练情感分类与区别二分类本质思考

  • 1. 方案设计
  • 2. 数据介绍
    • 2.1 下载微博情感分类数据
    • 2.2 数据样例
  • 3. 训练
    • 3.1 代码
    • 3.2 日志
    • 3.3 加载模型检查效果
  • 4. 总结
      • 1. 核心共性:二分类与多分类的本质一致
      • 2. 关键差异:仅在“全连接层输出维度”与“损失函数”
      • 总结

1. 方案设计

2. 数据介绍

2.1 下载微博情感分类数据

命令行下载:
huggingface-cli download --repo-type dataset seamew/Weibo --local-dir /root/autodl-tmp/xxzh/bert/weibo_data2

seamew/Weibo
在这里插入图片描述

2.2 数据样例

7,怎么送?
7,为节目忽悠呗~
7,"老头子:老太婆,我想求你一件事。"
7,太平天国刑律规定,男女别营,授受不亲,夫妻之间私自约会,按律处斩.
7,回复,帖子里面一个是231的破解版本,一个是231的完全镜头解密文件,用同步助手等第三方PC端软件上传到指定的目录里,就能使用了。
2,[1/2]三八妇女节快乐[哈哈]怎么刚好在这天我们就九个月了呢?
7,现在的国债法定上限为14.3万亿美元,预计将在7月前后冲击这一上限,如果国会不能立法提高该上限,则美国国债将出现违约。
2,"[晕][晕][晕][晕] 其实我还是刚出生得婴儿啦[害羞] [害羞] [害羞] ,[六一快乐] 各位六一快乐[六一快乐] !"
4,死圣母
7,鼓励、支持企业跨地区重组并购,加强产业资源整合。
7,我才发现原来是奔G……
7,敢问你还能活到下一个11年11月11日吗!
7,我想看到 我在寻找 那所谓的爱情的美好
1,"别说血性,你们能有点幽默感就算对得起生你们的女人了."
7,形象需要重塑~~
7,本次选举的一大亮点是鼓励党外人士和自荐候选人参选国会代表。
3,下午真的痛晕了。
2,集体的力量,我感受到了集体的力量就是:再累,为了不掉队也要坚持,还有在险处有人提醒或帮扶一把,再就是发现自己不是那么差劲,[哈哈] 一路上捡拾垃圾,发现自己眼神真好使,那么小的糖纸被树枝压着都能看见,不过我只负责看不管捡,这次,真的有些累了,下次,下次......
3,不是我不能原谅你,而是我无法原谅我自己。
2,老天不负苦心人昂!!!
5,各种店铺,市场都关了,连洗澡堂和药店也关了,整个商业停摆了,连瓶酱油都买不到!
7,岁月总是一晃就那么多年,小时候不顾一切的想要离开家,长大了又不顾一切的想回去。
2,我反应一向很迟钝哈~~嘿嘿~~加油加油加油~~~明天要破釜沉舟的奋战哈~~~为了你·我这个胆小菇都变成不怕死不怕毒的大蘑菇;啦~~~~~!!!
2,哈哈哈哈~
0,求TB链接~ 还有那本《洋溢幸福的青苔小世界》小妖同学不要太喜欢哦!
7,那些纠结李赫宰进M的傻瓜们~就算赫宰进了M 他也会好好照顾自己的。
3,我的世界,你不在乎;你的世界,我被驱逐。
7,亲,快来“分担”一下吧!
7,#素食主义#
3,回到宿舍连续拉肚子,崩溃…
1,再扣除出让金和土增税又如何呢?
7,这时候的认真和专业,在成品里会展现的淋漓尽致!
4,看到这个帖子我就想说句:TMD!
7,加油吧…
7,细看则是山谷幽深、云气缭绕、万仞深涧。
7,好吧,我偷懒了,手脚神马的……

3. 训练

3.1 代码

import os
import numpy as np
import torch
from datasets import load_dataset, DatasetDict
import transformers
from transformers import (AutoModelForSequenceClassification,AutoTokenizer,TrainingArguments,Trainer,DataCollatorWithPadding
)# 检查transformers版本
transformers_version = transformers.__version__
print(f"Transformers版本: {transformers_version}")
eval_strategy_param = "eval_strategy" if transformers_version >= "4.17.0" else "evaluation_strategy"
save_strategy_param = "save_strategy" if transformers_version >= "4.17.0" else "save_strategy"# 设备配置
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")# ----------------------
# 配置参数
# ----------------------
class Config:model_path = "/root/autodl-tmp/models_xxzh/bert-base-chinese/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea"# 本地数据集路径(关键修改)data_dir = "/root/autodl-tmp/xxzh/bert/data_weibo"train_file = os.path.join(data_dir, "train.csv")test_file = os.path.join(data_dir, "test.csv")val_file = os.path.join(data_dir, "validation.csv")# 训练参数output_dir = "./bert_weibo_sentiment_local"num_labels = 8batch_size = 200learning_rate = 2e-5num_train_epochs = 10max_length = 128logging_dir = "./logs/weibo_local"logging_steps = 100eval_strategy = "epoch"save_strategy = "epoch"save_total_limit = 3device = device# ----------------------
# 数据加载与预处理 - 核心修改
# ----------------------
def load_and_preprocess_data(config):# 1. 检查文件是否存在required_files = [config.train_file, config.test_file, config.val_file]for file in required_files:if not os.path.exists(file):raise FileNotFoundError(f"数据集文件不存在: {file}")# 2. 使用load_dataset的csv格式加载(按要求修改)dataset_train = load_dataset(path="csv", data_files=config.train_file, split="train")dataset_test = load_dataset(path="csv", data_files=config.test_file, split="train")dataset_val = load_dataset(path="csv", data_files=config.val_file, split="train")# 3. 构建DatasetDictdataset = DatasetDict({"train": dataset_train,"test": dataset_test,"validation": dataset_val})print(f"数据集结构: {dataset}")print(f"数据集样例: {dataset['train'][0]}")  # 打印第一条数据,确认列名是否正确# 4. 加载分词器tokenizer = AutoTokenizer.from_pretrained(config.model_path)# 5. 预处理函数(注意:需与CSV中的文本列名匹配)def preprocess_function(examples):# 假设CSV中的文本列名为'review',标签列为'label'# 若实际列名不同(如'text'),需修改此处return tokenizer(examples["text"],truncation=True,max_length=config.max_length,padding="max_length")# 6. 应用预处理tokenized_dataset = dataset.map(preprocess_function, batched=True)# 7. 格式化标签列tokenized_dataset = tokenized_dataset.rename_column("label", "labels")tokenized_dataset.set_format("torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])return tokenized_dataset, tokenizer# ----------------------
# 评估指标(不变)
# ----------------------
def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return {"accuracy": (predictions == labels).mean()}# ----------------------
# 训练模型(不变)
# ----------------------
def train_model(config):tokenize
http://www.dtcms.com/a/391210.html

相关文章:

  • Mysql杂志(二十)——MyISAM索引结构与B树B+树
  • Java 大视界 -- 基于 Java 的大数据实时流处理在金融高频交易数据分析中的应用
  • BonkFun 推出 USD1:Meme 币玩法的新入口
  • flutter在包含ListVIew的滚动列表页面中监听手势
  • Redis 三种集群模式详解
  • 打开hot100
  • Ant-Design Table中使用 AStatisticCountdown倒计时,鼠标在表格上移动时倒计时被重置
  • Linux crontab 定时任务工具使用
  • 阿里云RDS mysql8数据本地恢复,与本地主从同步(容器中)
  • 记录一次mysql启动失败问题解决
  • LeetCode算法练习:35.搜索插入位置
  • (1) 为什么推荐tauri框架
  • 嵌入式面试高频(八)!!!C++语言(嵌入式八股文,嵌入式面经)
  • Spring AI开发指导-工具调用
  • Linux 基本命令超详细解释第二期 | touch | cat | more | cp | mv | rm | which | find
  • [x-cmd] 安装指南
  • Altium Designer(AD24)原理图Move移动功能详细介绍图文教程
  • 部署java程序,服务器报403 Forbidden 问题的终极解决方案
  • 【LeetCode】链表经典问题解析:环形、回文与相交
  • 电磁超材料及其领域应用优势
  • STM32与Modbus RTU协议实战开发指南-fc3ab6a453
  • ArrayList 与 LinkedList 深度对比:从原理到场景的全方位解析
  • Ubuntu和windows复制粘贴互通
  • 银行回单 OCR 识别:财务自动化的 “数据入口“
  • 深兰科技陈海波的AI破局之道:打造软硬一体综合竞争力|《中国经营报》专访
  • 面试经验之mysql高级问答深度解析
  • 高质量票据识别数据集:1000张收据图像+2141个商品标注,支持OCR模型训练与文档理解研究
  • 嵌入式音视频开发——FFmpeg入门
  • MySQL索引篇---B+树在索引中的工作原理
  • 强化学习训练-数据处理