当前位置: 首页 > news >正文

从代码学习深度学习 - 情感分析:使用循环神经网络 PyTorch版

文章目录

  • 前言
  • 1. 加载与预处理数据集
    • 数据读取与词元化
    • 构建词汇表
    • 截断、填充与数据迭代器
  • 2. 构建循环神经网络模型
    • 双向RNN模型(BiRNN)详解
    • 权重初始化
  • 3. 加载预训练词向量
    • 构建词向量加载器
    • 将预训练向量注入模型
  • 4. 训练与评估模型
    • 定义训练函数
    • 可视化训练过程
  • 5. 模型预测
    • 编写预测函数
    • 实例测试
  • 6. 总结


前言

在信息爆炸的时代,从海量的文本数据中提取有价值的信息变得至关重要。无论是电商网站的商品评论、社交媒体上的用户反馈,还是新闻文章中的观点倾向,理解文本背后的情感色彩——即情感分析——都有着广泛的应用。

循环神经网络(RNN)由于其对序列数据的强大建模能力,天然地适用于处理文本这类具有时序特征的数据。在本篇博客中,我们将从零开始,使用PyTorch框架构建一个基于双向循环神经网络(Bi-RNN)的情感分析模型。我们不仅会详细讲解数据预处理、模型构建、训练评估的全过程,还将引入预训练的GloVe词向量来提升模型的性能。

这篇博客的目标是“从代码学习深度学习”。因此,我们将完整地展示每一个模块的代码,并配以详尽的解释,力求让读者不仅能看懂代码,更能理解每一行代码背后的原理和设计思想。无论您是深度学习初学者,还是希望系统学习PyTorch在自然语言处理中应用的开发者,相信都能从中获益。

让我们一起踏上这场代码与思想的探索之旅吧!

完整代码:下载链接


1. 加载与预处理数据集

任何成功的NLP项目都始于坚实的数据处理。我们的任务是分析IMDb电影评论的情感,这是一个经典的二分类问题(正面/负面)。在这一步,我们将完成从原始文本文件到PyTorch数据迭代器的全部转换过程。

主逻辑由load_data_imdb函数驱动,它调用了一系列辅助函数来完成任务。

# 情感分析:使用循环神经网络.ipynbimport torch
import utils_for_data
from torch import nnbatch_size = 64
train_iter, test_iter, vocab = utils_for_data.load_data_imdb(batch_size)

上面的代码是我们的入口,它调用utils_for_data.load_data_imdb来获取训练/测试数据迭代器和词汇表。现在,让我们深入utils_for_data.pyutils_for_vocab.py,看看这一切是如何实现的。

数据读取与词元化

首先,我们需要从压缩包中读取IMDb数据集的文本和标签。read_imdb函数负责遍历指定目录,读取每个评论文件并为其打上正面(1)或负面(0)的标签。

# utils_for_data.pyimport os
import zipfile
import tarfile
import utils_for_vocab
import torch.utils.data as data
import torchdef extract(name, folder=None):"""下载并解压zip/tar文件参数:name (str): 要解压的文件名/路径,维度: [字符串]folder (str, optional): 指定的文件夹名称,维度: [字符串] 或 None返回:str: 解压后的目录路径,维度: [字符串]"""base_dir = os.path.dirname(name)data_dir, ext = os.path.splitext(name)if ext == '.zip':fp = zipfile.ZipFile(name, 'r')elif ext in ('.tar', '.gz'):fp = tarfile.open(name, 'r')else:assert False, '只有zip/tar文件可以被解压缩'fp.extractall(base_dir)fp.close()return os.path.join(base_dir, folder) if folder else data_dirdef read_imdb(data_dir, is_train):"""读取IMDb评论数据集文本序列和标签参数:data_dir (str): 数据集根目录路径is_train (bool): 是否读取训练集,True为训练集,False为测试集返回:tuple: (data, labels)data (list): 评论文本列表,维度为 [样本数量]labels (list): 标签列表,维度为 [样本数量],1表示正面评价,0表示负面评价"""data = []labels = []for label in ('pos', 'neg'):folder_name = os.path.join(data_dir, 'train' if is_train else 'test', label)for file in os.listdir(folder_name):file_path = os.path.join(folder_name, file)with open(file_path, 'rb') as f:review = f.read().decode('utf-8').replace('\n', '')data.append(review)labels.append(1 if label == 'pos' else 0)return data, labels

拿到原始文本后,我们需要将其分解为模型可以理解的基本单元——词元(Token)。这个过程称为词元化(Tokenization)。tokenize函数可以按单词或字符进行分割。

# utils_for_vocab.pyimport torch
import torch.utils.data
from collections import Counterdef tokenize(lines, token='word'):"""将文本行拆分为单词或字符词元参数:lines (list): 文本行列表,维度: [行数],每个元素为字符串token (str): 词元化类型,维度: [标量],'word'表示按单词分割,'char'表示按字符分割返回:tokenized_lines (list): 词元化后的文本,维度: [行数 × 词元数],嵌套列表结构"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print('错误:未知词元类型:' + token)

构建词汇表

计算机无法直接处理文本,我们需要将词元映射为数字索引。Vocab类就是为此设计的。它会统计所有词元的频率,并只保留那些出现频率高于min_freq的词元,其余的都归为未知词元<unk>。这不仅能减小词汇表的大小,还能过滤掉噪音。

# utils_for_vocab.pydef count_corpus(tokens):"""统计词元出现频率参数:tokens (list): 词元列表,维度: [词元数] 或 [序列数 × 词元数](嵌套列表)返回:counter (Counter): 词元频率统计对象,键为词元,值为出现次数"""if len(tokens) == 0 or isinstance(tokens[0], list):tokens = [token for line in tokens for token in line]return Counter(tokens)class Vocab:"""文本词汇表类,用于管理词元到索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化词汇表参数:tokens (list): 词元列表,维度: [词元数] 或 [序列数 × 词元数]min_freq (int): 最小词频阈值,维度: [标量],低于此频率的词元将被忽略reserved_tokens (list): 保留词元列表,维度: [保留词元数],如特殊标记"""if tokens is None:tokens = []if reserved_tokens is None:reserved_tokens = []counter = count_corpus(tokens)self._token_freqs = sorted(counter.items()<
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.dtcms.com/a/258216.html

相关文章:

  • 【运维系列】Plane 开源项目安装和配置指南
  • 爬虫004----网页解析库
  • css 文字跳跃动画
  • prometheus 配置邮件告警
  • iostat中的util原理
  • 大模型项目实战:业务场景和解决方案
  • 数学:关于向量计算的三角形法则
  • GoAdmin代码生成器实践
  • 中断控制与实现
  • APP测试-APP启动耗时
  • Android 9.0(API 28)后字重设置
  • LeetCode热题100—— 35. 搜索插入位置
  • ubuntu22.04修改IP地址
  • 战略调整频繁,如何快速重构项目组合
  • Spring Boot整合FreeMarker全攻略
  • 基于STM32的快递箱的设计
  • 对人工智能的厌倦感是真实存在的,而且它给品牌带来的损失远不止是参与度的下降
  • Android edge-to-edge兼容适配
  • Git 子模块 (Submodule) 完全使用指南
  • 【Vue】 keep-alive缓存组件实战指南
  • AI智能化高效办公:WPS AI全场景深度应用指南
  • MySQL之SQL性能优化策略
  • LayUI的table实现行上传图片+mvc
  • PyTorch topk() 用法详解:取最大值
  • CI/CD GitHub Actions配置流程
  • mongoose解析http字段值
  • 【LLaMA-Factory 实战系列】三、命令行篇 - YAML 配置与高效微调 Qwen2.5-VL
  • 走近科学IT版:FreeBSD系统下ThinkPad键盘突然按不出b和n键了!
  • Android中Navigation使用介绍
  • QT Creator的快捷键设置 复制当前行 ctrl+d 删除当前行 ctrl +y,按照 AS设置