当前位置: 首页 > news >正文

深度学习笔记40_中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、我的环境

1.语言环境:Python 3.8

2.编译器:Pycharm

3.深度学习环境:

  • torch==1.12.1+cu113
  • torchvision==0.13.1+cu113

、导入数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")import pandas as pd# 加载自定义中文数据
train_data = pd.read_csv('./data/train.csv', sep='\t', header=None)
print(train_data.head())

结果:

                       0              1
0      还有双鸭山到淮阴的汽车票吗13号的   Travel-Query
1                从这里怎么回家   Travel-Query
2       随便播放一首专辑阁楼里的佛里的歌     Music-Play
3              给看一下墓王之王嘛  FilmTele-Play
4  我想看挑战两把s686打突变团竞的游戏视频     Video-Play

、构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引print(vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频']))

结果:[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]

text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))
结果:[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
4

生成数据批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text, _label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回维度dim中输入元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle=False,collate_fn=collate_batch)

定义模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,  # 嵌入的维度sparse=False)  #self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)  # 初始化权重self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置值归零def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

定义实例

num_class  = len(label_name)
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定义训练函数与评估函数

import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time = time.time()for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()  # grad属性归零loss = criterion(predicted_label, label)  # 计算网络输出和真实值之间的差距,label为真实值loss.backward()  # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc / total_count, train_loss / total_count

训练模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset# 超参数
EPOCHS = 10  # epoch
LR = 5  # 学习率
BATCH_SIZE = 64  # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc, val_loss, lr))print('-' * 69)

 结果:

Batch [50/152], Loss: 0.0340, Accuracy: 0.4203
Batch [100/152], Loss: 0.0235, Accuracy: 0.5851
Batch [150/152], Loss: 0.0309, Accuracy: 0.6572
---------------------------------------------------------------------
| epoch 1 | time: 0.55s | valid_acc 0.814 valid_loss 0.012 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0104, Accuracy: 0.8165
Batch [100/152], Loss: 0.0099, Accuracy: 0.8215
Batch [150/152], Loss: 0.0092, Accuracy: 0.8329
---------------------------------------------------------------------
| epoch 2 | time: 0.44s | valid_acc 0.855 valid_loss 0.008 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0068, Accuracy: 0.8790
Batch [100/152], Loss: 0.0065, Accuracy: 0.8778
Batch [150/152], Loss: 0.0064, Accuracy: 0.8809
---------------------------------------------------------------------
| epoch 3 | time: 0.44s | valid_acc 0.874 valid_loss 0.007 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0050, Accuracy: 0.9105
Batch [100/152], Loss: 0.0051, Accuracy: 0.9101
Batch [150/152], Loss: 0.0048, Accuracy: 0.9130
---------------------------------------------------------------------
| epoch 4 | time: 0.44s | valid_acc 0.882 valid_loss 0.006 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0039, Accuracy: 0.9366
Batch [100/152], Loss: 0.0039, Accuracy: 0.9339
Batch [150/152], Loss: 0.0038, Accuracy: 0.9350
---------------------------------------------------------------------
| epoch 5 | time: 0.44s | valid_acc 0.896 valid_loss 0.006 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0028, Accuracy: 0.9519
Batch [100/152], Loss: 0.0030, Accuracy: 0.9517
Batch [150/152], Loss: 0.0030, Accuracy: 0.9494
---------------------------------------------------------------------
| epoch 6 | time: 0.44s | valid_acc 0.898 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0025, Accuracy: 0.9580
Batch [100/152], Loss: 0.0024, Accuracy: 0.9616
Batch [150/152], Loss: 0.0024, Accuracy: 0.9609
---------------------------------------------------------------------
| epoch 7 | time: 0.44s | valid_acc 0.902 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0018, Accuracy: 0.9764
Batch [100/152], Loss: 0.0019, Accuracy: 0.9739
Batch [150/152], Loss: 0.0019, Accuracy: 0.9724
---------------------------------------------------------------------
| epoch 8 | time: 0.44s | valid_acc 0.900 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0015, Accuracy: 0.9810
Batch [100/152], Loss: 0.0014, Accuracy: 0.9817
Batch [150/152], Loss: 0.0014, Accuracy: 0.9818
---------------------------------------------------------------------
| epoch 9 | time: 0.49s | valid_acc 0.906 valid_loss 0.005 | lr 0.500000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0013, Accuracy: 0.9831
Batch [100/152], Loss: 0.0013, Accuracy: 0.9831
Batch [150/152], Loss: 0.0014, Accuracy: 0.9825
---------------------------------------------------------------------
| epoch 10 | time: 0.54s | valid_acc 0.906 valid_loss 0.005 | lr 0.500000
---------------------------------------------------------------------

、预测

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()# ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"model = model.to("cpu")print("该文本的类别是:%s" %label_name[predict(ex_text_str, text_pipeline)])
该文本的类别是:Travel-Query

总结: 

  1. 语料库(原始文本)‌:

    来源包括维基百科、网页文本、新闻资讯及内部文本。
  2. 文本清洗‌:

    清洗原始文本,包括去除标点符号和特殊字符。该流程主要用于将原始文本数据转化为可用于模型训练的数值化向量,再通过深度学习模型进行文本分类。
    • 分词‌:

      使用jieba分词工具对清洗后的文本进行分词处理。
    • 建模‌:

      采用不同的模型进行文本建模,包括循环神经网络(RNN)、卷积神经网络(CNN)、门控循环单元(GRU)和长短期记忆网络(LSTM)。
    • 文本向量化‌:

      将分词后的文本转换为向量表示,方法包括TF-IDF和Word2vec。

相关文章:

  • 数字智慧方案6189丨智慧应急综合解决方案(46页PPT)(文末有下载方式)
  • n8n 使用 AI Agent 和 MCP 社区节点
  • 树与二叉树完全解析:从基础到应用
  • 4.27-5.4学习周报
  • 如何实现服务的自动扩缩容(Auto Scaling)
  • 1️⃣7️⃣three.js_OrbitControls相机控制器
  • 溯因推理思维——AI与思维模型【92】
  • 【免费】2007-2021年上市公司对外投资数据
  • 数字世界的“私人车道“:网络切片如何用Python搭建专属通信高速路?
  • P2196 [NOIP 1996 提高组] 挖地雷
  • Python爬虫基础总结
  • 【算法】动态规划专题一 斐波那契数列模型 1-4
  • SQL基础全面指南:从CRUD操作到高级特性实战
  • GC9D01 和 GC9A01两种TFT 液晶显示驱动芯片
  • IntelliJ IDEA
  • Socat 用法详解:网络安全中的瑞士军刀
  • 依赖倒置原则
  • Kotlin 基础
  • 软件性能测试报告:办公软件性能如何满足日常工作需求?
  • 文章一《人工智能学习框架入门指南》
  • 叙利亚多地遭以色列空袭
  • 泽连斯基:美乌矿产协议将提交乌拉达批准
  • 五一当天1372对新人在沪喜结连理,涉外婚姻登记全市铺开
  • 魔都眼|买买买,老铺黄金新店开业被挤爆:有人排队5小时
  • 湖南新宁一矿厂排水管破裂,尾砂及积水泄漏至河流,当地回应
  • 长三角铁路今日预计发送旅客420万人次,有望创单日客发量新高