第69课 分类任务: 基于BERT训练情感分类与区别二分类本质思考
【GPT入门】第69课 分类任务: 基于BERT训练情感分类与区别二分类本质思考
- 1. 方案设计
- 2. 数据介绍
-
- 2.1 下载微博情感分类数据
- 2.2 数据样例
- 3. 训练
-
- 3.1 代码
- 3.2 日志
- 3.3 加载模型检查效果
- 4. 总结
-
-
- 1. 核心共性:二分类与多分类的本质一致
- 2. 关键差异:仅在“全连接层输出维度”与“损失函数”
- 总结
-
1. 方案设计
2. 数据介绍
2.1 下载微博情感分类数据
命令行下载:
huggingface-cli download --repo-type dataset seamew/Weibo --local-dir /root/autodl-tmp/xxzh/bert/weibo_data2
seamew/Weibo
2.2 数据样例
7,怎么送?
7,为节目忽悠呗~
7,"老头子:老太婆,我想求你一件事。"
7,太平天国刑律规定,男女别营,授受不亲,夫妻之间私自约会,按律处斩.
7,回复,帖子里面一个是231的破解版本,一个是231的完全镜头解密文件,用同步助手等第三方PC端软件上传到指定的目录里,就能使用了。
2,[1/2]三八妇女节快乐[哈哈]怎么刚好在这天我们就九个月了呢?
7,现在的国债法定上限为14.3万亿美元,预计将在7月前后冲击这一上限,如果国会不能立法提高该上限,则美国国债将出现违约。
2,"[晕][晕][晕][晕] 其实我还是刚出生得婴儿啦[害羞] [害羞] [害羞] ,[六一快乐] 各位六一快乐[六一快乐] !"
4,死圣母
7,鼓励、支持企业跨地区重组并购,加强产业资源整合。
7,我才发现原来是奔G……
7,敢问你还能活到下一个11年11月11日吗!
7,我想看到 我在寻找 那所谓的爱情的美好
1,"别说血性,你们能有点幽默感就算对得起生你们的女人了."
7,形象需要重塑~~
7,本次选举的一大亮点是鼓励党外人士和自荐候选人参选国会代表。
3,下午真的痛晕了。
2,集体的力量,我感受到了集体的力量就是:再累,为了不掉队也要坚持,还有在险处有人提醒或帮扶一把,再就是发现自己不是那么差劲,[哈哈] 一路上捡拾垃圾,发现自己眼神真好使,那么小的糖纸被树枝压着都能看见,不过我只负责看不管捡,这次,真的有些累了,下次,下次......
3,不是我不能原谅你,而是我无法原谅我自己。
2,老天不负苦心人昂!!!
5,各种店铺,市场都关了,连洗澡堂和药店也关了,整个商业停摆了,连瓶酱油都买不到!
7,岁月总是一晃就那么多年,小时候不顾一切的想要离开家,长大了又不顾一切的想回去。
2,我反应一向很迟钝哈~~嘿嘿~~加油加油加油~~~明天要破釜沉舟的奋战哈~~~为了你·我这个胆小菇都变成不怕死不怕毒的大蘑菇;啦~~~~~!!!
2,哈哈哈哈~
0,求TB链接~ 还有那本《洋溢幸福的青苔小世界》小妖同学不要太喜欢哦!
7,那些纠结李赫宰进M的傻瓜们~就算赫宰进了M 他也会好好照顾自己的。
3,我的世界,你不在乎;你的世界,我被驱逐。
7,亲,快来“分担”一下吧!
7,#素食主义#
3,回到宿舍连续拉肚子,崩溃…
1,再扣除出让金和土增税又如何呢?
7,这时候的认真和专业,在成品里会展现的淋漓尽致!
4,看到这个帖子我就想说句:TMD!
7,加油吧…
7,细看则是山谷幽深、云气缭绕、万仞深涧。
7,好吧,我偷懒了,手脚神马的……
3. 训练
3.1 代码
import os
import numpy as np
import torch
from datasets import load_dataset, DatasetDict
import transformers
from transformers import (AutoModelForSequenceClassification,AutoTokenizer,TrainingArguments,Trainer,DataCollatorWithPadding
)# 检查transformers版本
transformers_version = transformers.__version__
print(f"Transformers版本: {transformers_version}")
eval_strategy_param = "eval_strategy" if transformers_version >= "4.17.0" else "evaluation_strategy"
save_strategy_param = "save_strategy" if transformers_version >= "4.17.0" else "save_strategy"# 设备配置
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")# ----------------------
# 配置参数
# ----------------------
class Config:model_path = "/root/autodl-tmp/models_xxzh/bert-base-chinese/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea"# 本地数据集路径(关键修改)data_dir = "/root/autodl-tmp/xxzh/bert/data_weibo"train_file = os.path.join(data_dir, "train.csv")test_file = os.path.join(data_dir, "test.csv")val_file = os.path.join(data_dir, "validation.csv")# 训练参数output_dir = "./bert_weibo_sentiment_local"num_labels = 8batch_size = 200learning_rate = 2e-5num_train_epochs = 10max_length = 128logging_dir = "./logs/weibo_local"logging_steps = 100eval_strategy = "epoch"save_strategy = "epoch"save_total_limit = 3device = device# ----------------------
# 数据加载与预处理 - 核心修改
# ----------------------
def load_and_preprocess_data(config):# 1. 检查文件是否存在required_files = [config.train_file, config.test_file, config.val_file]for file in required_files:if not os.path.exists(file):raise FileNotFoundError(f"数据集文件不存在: {file}")# 2. 使用load_dataset的csv格式加载(按要求修改)dataset_train = load_dataset(path="csv", data_files=config.train_file, split="train")dataset_test = load_dataset(path="csv", data_files=config.test_file, split="train")dataset_val = load_dataset(path="csv", data_files=config.val_file, split="train")# 3. 构建DatasetDictdataset = DatasetDict({"train": dataset_train,"test": dataset_test,"validation": dataset_val})print(f"数据集结构: {dataset}")print(f"数据集样例: {dataset['train'][0]}") # 打印第一条数据,确认列名是否正确# 4. 加载分词器tokenizer = AutoTokenizer.from_pretrained(config.model_path)# 5. 预处理函数(注意:需与CSV中的文本列名匹配)def preprocess_function(examples):# 假设CSV中的文本列名为'review',标签列为'label'# 若实际列名不同(如'text'),需修改此处return tokenizer(examples["text"],truncation=True,max_length=config.max_length,padding="max_length")# 6. 应用预处理tokenized_dataset = dataset.map(preprocess_function, batched=True)# 7. 格式化标签列tokenized_dataset = tokenized_dataset.rename_column("label", "labels")tokenized_dataset.set_format("torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])return tokenized_dataset, tokenizer# ----------------------
# 评估指标(不变)
# ----------------------
def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return {"accuracy": (predictions == labels).mean()}# ----------------------
# 训练模型(不变)
# ----------------------
def train_model(config):tokenize