当前位置: 首页 > news >正文

从代码学习深度学习 - LSTM PyTorch版

文章目录

  • 前言
  • 一、数据加载与预处理
    • 1.1 代码实现
    • 1.2 功能解析
  • 二、LSTM介绍
    • 2.1 LSTM原理
    • 2.2 模型定义
      • 代码解析
  • 三、训练与预测
    • 3.1 训练逻辑
      • 代码解析
    • 3.2 可视化工具
      • 功能解析
      • 功能结果
  • 总结


前言

深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文本、时间序列等)方面表现出色。本篇博客将通过一个完整的PyTorch实现,带你从零开始学习如何使用LSTM进行文本生成任务。我们将基于H.G. Wells的《时间机器》数据集,逐步展示数据预处理、模型定义、训练与预测的全过程。通过代码和文字的结合,帮助你深入理解LSTM的实现细节及其在自然语言处理中的应用。

本文的代码分为四个主要部分:

  1. 数据加载与预处理(utils_for_data.py
  2. LSTM模型定义(Jupyter Notebook中的模型部分)
  3. 训练与预测逻辑(utils_for_train.py
  4. 可视化工具(utils_for_huitu.py

以下是详细的实现与解析。


一、数据加载与预处理

首先,我们需要加载《时间机器》数据集并进行预处理。以下是utils_for_data.py中的完整代码及其功能说明。

1.1 代码实现

import random
import re
import torch
from collections import Counter

def read_time_machine():
    """将时间机器数据集加载到文本行的列表中"""
    with open('timemachine.txt', 'r') as f:
        lines = f.readlines()
    return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]

def tokenize(lines, token='word'):
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print(f'错误:未知词元类型:{
     token}')

def count_corpus(tokens):
    """统计词元的频率"""
    if not tokens:
        return Counter()
    if isinstance(tokens[0], list):
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        flattened_tokens = tokens
    return Counter(flattened_tokens)

class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    @staticmethod
    def _count_corpus(tokens):
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self[token] for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

def load_corpus_time_machine(max_tokens=-1):
    lines = read_time_machine()
    tokens = tokenize(lines, 'char')
    vocab = Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens > 0:
        corpus = corpus[:max_tokens]
    return corpus, vocab

def seq_data_iter_random(corpus, batch_size, num_steps):
    offset = random.randint(0, num_steps - 1)
    corpus = corpus[offset:]
    num_subseqs = (len(corpus) - 1) // num_steps
    initial_indices = list(range(0, num_subseqs * num_steps, num_steps))
    random.shuffle(initial_indices)

    def data(pos):
        return corpus[pos:pos + num_steps]

    num_batches = num_subseqs // batch_size
    for i in range(0, batch_size * num_batches, batch_size):
        initial_indices_per_batch = initial_indices[i:i + batch_size]
        X = [data(j) for j in initial_indices_per_batch]
        Y = [data(j + 1) for j in initial_indices_per_batch]
        yield torch.tensor(X), torch.tensor(Y)

def seq_data_iter_sequential(corpus, batch_size, num_steps):
    offset = random.randint(0, num_steps)
    num_tokens = ((len(corpus) - offset - 1) // batch_size) *
http://www.dtcms.com/a/111523.html

相关文章:

  • 【硬件模块】数码管模块
  • 理解OSPF Stub区域和各类LSA特点
  • QEMU学习之路(5)— 从0到1构建Linux系统镜像
  • 【学习篇】fastapi接口定义学习
  • 19.TCP相关实验
  • 哈密尔顿路径(Hamiltonian Path)及相关算法题目
  • 前端快速入门学习3——CSS介绍与选择器
  • 第三季:挪威
  • 阿里Qwen 创建智能体,并实现ubantu系统中调用
  • 对用户登录设计测试用例
  • Transformer由入门到精通(一):基础知识
  • CSS快速上手
  • BUUCTF-web刷题篇(10)
  • 封装自己的api签名sdk
  • 数据结构 -- 图的存储
  • SpringBoot定时任务深度优化指南
  • ubuntu部署ollama+deepseek+open-webui
  • OpenCV 实现对形似宝马标的黄黑四象限标定位
  • 字符串移位包含问题
  • CExercise_1_4continue关键字在while循环和for循环中,实现的功能有什么区别?
  • Neo4j操作数据库(Cypher语法)
  • NO.61十六届蓝桥杯备战|基础算法-双指针|唯一的雪花|逛画展|字符串|丢手绢(C++)
  • 管理系统 UI 设计:提升企业办公效率的关键
  • (多看) CExercise_05_1函数_1.2计算base的exponent次幂
  • 花卉识别分类系统,Python/resnet18/pytorch
  • MySQL简介
  • 大钲资本押注儒拉玛特全球业务,累计交付超2500条自动化生产线儒拉玛特有望重整雄风,我以为它破产倒闭了,担心很多非标兄弟们失业
  • SpringBoot配置文件多环境开发
  • 空中无人机等动态目标识别2025.4.4
  • Nacos注册中心AP模式核心源码分析(单机模式)