当前位置: 首页 > news >正文

LSTM 学习笔记 之pytorch调包每个参数的解释

0、 LSTM 原理

整理优秀的文章
LSTM入门例子:根据前9年的数据预测后3年的客流(PyTorch实现)
[干货]深入浅出LSTM及其Python代码实现
整理视频
李毅宏手撕LSTM
[双语字幕]吴恩达深度学习deeplearning.ai

1 Pytorch 代码

这里直接调用了nn.lstm

 self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # utilize the LSTM model in torch.nn

下面作为初学者解释一下里面的3个参数
input_size: 这个就是输入的向量的长度or 维度,如一个单词可能占用20个维度。
hidden_size: 这个是隐藏层,其实我感觉有点全连接的意思,这个层的维度影响LSTM 网络输入的维度,换句话说,LSTM接收的数据维度不是输入什么维度就是什么维度,而是经过了隐藏层,做了一个维度的转化。
num_layers: 这里就是说堆叠了几个LSMT 结构。

2 网络定义

class LstmRNN(nn.Module):
    """
        Parameters:
        - input_size: feature size
        - hidden_size: number of hidden units
        - output_size: number of output
        - num_layers: layers of LSTM to stack
    """

    def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):
        super().__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # utilize the LSTM model in torch.nn
        self.forwardCalculation = nn.Linear(hidden_size, output_size)

    def forward(self, _x):
        x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)
        s, b, h = x.shape  # x is output, size (seq_len, batch, hidden_size)
        x = x.view(s * b, h)
        x = self.forwardCalculation(x)
        x = x.view(s, b, -1)
        return x

3 网络初始化

我们定义一个网络导出onnx ,观察 网络的具体结构

INPUT_FEATURES_NUM = 100
OUTPUT_FEATURES_NUM = 13
lstm_model = LstmRNN(INPUT_FEATURES_NUM, 16, output_size=OUTPUT_FEATURES_NUM, num_layers=2)  # 16 hidden units
print(lstm_model)
save_onnx_path= "weights/lstm_16.onnx"
input_data = torch.randn(1,150,100)

input_names = ["images"] + ["called_%d" % i for i in range(2)]
output_names = ["prob"]
torch.onnx.export(
    lstm_model,
    input_data,
    save_onnx_path,
    verbose=True,
    input_names=input_names,
    output_names=output_names,
    opset_version=12
    )

在这里插入图片描述
可以看到 LSTM W 是1x64x100;这个序列150没有了 是不是说150序列是一次一次的送的呢,所以在网络中没有体现;16是hidden,LSTM里面的W是64,这里存在一个4倍的关系。
我想这个关系和LSTM的3个门(输入+输出+遗忘+C^)有联系。
在这里插入图片描述
在这里插入图片描述
这里输出我们设置的13,如图 onnx 网络结构可视化显示也是13,至于这个150,或许就是输入有150个词,输出也是150个词吧。

在这里插入图片描述
至于LSTM的层数设置为2,则表示有2个LSTM堆叠。
在这里插入图片描述

4 网络提取

另外提取 网络方便看 每一层的维度,代码如下。

import onnx
from onnx import helper, checker
from onnx import TensorProto
import re
import argparse
model = "./weights/lstm_16.onnx"
output_model_path = "./weights/lstm_16_e.onnx"

onnx_model = onnx.load(model)
#Flatten
onnx.utils.extract_model(model, output_model_path, ['images'],['prob'])

相关文章:

  • python自动化测试之统一请求封装及通过文件实现接口关联
  • 传感器篇(一)——深度相机
  • 第一章嵌入式系统概论考点10互联网
  • 基于Spring Security 6的OAuth2 系列之十五 - 高级特性--客户端认证方式
  • 机器学习实战之基于随机森林的气温预测
  • 设计模式——职责链模式
  • Maven 中的 `<dependencyManagement>` 标签及其高级用法
  • centos7安装vscode
  • MySql从入门到精通
  • qt 控件的焦点事件
  • 共享设备管理难?MDM助力Kiosk模式一键部署
  • P2704 [NOI2001] 炮兵阵地
  • 血压高吃哪些水果比较好喵?
  • VM ubuntu20.04 虚拟机与主机之间不能互相复制的解决
  • Deepseek R1模型本地化部署+API接口调用详细教程:释放AI生产力
  • 云原生时代的后端开发:架构、工具与最佳实践
  • 6 Flink Table 和相关概念
  • TCP可靠传输的ARQ协议
  • 20250214在ubuntu20.04下使用obs studio录制外挂的1080p的USB摄像头【下载安装】
  • vm vitualbox和主机ssh连接,使用net 和仅主机网卡连接
  • 财政部党组召开2025年巡视工作会议暨第一轮巡视动员部署会
  • 海运港口股掀涨停潮!回应关税下调利好,有货代称美线舱位爆了
  • 深圳市政协原副主席王幼鹏被“双开”
  • 工行回应两售出金条疑似有杂质:情况不属实,疑似杂质应为金条售出后的外部附着物
  • 上海发布首份直播电商行业自律公约,禁止虚假宣传、商业诋毁
  • 2025年度上海市住房城乡建设管理委工程系列中级职称评审工作启动