当前位置: 首页 > news >正文

pytorch LSTM 结构详解

最近项目用到了LSTM ,但是对LSTM 的输入输出不是很理解,对此,我详细查找了lstm 的资料

import torch.nn as nnclass LSTMModel(nn.Module):def __init__(self, input_size=1, hidden_size=50, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, 1)  # 1 表示预测输出变量为1def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)out, _ = self.lstm(x, (h0, c0))out = self.fc(out[:, -1, :])return out # out 形状为(batch_size,1)
  • input_size=1:输入特征的维度,适用于单变量时间序列。

  • hidden_size=50:LSTM 隐藏层的维度,决定了模型的记忆能力。

  • num_layers=2:堆叠的 LSTM 层数,增加层数可以提升模型的表达能力。

  • batch_first=True:指定输入和输出的张量形状为 (batch_size, seq_len, input_size)

  • self.fc:一个全连接层,将 LSTM 的输出映射到最终的预测值。

  • batch_size 表示批次、seq_len 表示窗口大小、input_size 表示输入尺寸,单变量输入为1 ,多变量要基于个数变化

  • 初始化隐藏状态和细胞状态

    • h0c0 分别表示初始的隐藏状态和细胞状态,形状为 (num_layers, batch_size, hidden_size)

    • 在每次前向传播时,初始化为零张量。

  • LSTM 层处理

    • self.lstm(x, (h0, c0)):将输入 x 和初始状态传入 LSTM 层,输出 out 和新的状态。

    • out 的形状为 (batch_size, seq_len, hidden_size),包含了每个时间步的输出。

  • 全连接层映射

    • out[:, -1, :]:提取序列中最后一个时间步的输出。

    • self.fc(...):将提取的输出通过全连接层,得到最终的预测结果。

相关文章:

  • 安卓新建项目时,Gradle下载慢下载如何用国内的镜像
  • 【博客系统】博客系统第四弹:令牌技术
  • 【python深度学习】Day34 GPU训练及类的call方法
  • 智能指针
  • 科研经验贴:AI领域的研究方向总结
  • DAO模式
  • Java转Go日记(五十六):gin 渲染
  • 提高 Maven 项目的编译效率
  • 大厂技术大神远程 3 年,凌晨 1 点到 6 点竟开会 77 次。同事一脸震惊,网友:身体还扛得住吗?
  • matlab时间反转镜算法
  • Appium+python自动化(四)- 如何查看程序所占端口号和IP
  • 动态防御体系实战:AI如何重构DDoS攻防逻辑
  • 交安安全员:交通工程安全领域的关键角色
  • DB-GPT扩展自定义Agent配置说明
  • 同为科技领军智能电源分配单元技术,助力物联网与计量高质量发展
  • Linux安装Nginx并配置转发
  • WPF性能优化之延迟加载(解决页面卡顿问题)
  • 园区/小区执法仪部署指南:ZeroNews低成本+高带宽方案”
  • 实时操作系统革命:实时Linux驱动的智能时代底层重构
  • EasyExcel使用
  • 微信网站结构/怎么让百度收录我的网站
  • 网站访问量什么意思/产品软文范例800字
  • 深圳城乡和建设局网站首页/站长统计幸福宝下载
  • 彩票引流推广方法/亚马逊关键词优化怎么做
  • 网站建设添加汉语/东莞营销推广公司
  • 网站在工信部备案如何做/网络营销的基本方法有哪些