当前位置: 首页 > news >正文

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务

      • 1. 认为主流程code
      • 2. NLP 对话和预测基本均属于分类任务详细见
      • 3. Tensorborad

1. 认为主流程code

import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

###制定参数 --model TextRNN
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()


if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  #TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter,writer)

RNN


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

在这里插入图片描述
TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

相关文章:

  • 【教学类-43-14】 20240103 (4宫格数独:正确版:576套) 不重复的基础模板数量:576套
  • 工作中redis相关知识总结
  • Javaweb之Mybatis的基础操作的详细解析
  • Docker安装Superset
  • CUMT--Java复习--核心类
  • 影子价格 Shadow Price
  • JUC原子操作类
  • 【C程序设计】C函数
  • 华为鸿蒙应用--文件管理工具(鸿蒙工具)-ArkTs
  • JavaSE语法之十五:异常(超全!!!)
  • 服务器运行状况监控工具
  • 单挑力扣(LeetCode)SQL题:180. 连续出现的数字(难度:中等)
  • 用idea跑起十多年前的项目
  • PHP序列化总结3--反序列化的简单利用及案例分析
  • Linux系统:引导过程与服务控制
  • 深入理解ArkTS:Harmony OS 应用开发语言 TypeScript 的基础语法和关键特性
  • [C语言]时间戳
  • Unity3D Shader Graph 使用 DDXY 节点达到抗锯齿的原理详解
  • 【量化】蜘蛛网策略复现
  • uniapp原生插件 - android原生插件打包流程 ( 避坑指南一)
  • 烤肉店从泔水桶内捞出肉串再烤?西安未央区市监局:停业整顿
  • 英德宣布开发射程超2000公里导弹,以防务合作加强安全、促进经济
  • 舱位已排到月底,跨境电商忙补货!美线订单大增面临爆舱,6月运价或翻倍
  • 国家卫生健康委通报关于肖某引发舆情事件调查处置进展情况
  • 俄方代表团抵达土耳其,俄乌直接谈判有望于当地时间上午重启
  • 泰山、华海、中路等山东险企综合成本率均超100%,承保业务均亏损