当前位置: 首页 > news >正文

经典深度学习模型——LSTM【原理解释 代码(以2025年美赛C题为例)】

一、LSTM 模型介绍

1. 基本概念

LSTM(Long Short-Term Memory,长短期记忆网络)是一种 循环神经网络(RNN) 的改进模型,用于解决 RNN 的长序列训练中 梯度消失或梯度爆炸 问题。LSTM 通过引入 门控机制,能够有效捕捉长期依赖信息。

2. 模型结构

LSTM 单元主要包括以下部分:

  1. 输入门 iti_tit:控制当前输入 xtx_txt 对单元状态的影响。
  2. 遗忘门 ftf_tft:控制前一时刻状态 Ct−1C_{t-1}Ct1 的保留或遗忘。
  3. 输出门 oto_tot:控制当前单元状态 CtC_tCt 对输出 hth_tht 的影响。
  4. 候选状态 C~t\tilde{C}_tC~t:对新信息的候选状态。
  5. 单元状态 CtC_tCt:贯穿时间的内部记忆,存储长期信息。
  6. 隐藏状态 hth_tht:对外输出,作为下一时刻输入的一部分。
    在这里插入图片描述

3. 数学公式

设输入为 xtx_txt,前一隐藏状态为 ht−1h_{t-1}ht1,前一单元状态为 Ct−1C_{t-1}Ct1,则 LSTM 的公式如下:

  1. 遗忘门

ft=σ(Wf⋅[ht−1,xt]+bf) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

  1. 输入门

it=σ(Wi⋅[ht−1,xt]+bi) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

C~t=tanh⁡(WC⋅[ht−1,xt]+bC) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

  1. 单元状态更新

Ct=ft⊙Ct−1+it⊙C~t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ftCt1+itC~t

  1. 输出门和隐藏状态

ot=σ(Wo⋅[ht−1,xt]+bo) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

ht=ot⊙tanh⁡(Ct) h_t = o_t \odot \tanh(C_t) ht=ottanh(Ct)

其中:

  • σ\sigmaσ 是 sigmoid 激活函数
  • ⊙\odot 表示逐元素乘法
  • Wf,Wi,WC,WoW_f, W_i, W_C, W_oWf,Wi,WC,Wobf,bi,bC,bob_f, b_i, b_C, b_obf,bi,bC,bo 为模型可学习参数

4. 输入与输出

  • 输入:时间序列数据 [x1,x2,...,xT][x_1, x_2, ..., x_T][x1,x2,...,xT],每个 xtx_txt 可以是多维特征(如运动员人数、历史奖牌数等)。

  • 输出

    • 单步预测:当前时刻输出一个值(如某国金牌数)。
    • 多步预测:整个序列预测(如未来奥运奖牌预测)。

二、案例分析:用 LSTM 预测奥运奖牌数

一、数据与特征设计

1. 输入特征(每个国家,每届奥运会)

特征说明类型
金牌数(历史)历届奥运会该国金牌数连续数值
总奖牌数(历史)历届奥运会总奖牌数连续数值
东道主效应是否为东道主(0/1)二值
运动员数量该国参赛运动员总数连续数值
优势项目数量该国在过去若干届奥运中获奖的项目数量连续数值
伟大教练效应教练影响力指标(历史金牌贡献等)连续数值

这些特征可以通过 summerOly_medal_counts.csvsummerOly_hosts.csvsummerOly_athletes.csvsummerOly_programs.csv 等文件提取。


2. 数据处理流程

  1. 时间序列构建:按国家整理历史金牌、总奖牌序列。
  2. 特征归一化:对连续特征使用 MinMaxScalerStandardScaler
  3. 序列切片:设定 seq_length,使用前 seq_length 届奥运数据预测下一届。
  4. 缺失值处理:未获奖年份填 0,缺少运动员数填均值。

二、模型设计

1. 模型结构

输入序列(seq_length 届,每届 n_features)│LSTM层 (64单元, tanh)│Dropout层 (MC Dropout, rate=0.2)│Dense层 (输出每个国家 金牌+总奖牌数)│输出
  • 输入形状(batch_size, seq_length, n_countries * n_features_per_country)
  • 输出形状(batch_size, n_countries * 2),每个国家输出金牌和总奖牌

2. 多输出设计

  • 金牌和总奖牌作为 多输出回归,可以同时预测。
  • 损失函数:MSE(均方误差)。

3. 不确定性估计

  • MC Dropout

    • 训练时使用 Dropout
    • 预测时开启 Dropout,进行多次采样
    • 计算均值和标准差,得到预测区间
  • 多模型训练(可选):训练多个模型取平均和方差


三、Python 框架示例

import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
import matplotlib.pyplot as plt# ----------------------
# 1. 数据准备
# ----------------------
# 示例读取数据
medal_df = pd.read_csv('summerOly_medal_counts.csv')
host_df = pd.read_csv('summerOly_hosts.csv')
athlete_df = pd.read_csv('summerOly_athletes.csv')countries = ['USA','CHN','GBR','FRA','JPN']
years = sorted(medal_df['Year'].unique())
seq_length = 5
n_features = 4  # 金牌, 总奖牌, 东道主, 运动员数
n_countries = len(countries)# 初始化特征矩阵
features = np.zeros((len(years), n_countries, n_features))
for i, year in enumerate(years):for j, country in enumerate(countries):row = medal_df[(medal_df['Year']==year) & (medal_df['Country']==country)]features[i,j,0] = row['Gold'].values[0] if not row.empty else 0features[i,j,1] = row['Total'].values[0] if not row.empty else 0features[i,j,2] = 1 if host_df[host_df['Year']==year]['HostCountry'].values[0]==country else 0ath_row = athlete_df[(athlete_df['Year']==year) & (athlete_df['Country']==country)]features[i,j,3] = ath_row['AthleteCount'].values[0] if not ath_row.empty else 0# 归一化
scaler = MinMaxScaler()
features_scaled = scaler.fit_transform(features.reshape(len(years), -1))
features_scaled = features_scaled.reshape(len(years), n_countries, n_features)# 构建训练集
X, y = [], []
for i in range(seq_length, len(years)):X.append(features_scaled[i-seq_length:i])y.append(features_scaled[i,:, :2])
X, y = np.array(X), np.array(y)
X_train = X.reshape(X.shape[0], seq_length, -1)
y_train = y.reshape(y.shape[0], -1)# ----------------------
# 2. 模型构建
# ----------------------
def build_model(dropout_rate=0.2):model = Sequential()model.add(LSTM(64, activation='tanh', input_shape=(seq_length, n_countries*n_features)))model.add(Dropout(dropout_rate))model.add(Dense(n_countries*2))model.compile(optimizer='adam', loss='mse')return modelmodel = build_model()
model.fit(X_train, y_train, epochs=200, batch_size=1, verbose=1)# ----------------------
# 3. MC Dropout 预测不确定性
# ----------------------
def mc_dropout_predict(model, X, n_samples=100):preds = []for _ in range(n_samples):pred = model(X, training=True).numpy()preds.append(pred)preds = np.array(preds)return preds.mean(axis=0), preds.std(axis=0)last_seq = X_train[-1].reshape(1, seq_length, -1)
mean_scaled, std_scaled = mc_dropout_predict(model, last_seq, n_samples=100)# 逆归一化
pred_mean_full = np.concatenate([mean_scaled, np.zeros((1, n_countries*(n_features-2)))], axis=1)
pred_mean = scaler.inverse_transform(pred_mean_full)[:, :n_countries*2].reshape(n_countries,2)
pred_std_full = np.concatenate([std_scaled, np.zeros((1, n_countries*(n_features-2)))], axis=1)
pred_std = scaler.inverse_transform(pred_std_full)[:, :n_countries*2].reshape(n_countries,2)# 输出预测结果和不确定性
for i,country in enumerate(countries):print(f"{country}: 金牌={pred_mean[i,0]:.1f}±{pred_std[i,0]:.1f}, 总奖牌={pred_mean[i,1]:.1f}±{pred_std[i,1]:.1f}")# ----------------------
# 4. 可视化预测区间
# ----------------------
fig, ax = plt.subplots(figsize=(10,6))
x = np.arange(n_countries)
ax.bar(x-0.15, pred_mean[:,0], width=0.3, yerr=pred_std[:,0], label='Gold', capsize=5)
ax.bar(x+0.15, pred_mean[:,1], width=0.3, yerr=pred_std[:,1], label='Total', capsize=5)
ax.set_xticks(x)
ax.set_xticklabels(countries)
ax.set_ylabel('Predicted Medal Count')
ax.set_title('2025 Predicted Olympic Medals with Uncertainty')
ax.legend()
plt.show()

四、特点总结

  1. 多特征输入:金牌、总奖牌、东道主、运动员数,可扩展到优势项目和伟大教练效应
  2. 多输出:同时预测金牌和总奖牌数
  3. 不确定性估计:MC Dropout 提供预测均值 ± 标准差
  4. 可视化:柱状图展示预测值及不确定性区间
  5. 模型评估:可使用 MSE、MAE 或排名误差(预测排名与实际排名差)

结构总结

层级作用
输入层时间序列特征(历史金牌、东道主、运动员年龄等)
LSTM 层捕捉历史依赖关系,提取时间序列模式
全连接层输出每个国家奖牌预测值
输出金牌数 / 总奖牌数(多输出)
http://www.dtcms.com/a/331001.html

相关文章:

  • FreeRTOS-C语言指针笔记
  • 【入门级-C++程序设计:13、STL 模板:栈(stack)、队 列(queue)、 链 表(list)、 向 量(vector) 等容器】
  • gitlab的ci/cd变量如何批量添加
  • 【P81 10-7】OpenCV Python【实战项目】——车辆识别、车流统计(图像/视频加载、图像运算与处理、形态学、轮廓查找、车辆统计及显示)
  • 智能清扫新纪元:有鹿机器人如何用AI点亮我们的城市角落
  • Streamlit实现Qwen对话机器人
  • CVPR 2025 | 机器人操控 | RoboGround:用“掩码”中介表示,让机器人跨场景泛化更聪明
  • GaussDB数据库架构师修炼(十六) 如何选择磁盘
  • Helm-K8s包管理(三)新建、编辑一个Chart
  • k8s+isulad 重装
  • Seata学习(三):Seata AT模式练习
  • CMake语法与Bash语法的区别
  • 解剖HashMap的put <三> JDK1.8
  • 会议系统进程池管理:初始化、通信与状态同步详解
  • 从 Notion 的水土不服到 Codes 的本土突围:研发管理工具的适性之道​
  • Apache 虚拟主机配置冲突导致 404 错误的排查总结
  • [机器学习]08-基于逻辑回归模型的鸢尾花数据集分类
  • AXI GPIO 2——ZYNQ学习笔记
  • 力扣top100(day03-02)--图论
  • Java 技术栈中间件优雅停机方案设计与实现全景图
  • 【JavaEE】多线程 -- 线程状态
  • 数据结构之顺序表相关算法题
  • 【数据分享】351个地级市农业相关数据(2013-2022)-有缺失值
  • linux中date命令
  • SAP-ABAP:SAP消息系统深度解析:架构设计与企业级应用实践
  • Wireshark中捕获的大量UDP数据
  • 23.Linux : ftp服务及配置详解
  • (论文速读)DiffusionDet - 扩散模型在目标检测中的开创性应用
  • AI搜索重构下的GEO优化服务商格局观察
  • 李沐-第六章-LeNet训练中的pycharm jupyter-notebook Animator类的显示问题