当前位置: 首页 > news >正文

2025-11-15 学习记录--Python-LSTM模型定义(PyTorch)

LSTM模型定义(PyTorch

  • LSTM(Long Short-Term Memory)长短期记忆网络
    是 RNN(循环神经网络)的一种改进版本,主要用来解决 时间序列预测需要记住过去信息的任务,例如:👇🏻
    • PM2.5 时间序列预测
    • 文本生成
    • 股票预测
    • 温度预测
    • 电力负载预测
  • 普通 RNN 的问题是运行久了就遗忘前面的信息(梯度消失),而 LSTM 通过 “门结构(gates)” 让网络能够选择:👇🏻
    • 记住(Keep)
    • 忘掉(Forget)
    • 更新(Update)
  • 这些信息。
# LSTM 模型定义(PyTorch)
# ---------------------------------------------------------
# 本文件实现一个简单的单层 LSTM 回归模型,用于预测下一小时 PM2.5。
# 输入维度: 24(窗口长度)
# 输出维度: 1(预测未来1小时 PM2.5)
# ---------------------------------------------------------import torch  # 导入 PyTorch 主包,用于张量运算与设备管理
import torch.nn as nn  # 导入神经网络模块的子包,习惯性重命名为 nn# 定义一个继承自 nn.Module 的 LSTM 模型类
class LSTMModel(nn.Module):def __init__(self, input_size=1, hidden_size=64, num_layers=1, dropout=0.0):super(LSTMModel, self).__init__()  # 调用父类构造函数,初始化模块内部状态# 定义一个 LSTM 层# input_size: 每个时间步的特征维度(这里每小时只有一个 PM2.5 值,所以是 1)# hidden_size: LSTM 隐藏态的维度(即每个时间步输出向量的长度)# num_layers: LSTM 堆叠层数(几层 LSTM 单元叠在一起)# batch_first=True: 输入/输出张量的形状为 (batch, seq_len, feature)# dropout: 当 num_layers>1 时,层间 dropout 的概率self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True,dropout=dropout)# 定义一个线性全连接层,把 LSTM 的最后隐藏态映射为预测值# 输入维度 hidden_size -> 输出维度 1(回归预测一个数)self.fc = nn.Linear(hidden_size, 1)def forward(self, x):# forward 定义前向传播逻辑,x 是模型输入张量# 期望 x.shape = (batch_size, seq_len=24, feature=1)out, _ = self.lstm(x)  # 把输入传入 LSTM,out 为每个时间步的输出(shape=(batch, seq_len, hidden_size))# 第二个返回值是 (h_n, c_n) —— 最后一个时间步的隐状态与细胞状态,这里用 _ 忽略它# 取 LSTM 输出序列中最后一个时间步的输出作为序列级特征# out[:, -1, :] 的形状为 (batch_size, hidden_size)out = out[:, -1, :]# 把最后时间步的隐藏向量通过全连接层映射为标量预测值# 最终 out 的形状为 (batch_size, 1)out = self.fc(out)return out  # 返回预测结果(未做激活,回归任务通常直接输出实数)
http://www.dtcms.com/a/613306.html

相关文章:

  • PLB-TV 4K+H.265 编码,无广告超流畅
  • Transformer结构完全解读:从Attention到LLM
  • 【ZeroRange WebRTC】REMB(Receiver Estimated Maximum Bitrate)技术深度分析
  • sharding-jdbc 绑定表
  • 郑州网站制作wordpress 密码失败
  • Dify-Token 应用实现
  • webRTC:流程和socket搭建信令服务器
  • PoA 如何把 CodexField 从“创作平台”推向“内容经济网络”
  • 厦门 外贸商城网站建设网站推广哪个好
  • 小米Java开发校园招聘面试题及参考答案
  • 哪个网站做头像比较好网片式防护围栏
  • LangChain Memory 使用示例
  • 【剑斩OFFER】算法的暴力美学——寻找数组的中心下标
  • APIs---Day01
  • 猪只行为状态识别与分类:基于YOLO13-C3k2-ESC模型的实现与优化_3
  • 宁波网站建设方案推广公司网站设计
  • [智能体设计模式] 第10章:模型上下文协议(MCP)
  • 使用docker-composer安装MySQL8、Redis7、minio脚本
  • linux的nginx版本升级
  • 支持selenium的chrome driver更新到142.0.7444.162
  • 【 Java八股文面试 | JVM篇 内存结构、类加载、垃圾回收与性能调优 】
  • 网站开发和前端是一样吗化妆品网站模板
  • Mujoco 机械臂进行 PBVS 基于位置的视觉伺服思路
  • 【玄机靶场】Crypto-常见编码
  • 360加固 APK 脱壳研究:安全工程师视角下的防护与还原原理解析
  • AI面试速记
  • ASC学习笔记0018:返回属性集实例的引用(如果此组件中存在)
  • SpringBoot中整合RabbitMQ(测试+部署上线 最完整)
  • 第15章 并发编程
  • 【高级机器学习】 13. 因果推断