当前位置: 首页 > news >正文

【NLP入门系列四】评论文本分类入门案例

在这里插入图片描述

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

博主简介:努力学习的22级本科生一枚 🌟​;探索AI算法,C++,go语言的世界;在迷茫中寻找光芒​🌸
博客主页:羊小猪~~-CSDN博客
内容简介:这一篇是NLP的入门项目,以AG_NEW新闻数据为例。
🌸箴言🌸:去寻找理想的“天空“”之城
上一篇内容:【NLP入门系列三】NLP文本嵌入(以Embedding和EmbeddingBag为例)-CSDN博客
​💁​​💁​​💁​​💁​: 如果在conda安装环境,由于nlp的核心包是torchtext,所以如果把握不好就重新创建一虚拟环境(小编的“难忘”经历)

文章目录

    • 1、准备
      • 数据加载
      • 构建词表
    • 2、生成数据批次和迭代器
    • 3、定义与模型
      • 模型定义
      • 创建模型
    • 4、创建训练和评估函数
      • 训练函数
      • 评估函数
      • 创建超参数
    • 5、模型训练
    • 6、结果展示
    • 7、预测

🤔 思路

在这里插入图片描述

1、准备

AG News 数据集(也叫 AG’s Corpus or AG News Dataset),这是一个广泛用于自然语言处理(NLP)任务中的文本分类数据集


基本信息:

  • 全称:AG News
  • 来源:来源于 AG’s corpus,由 A. Godin 在 2005 年构建。
  • 用途:主要用于短文本多类别分类任务
  • 语言:英文
  • 类别数:4 类新闻主题
  • 训练样本数:120,000 条
  • 测试样本数:7,600 条

类别标签(共 4 类)

标签含义
1World (世界)
2Sports (体育)
3Business (商业)
4Science and Technology (科技)

数据加载

import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader 
import torchtext 
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')
# 加载本地数据
train_df = pd.read_csv("./data/train.csv")
test_df = pd.read_csv("./data/test.csv")# 合并标题和描述数据
train_df["text"] = train_df["Title"] + " " + train_df["Description"]
test_df["text"] = test_df["Title"] + " " + test_df["Description"]# 查看数据格式
train_df
Class IndexTitleDescriptiontext
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...Wall St. Bears Claw Back Into the Black (Reute...
13Carlyle Looks Toward Commercial Aerospace (Reu...Reuters - Private investment firm Carlyle Grou...Carlyle Looks Toward Commercial Aerospace (Reu...
23Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\ab...Oil and Economy Cloud Stocks' Outlook (Reuters...
33Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\f...Iraq Halts Oil Exports from Main Southern Pipe...
43Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...Oil prices soar to all-time record, posing new...
...............
1199951Pakistan's Musharraf Says Won't Quit as Army C...KARACHI (Reuters) - Pakistani President Perve...Pakistan's Musharraf Says Won't Quit as Army C...
1199962Renteria signing a top-shelf dealRed Sox general manager Theo Epstein acknowled...Renteria signing a top-shelf deal Red Sox gene...
1199972Saban not going to Dolphins yetThe Miami Dolphins will put their courtship of...Saban not going to Dolphins yet The Miami Dolp...
1199982Today's NFL gamesPITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...Today's NFL games PITTSBURGH at NY GIANTS Time...
1199992Nets get Carter from RaptorsINDIANAPOLIS -- All-Star Vince Carter was trad...Nets get Carter from Raptors INDIANAPOLIS -- A...

120000 rows × 4 columns

构建词表

# 定义 Dataset
class AGNewsDataset(Dataset):def __init__(self, dataframe):self.labels = dataframe['Class Index'].tolist()  # 列表数据self.texts = dataframe['text'].tolist()def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.labels[idx], self.texts[idx]# 加载数据
train_dataset = AGNewsDataset(train_df)
test_dataset = AGNewsDataset(test_df)# 构建词表
tokenizer = get_tokenizer("basic_english")  # 英文数据,设置英文分词def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)  # 构建词表# 构建词表,设置索引
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])print("Vocab size:", len(vocab))
Vocab size: 95804
# 查看这些单词所在词典的索引
vocab(['here', 'is', 'an', 'example'])  
[475, 21, 30, 5297]
'''  
标签,原始是字符串类型,现在要转换成 数字 类型
文本数字化,需要一个函数进行转换(vocab)
'''
text_pipline = lambda x : vocab(tokenizer(x))  # 先分词。在数字化
label_pipline = lambda x : int(x) - 1   # 标签转化为数字# 举例
text_pipline('here is the an example')
[475, 21, 2, 30, 5297]

2、生成数据批次和迭代器

# 采用embeddingbag嵌入方式,故需要构建数据,包括长度、标签、偏移量
''' 
数据格式:长度(~, 1)
标签:一维
偏移量:一维
'''
def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:# 标签列表,注意字符串转换成数字label_list.append(label_pipline(_label))# 文本列表,注意要转入tensro数据temp_text = torch.tensor(text_pipline(_text), dtype=torch.int64)text_list.append(temp_text)# 偏移量offsets.append(temp_text.size(0))# 全部转变成tensor变量label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device), text_list.to(device), offsets.to(device)# 数据加载
batch_size = 16
train_dl = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)test_dl = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)

3、定义与模型

模型定义

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embeddingBag = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,   # 嵌入维度sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()# 初始化权重def init_weights(self):initrange = 0.5self.embeddingBag.weight.data.uniform_(-initrange, initrange)  # 初始化权重范围self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置置为0def forward(self, text, offsets):embedding = self.embeddingBag(text, offsets)return self.fc(embedding)

创建模型

# 查看数据类别
train_df.groupby('Class Index').count()
TitleDescriptiontext
Class Index
1300003000030000
2300003000030000
3300003000030000
4300003000030000
class_num = 4
vocab_len = len(vocab)
embed_dim = 64  # 嵌入到64维度中
model = TextModel(vocab_size=vocab_len, embed_dim=embed_dim, num_class=class_num).to(device=device)

4、创建训练和评估函数

训练函数

def train(model, dataset, optimizer, loss_fn):size = len(dataset.dataset)num_batch = len(dataset)train_acc = 0train_loss = 0for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict_label = model(text, offset)loss = loss_fn(predict_label, label)# 求导与反向传播optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (predict_label.argmax(1) == label).sum().item()train_loss += loss.item()train_acc /= size train_loss /= num_batchreturn train_acc, train_loss

评估函数

def test(model, dataset, loss_fn):size = len(dataset.dataset)batch_size = len(dataset)test_acc, test_loss = 0, 0with torch.no_grad():for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict = model(text, offset)loss = loss_fn(predict, label) test_acc += (predict.argmax(1) == label).sum().item()test_loss += loss.item()test_acc /= size test_loss /= batch_sizereturn test_acc, test_loss

创建超参数

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01)  # 动态调整学习率

5、模型训练

import copyepochs = 10train_acc, train_loss, test_acc, test_loss = [], [], [], []best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(model, train_dl, optimizer, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)model.eval()epoch_test_acc, epoch_test_loss = test(model, test_dl, loss_fn)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)if best_acc is not None and epoch_test_acc > best_acc:# 动态调整学习率scheduler.step()best_acc = epoch_test_accbest_model = copy.deepcopy(model)  # 保存模型# 当前学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,  epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件
path = './best_model.pth'
torch.save(best_model.state_dict(), path) # 保存模型参数
Epoch: 1, Train_acc:79.9%, Train_loss:0.562, Test_acc:86.9%, Test_loss:0.392, Lr:5.00E-01
Epoch: 2, Train_acc:89.7%, Train_loss:0.313, Test_acc:88.9%, Test_loss:0.346, Lr:5.00E-01
Epoch: 3, Train_acc:91.2%, Train_loss:0.269, Test_acc:89.6%, Test_loss:0.329, Lr:5.00E-01
Epoch: 4, Train_acc:92.0%, Train_loss:0.243, Test_acc:89.8%, Test_loss:0.319, Lr:5.00E-01
Epoch: 5, Train_acc:92.6%, Train_loss:0.224, Test_acc:90.2%, Test_loss:0.315, Lr:5.00E-03
Epoch: 6, Train_acc:93.3%, Train_loss:0.207, Test_acc:90.6%, Test_loss:0.297, Lr:5.00E-03
Epoch: 7, Train_acc:93.4%, Train_loss:0.204, Test_acc:90.7%, Test_loss:0.295, Lr:5.00E-03
Epoch: 8, Train_acc:93.4%, Train_loss:0.203, Test_acc:90.7%, Test_loss:0.294, Lr:5.00E-03
Epoch: 9, Train_acc:93.4%, Train_loss:0.202, Test_acc:90.8%, Test_loss:0.293, Lr:5.00E-03
Epoch:10, Train_acc:93.4%, Train_loss:0.201, Test_acc:90.7%, Test_loss:0.293, Lr:5.00E-03

6、结果展示

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epoch_length = range(epochs)plt.figure(figsize=(12, 3))plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.show()


在这里插入图片描述

7、预测

model.load_state_dict(torch.load("./best_model.pth"))
model.eval() # 模型评估# 测试句子
test_sentence = "This is a news about Technology"# 转换为 token
token_ids = vocab(tokenizer(test_sentence))   # 切割分词--> 词典序列
text = torch.tensor(token_ids, dtype=torch.long).to(device)  # 转化为tensor
offsets = torch.tensor([0], dtype=torch.long).to(device)# 测试,注意:不需要反向求导
with torch.no_grad():output = model(text, offsets)predicted_label = output.argmax(1).item()# 输出结果
class_names = ["World", "Sports", "Business", "Science and Technology"]
print(f"预测类别: {class_names[predicted_label]}")
预测类别: Science and Technology
http://www.dtcms.com/a/265852.html

相关文章:

  • 设计模式-观察者模式、命令模式
  • Java连接阿里云MaxCompute例
  • Qt宝藏库:20+实用开源项目合集
  • NV133NV137美光固态闪存NV147NV148
  • Git协作开发:feature分支、拉取最新并合并
  • 这才叫窗口查询!TDEngine官方文档没讲透的实战玩法
  • ModbusRTU转Profinet网关在工业自动化中的应用与价值
  • 50天50个小项目 (Vue3 + Tailwindcss V4) ✨ | DragNDrop(拖拽占用组件)
  • 力扣 hot100 Day33
  • 快速搭建大模型web对话环境指南(open-webUI)
  • 双向链表的实现
  • [创业之路-468]:企业经营层 - 使用“市场-需求-竞争”三维模型筛选细分市场(市场维度、客户需求维度、竞争维度)
  • JavaEE-Linux环境部署
  • Java 核心技术与框架实战十八问
  • 专题:2025即时零售与各类人群消费行为洞察报告|附400+份报告PDF、原数据表汇总下载
  • 模拟IC设计提高系列6-Library导入与新建Library
  • 微信小程序41~50
  • 区块链(私有链搭建和实现)
  • 【C++】访问者模式
  • PHP语法基础篇(八):超全局变量
  • 鸿蒙应用开发:从网络获取数据
  • UE5中的AnimNotify
  • KDD 2025 | 地理定位中的群体智能:一个多智能体大型视觉语言模型协同框架
  • rabbitmq 与 Erlang 的版本对照表 win10 安装方法
  • SPLADE 在稀疏向量搜索中的原理与应用详解
  • MCP 传输机制(Streamable HTTP)
  • 多线程知识
  • 21、MQ常见问题梳理
  • 映射阿里云OSS(对象存储服务)
  • [创业之路-467]:企业经营层 - 《营销管理》的主要内容、核心思想以及对创业者的启示