当前位置: 首页 > news >正文

基于FinRL深度强化学习框架的股票预测和回测交易

说明:项目仅限于科研研究

🧠 项目目标

通过使用 FinRL 强化学习框架,对 BTC历史数据进行训练,实现自动化的量化交易策略学习和回测比较,评估不同 DRL 模型(PPO、A2C、DDPG)的表现。


📂 项目结构和模块划分

1. 数据准备(convert_1m_to_4h, load_btc_4h_data)

✅ 输入数据:
  • 原始 BTC-USDT 1 分钟 K线数据(CSV格式)

🔁 数据重采样:
  • convert_1m_to_4h:将 1 分钟数据聚合为 4 小时数据

    • 聚合策略:

      • open: 第一个值

      • high: 最大值

      • low: 最小值

      • close: 最后一个值

      • volume: 总和

      • openinterest: 最后一个值

📖 数据加载:
  • load_btc_4h_data:加载已重采样的 4 小时数据,并标准化列名为 FinRL 所需格式(如 date, open, close, 等)


2. 特征工程(add_technical_indicators)

使用 FinRL 的内置 FeatureEngineer 处理器:

  • 添加技术指标列表 INDICATORS,例如:

    • MACD, RSI, CCI, ADX 等

  • 输出为增强的 DataFrame(包含技术指标列)


3. 环境创建(create_env)

构建强化学习交易环境 StockTradingEnv

  • 环境参数(env_kwargs)包括:

    • hmax: 最大交易单位

    • initial_amount: 初始资金

    • buy/sell_cost_pct: 手续费

    • tech_indicator_list: 技术指标列表

    • state_space: 状态维度

    • action_space: 动作空间维度(买/卖/持有)

  • 使用训练集构造训练环境 e_train_gym


4. 训练模型(train_agent)

训练强化学习模型,支持以下算法:

算法模块
PPOstable_baselines3.PPO
A2Cstable_baselines3.A2C
DDPGstable_baselines3.DDPG
  • 使用 FinRL 的 DRLAgent 封装训练逻辑

  • 设置训练步数 total_timesteps=10000

  • 日志输出到 ./log/


5. 回测与评估(run_backtest)

  • 使用训练好的模型对 测试集 进行预测

  • 评估指标:

    • 年化收益率

    • 年化波动率

    • 夏普比率

    • 最大回撤

    • 日均收益率等

  • 回测过程输出:

    • df_account_value:账户价值随时间变化

    • df_actions:模型执行的交易动作

    • perf_stats:评估指标

  • 可视化:

    • 保存每个模型的账户价值曲线图(如:results/ppo_account_value.png

    • 保存评估指标 CSV(如:results/ppo_stats.csv


6. 多模型比较(plot_comparison)

  • 将 PPO、A2C、DDPG 的账户价值曲线绘制在一张图上进行直观对比

  • 输出图表:results/model_comparison.png

  • 汇总所有模型的评估统计:

    • summary_statistics.csv:每个模型的关键指标并排展示


📊 输出成果汇总

文件名称内容描述
BTC-USDT_1m_4H.csv转换后的 4 小时 BTC 数据
results/ppo_account_value.pngPPO 模型账户价值走势图
results/ppo_stats.csvPPO 模型回测统计结果
results/model_comparison.png所有模型账户价值对比图
results/summary_statistics.csv所有模型的性能指标汇总(表格)


🧪 总结:流程概览图

数据导入(1m) ↓

4H 重采样 → 加载数据 → 特征工程(添加技术指标) ↓

划分训练集 / 测试集 ↓

创建训练环境(StockTradingEnv) ↓

循环训练多个模型(PPO, A2C, DDPG) ↓

模型回测(评估账户价值 + 性能指标) ↓

输出结果图表与统计文件

项目代码:

# btc_finrl_experiment.pyimport pandas as pd
import yfinance as yf
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.agents.stablebaselines3.models import DRLAgent
from stable_baselines3 import A2C, PPO, DDPG
from finrl.plot import backtest_stats, backtest_plot
from finrl.config import INDICATORS
from gym import register
import matplotlib.pyplot as plt
import datetime
import os
import numpy as npdef convert_1m_to_4h(input_path='../data/BTC-USDT_1m.csv', time_step='4H'):print(f"📥 正在读取 BTC 数据(无表头)文件:{input_path}")# 添加列名手动指定col_names = ['datetime', 'open', 'high', 'low', 'close', 'volume', 'openinterest']df = pd.read_csv(input_path, names=col_names, header=None, parse_dates=['datetime'])# 设置索引为 datetime,方便重采样df.set_index('datetime', inplace=True)print(f"🔁 正在将数据从 1 分钟 转换为 {time_step} 间隔...")df_resampled = df.resample(time_step).agg({'open': 'first','high': 'max','low': 'min','close': 'last','volume': 'sum','openinterest': 'last'})df_resampled.dropna(inplace=True)df_resampled.reset_index(inplace=True)# 如果没指定输出路径,自动生成base, ext = os.path.splitext(input_path)output_path = f"{base}_{time_step}.csv".replace(":", "")# 保存文件df_resampled.to_csv(output_path, index=False)print(f"✅ 已成功保存 {time_step} 数据到:{output_path}")return output_pathdef load_btc_4h_data(csv_path):print(f"📖 正在加载 4 小时 BTC 数据:{csv_path}")df = pd.read_csv(csv_path, parse_dates=['datetime'])# 重命名列,FinRL 更习惯用 Date 为时间列df.columns = ['date', 'open', 'high', 'low', 'close', 'volume', 'openinterest']df['tic'] = 'BTC'# 删除可能的空行df.dropna(inplace=True)print("✅ 加载成功,数据格式如下:")print(df.head())return df# -------------------- Step 1: Download BTC Data --------------------# def download_btc_data(cache_path='../data/btc.csv'):
#
#     if os.path.exists(cache_path):
#         print("📂 正在读取本地缓存数据...")
#         return pd.read_csv(cache_path, parse_dates=['Date'])
#     print("本地无缓存数据,正在从 CoinGecko 下载 BTC 历史数据...")
#     end = datetime.datetime.now()
#     start = end - datetime.timedelta(days=365 * 3)
#
#     try:
#         df = yf.download('BTC-USD', start=start, end=end, threads=False, auto_adjust=True)
#
#         # 检查是否成功下载
#         if df is None or df.empty or not isinstance(df, pd.DataFrame):
#             raise ValueError("❌ 下载失败,未获取到有效 BTC 数据。")
#
#         # 清理数据
#         df = df.reset_index()
#         df.columns = [c.replace(' ', '_') for c in df.columns]
#         df.dropna(inplace=True)
#         print("✅ BTC 数据下载成功!")
#         return df
#
#     except Exception as e:
#         print(f"❌ 发生错误:{e}")
#         print("📌 可能原因:网络不通、VPN干扰、或被限流")
#         print("👉 建议:1)切换网络;2)间隔一段时间再试;3)尝试使用本地缓存数据")
#         exit(1)# -------------------- Step 2: Feature Engineering --------------------def add_technical_indicators(df):from finrl.meta.preprocessor.preprocessors import FeatureEngineerfe = FeatureEngineer(use_technical_indicator=True, tech_indicator_list=INDICATORS)df = fe.preprocess_data(df)return df# -------------------- Step 3: Create Environment --------------------def create_env(df):# 股票数量(BTC 视为一只“股票”)stock_dimension = 1# 状态空间的大小:# 1项账户余额 + 2项每支股票的状态(持仓股数、当前价格) + 每支股票的技术指标数量state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension# 构造环境所需参数env_kwargs = {"hmax": 10,  # 每次交易最大买/卖的单位数量(如最多买10个单位BTC)"initial_amount": 100000,  # 初始账户余额"buy_cost_pct": [0.001],  # 买入手续费(千分之一),用列表形式支持多股票"sell_cost_pct": [0.001],  # 卖出手续费(千分之一)"state_space": state_space,  # 状态空间大小"stock_dim": stock_dimension,  # 股票维度(BTC 为1)"tech_indicator_list": INDICATORS,  # 技术指标列表(如MACD、RSI等)"action_space": stock_dimension,  # 动作空间大小(买、卖、持有每只股票)"reward_scaling": 1e-4,  # 奖励缩放因子,避免过大的值干扰学习"num_stock_shares": [0] * stock_dimension  # 初始每只股票持仓数量(BTC 持仓为 0)}# 初始化自定义的交易环境(基于 OpenAI Gym)e_train_gym = StockTradingEnv(df=df, **env_kwargs)# 返回环境对象和参数(便于评估和测试使用)return e_train_gym, env_kwargs# -------------------- Step 4: Train Agent --------------------def train_agent(e_train_gym, model_name, model_kwargs=None):# 初始化强化学习代理,传入训练环境 e_train_gymagent = DRLAgent(env=e_train_gym)# 获取指定算法名称的模型(如 'ppo', 'a2c', 'ddpg' 等),# 可通过 model_kwargs 传入模型参数配置model = agent.get_model(model_name, model_kwargs=model_kwargs, tensorboard_log="./log",)# 使用强化学习代理训练模型,设置训练步数(这里是10000步),# 并指定 TensorBoard 日志名称为模型名,方便训练过程监控trained_model = agent.train_model(model=model, tb_log_name=model_name, total_timesteps=10000)# 返回训练好的模型,后续可以用来回测或进一步操作return trained_model# -------------------- Step 5: Backtest --------------------def run_backtest(df, env_kwargs, trained_model, algo_name):# 创建测试环境,传入数据df和环境参数env_kwargs初始化StockTradingEnve_test_gym = StockTradingEnv(df=df, **env_kwargs)# 创建DRL代理实例,绑定测试环境agent = DRLAgent(env=e_test_gym)# 使用训练好的模型和测试环境进行预测,得到账户价值和交易动作的DataFramedf_account_value, df_actions = agent.DRL_prediction(model=trained_model, environment=e_test_gym)# 计算回测性能指标,传入账户价值的DataFrame,value_col_name指明列名perf_stats = backtest_stats(df_account_value, value_col_name="account_value")# 绘制账户价值随时间变化图fig, ax = plt.subplots()  # 新建图表和坐标轴对象# 将账户价值数据设为索引为日期后绘图,标题包含算法名称(大写)df_account_value.set_index("date")['account_value'].plot(ax=ax, title=f"Account Value - {algo_name.upper()}")plt.xlabel("Date")   # X轴标签为日期plt.ylabel("Value")  # Y轴标签为账户价值plt.grid(True)       # 显示网格线fig.tight_layout()   # 自动调整子图参数,使图像更紧凑不重叠# 如果结果目录不存在,则创建os.makedirs("results", exist_ok=True)# 保存账户价值走势图为PNG文件,文件名中包含算法名plt.savefig(f"results/{algo_name}_account_value.png")plt.close()  # 关闭图表,释放内存# 将回测性能指标转为DataFrame并转置,使指标成为列名stats_df = pd.DataFrame(perf_stats).T# 保存性能指标CSV文件stats_df.to_csv(f"results/{algo_name}_stats.csv")# 返回账户价值数据,交易动作数据,和性能指标return df_account_value, df_actions, perf_stats# -------------------- Compare Performance Plot --------------------def plot_comparison(results):plt.figure(figsize=(10, 6))for algo, data in results.items():df = data['account_value']plt.plot(df['date'], df['account_value'], label=algo.upper())plt.title("Account Value Comparison")plt.xlabel("Date")plt.ylabel("Portfolio Value")plt.legend()plt.grid(True)plt.tight_layout()plt.savefig("results/model_comparison.png")plt.close()# -------------------- Main Script --------------------def main():BTC1m_data_path = '../data/BTC-USDT_1m.csv'BTC4h_data_path = '../data/BTC-USDT_1m_4H.csv'print("[0] 下载 BTC 数据...")# df = download_btc_data()if BTC4h_data_path and os.path.exists(BTC4h_data_path):print(f"📂 已找到 4 小时数据本地缓存:{BTC4h_data_path}")else:print(f"📥 正在从 1 分钟数据转换为 4 小时数据...")BTC4h_data_path = convert_1m_to_4h(input_path=BTC1m_data_path, time_step='4H')df = load_btc_4h_data(csv_path=BTC4h_data_path)split_index = int(len(df) * 0.8)train_df = df[:split_index]  # 前 80% 数据用于训练test_df = df[split_index:]  # 后 20% 数据用于回测print("[2] 添加技术指标...")   #训练集和测试集都需要执行train_df = add_technical_indicators(train_df)test_df = add_technical_indicators(test_df)print("[3] 创建训练环境...")e_train_gym, env_kwargs = create_env(train_df)   # e_train_gym是使用训练集构建的训练环境,env_kwargs是环境参数results = {}for algo in ["ppo", "a2c", "ddpg"]:print(f"[4] 训练模型: {algo.upper()}...")model = train_agent(e_train_gym, algo)   # 使用e_train_gym训练模型,返回训练好的模型对象print(f"[5] 回测模型: {algo.upper()}...")df_account_value, df_actions, perf_stats = run_backtest(test_df, env_kwargs, model, algo)results[algo] = {"account_value": df_account_value,"actions": df_actions,"stats": perf_stats}print("\n[6] 所有模型训练与回测完成,生成比较图与CSV。")plot_comparison(results)# 保存所有评估结果汇总为一个 CSVsummary_stats = pd.DataFrame({algo.upper(): pd.DataFrame(stats).iloc[:, 0] for algo, stats in results.items()})summary_stats.to_csv("results/summary_statistics.csv")if __name__ == "__main__":main()

输出示例:

参考项目:

使用 FinRL 框架进行股票预测和回测交易-CSDN博客

FinRL项目深度解析:基于深度强化学习的投资组合分配实战指南-CSDN博客

http://www.dtcms.com/a/279331.html

相关文章:

  • 迁移学习:知识复用的智能迁移引擎 | 从理论到实践的跨域赋能范式
  • 什么是神经网络,常用的神经网络,如何训练一个神经网络
  • python 循环遍历取出偶数
  • 「日拱一码」027 深度学习库——PyTorch Geometric(PyG)
  • MCP基础知识二(实战通信方式之Streamable HTTP)
  • 【CTF学习】PWN基础工具的使用(binwalk、foremost、Wireshark、WinHex)
  • ewdyfdfytty
  • LangChain教程——文本嵌入模型
  • 20250714让荣品RD-RK3588开发板在Android13下长按关机
  • Debezium日常分享系列之:提升Debezium性能
  • 制造业实战:数字化集采如何保障千种备件“不断供、不积压”?
  • 16.避免使用裸 except
  • MFC扩展库BCGControlBar Pro v36.2新版亮点:可视化设计器升级
  • 计算机毕业设计Java轩辕购物商城管理系统 基于 SpringBoot 的轩辕电商商城管理系统 Java 轩辕购物平台管理系统设计与实现
  • 面向对象的设计模式
  • 【数据结构】树(堆)·上
  • js的局部变量和全局变量
  • 测试驱动开发(TDD)实战:在 Spring 框架实现中践行 “红 - 绿 - 重构“ 循环
  • Bash vs PowerShell | 从 CMD 到跨平台工具:Bash 与 PowerShell 的全方位对比
  • vue3 服务端渲染时请求接口没有等到数据,但是客户端渲染是请求接口又可以得到数据
  • 7.14 map | 内存 | 二维dp | 二维前缀和
  • python+Request提取cookie
  • 电脑升级Experience
  • python transformers笔记(Trainer类)
  • 代码随想录算法训练营第三十五天|416. 分割等和子集
  • LLM表征工程还有哪些值得做的地方
  • 内部文件审计:企业文件服务器审计对网络安全提升有哪些帮助?
  • 防火墙技术概述
  • Qt轮廓分析设计+算法+避坑
  • Redis技术笔记-主从复制、哨兵与持久化实战指南