当前位置: 首页 > news >正文

多源信息融合智能投资【“图神经网络+强化学习“的融合架构】【低配显卡正常运行】

模型架构思路简述

本模型采用"图神经网络+强化学习"的融合架构,核心思路是通过多源信息融合进行智能投资决策,并实现决策可解释性。架构设计分为三个关键层次:


1. 信息融合层(图神经网络核心)
  • 节点特征:每只股票作为图节点,包含:
    • 单股历史行情(价格、成交量等)
    • 基本面数据(PE、PB等)
  • 边关系:构建三种连接关系
    • 同行业关联(行业分类)
    • 竞争关系(业务相似度)
    • 供应链关系(上下游企业)
  • 全局特征:独立输入层处理
    • 大盘指数(上证、深证)
    • 舆情情感分析
    • 政策事件向量
    • 宏观经济指标

处理流程

个股特征
GNN卷积层
行业关系
竞争关系
全局特征
特征融合
决策输出

2. 决策生成层(强化学习核心)
  • 双头输出机制

    • 动作头:生成4类操作
      • 0: 持有
      • 1: 买入
      • 2: 卖出
      • 3: 做空
    • 解释头:输出5类预设原因
      1. 技术指标信号
      2. 行业趋势变化
      3. 政策利好/利空
      4. 市场情绪转向
      5. 估值水平异动
  • 状态-动作映射

    股票特征 + 图结构 + 全局特征 => 联合嵌入 => 动作概率分布
    

3. 训练优化层
  • REINFORCE策略梯度
    • 奖励设计:投资组合收益率变化
    • 探索机制:熵正则化项
  • 多目标优化
    • 主目标:最大化累积奖励
    • 辅助目标:动作原因可解释性

创新设计亮点
  1. 时空特征融合

    • 空间维度:GNN捕捉股票间关联
    • 时间维度:LSTM处理历史序列(可扩展)
  2. 可解释性机制

    • 预设原因标签与特征关联
    • 决策时可输出具体依据
  3. 轻量化设计

    • 参数共享:所有股票共用同一GNN
    • 维度控制:适配4GB显存限制
  4. 市场适应能力

    • 全局特征通道响应政策变化
    • 行业关系图动态更新机制

工作流程
市场环境GNN处理器策略网络交易系统提供多源数据生成联合特征向量输出动作+原因执行交易指令返回收益反馈策略梯度更新市场环境GNN处理器策略网络交易系统

该架构实现了从市场信息感知到投资决策生成的端到端学习,同时满足可解释性要求,为A股市场的复杂动态环境提供了适配性强的解决方案。

代码实现

程序运行输出

Epoch 1/7, Loss: -8439.5590, Reward: -74.6602
Epoch 2/7, Loss: -7382.1190, Reward: 171.4712
Epoch 3/7, Loss: -7427.5151, Reward: -20.3877
Epoch 4/7, Loss: -6279.8277, Reward: 107.2520
Epoch 5/7, Loss: -5892.7854, Reward: -17.9236
Epoch 6/7, Loss: -7332.6072, Reward: 21.7527
Epoch 7/7, Loss: -7767.7623, Reward: 28.9665
测试轮次奖励: 0.5648
测试轮次奖励: 0.9061
测试轮次奖励: 0.3593
测试轮次奖励: -2.5653
测试轮次奖励: 0.4737
平均测试奖励: -0.0523

代码正文

###
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import gym
from gym import spaces# ==================== 常量定义 ====================
N_STOCKS = 50  # 股票数量
N_FEATURES = 20  # 每只股票的特征维度
GLOBAL_FEAT_DIM = 10  # 全局特征维度
GNN_HIDDEN_DIM = 64  # GNN隐藏层维度
FC_HIDDEN_DIM = 128  # 全连接层隐藏维度
ACTION_DIM = 4  # 动作维度 (0:持有, 1:买入, 2:卖出, 3:做空)
N_EPOCHS = 7  # 训练轮数
BATCH_SIZE = 32  # 批次大小
LR = 1e-3  # 学习率
GAMMA = 0.99  # 折扣因子
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'# ==================== 图神经网络模型 ====================
class GNNPolicy(nn.Module):def __init__(self):super(GNNPolicy, self).__init__()# 图卷积层self.conv1 = GCNConv(N_FEATURES, GNN_HIDDEN_DIM)self.conv2 = GCNConv(GNN_HIDDEN_DIM, GNN_HIDDEN_DIM)# 全局特征处理器self.global_fc = nn.Sequential(nn.Linear(GLOBAL_FEAT_DIM, FC_HIDDEN_DIM),nn.ReLU())# 动作决策层self.action_head = nn.Sequential(nn.Linear(GNN_HIDDEN_DIM + FC_HIDDEN_DIM, FC_HIDDEN_DIM),nn.ReLU(),nn.Linear(FC_HIDDEN_DIM, ACTION_DIM))# 原因解释层self.reason_head = nn.Sequential(nn.Linear(GNN_HIDDEN_DIM + FC_HIDDEN_DIM, FC_HIDDEN_DIM),nn.ReLU(),nn.Linear(FC_HIDDEN_DIM, 5)  # 5种预设原因类型)def forward(self, x, edge_index, global_feat):# 图神经网络处理x = F.relu(self.conv1(x, edge_index))  # (N_STOCKS, GNN_HIDDEN_DIM)x = F.relu(self.conv2(x, edge_index))  # (N_STOCKS, GNN_HIDDEN_DIM)# 全局特征处理global_feat = self.global_fc(global_feat)  # (BATCH_SIZE, FC_HIDDEN_DIM)global_feat = global_feat.unsqueeze(1).repeat(1, N_STOCKS, 1)  # (BATCH_SIZE, N_STOCKS, FC_HIDDEN_DIM)# 拼接特征x = x.unsqueeze(0)  # (1, N_STOCKS, GNN_HIDDEN_DIM)combined = torch.cat([x, global_feat], dim=-1)  # (BATCH_SIZE, N_STOCKS, GNN_HIDDEN_DIM+FC_HIDDEN_DIM)# 生成动作和原因action_logits = self.action_head(combined)  # (BATCH_SIZE, N_STOCKS, ACTION_DIM)reason_logits = self.reason_head(combined)  # (BATCH_SIZE, N_STOCKS, 5)return action_logits, reason_logits# ==================== 强化学习环境 ====================
class StockTradingEnv(gym.Env):def __init__(self):super(StockTradingEnv, self).__init__()# 动作空间: 每只股票4种动作self.action_space = spaces.MultiDiscrete([ACTION_DIM] * N_STOCKS)# 状态空间: 股票特征+全局特征self.observation_space = spaces.Dict({'graph': spaces.Box(low=-np.inf, high=np.inf, shape=(N_STOCKS, N_FEATURES)),'global_feat': spaces.Box(low=-np.inf, high=np.inf, shape=(GLOBAL_FEAT_DIM,)),'edge_index': spaces.Box(low=0, high=N_STOCKS-1, shape=(2, 100), dtype=np.int64)})# 初始化状态self.reset()def reset(self):# 生成随机图数据 (模拟股票关系)self.x = torch.randn(N_STOCKS, N_FEATURES)  # 股票特征self.edge_index = torch.randint(0, N_STOCKS, (2, 100))  # 随机边# 全局特征 (模拟市场环境)self.global_feat = torch.randn(GLOBAL_FEAT_DIM)return {'graph': self.x.numpy(),'global_feat': self.global_feat.numpy(),'edge_index': self.edge_index.numpy()}def step(self, actions):# 模拟奖励计算 (实际应用中替换为真实市场数据)rewards = np.random.randn(N_STOCKS) * 0.1# 模拟新状态self.x = self.x + torch.randn_like(self.x) * 0.1self.global_feat = self.global_feat + torch.randn_like(self.global_feat) * 0.05# 构建返回数据obs = {'graph': self.x.numpy(),'global_feat': self.global_feat.numpy(),'edge_index': self.edge_index.numpy()}done = False  # 简化处理,实际需要定义终止条件info = {}  # 附加信息return obs, sum(rewards), done, info# ==================== REINFORCE 算法 ====================
class REINFORCE:def __init__(self, policy):self.policy = policy.to(DEVICE)self.optimizer = optim.Adam(self.policy.parameters(), lr=LR)self.writer = SummaryWriter()def select_action(self, state):graph_data = state['graph']edge_index = state['edge_index']global_feat = state['global_feat']# 转换为张量x_tensor = torch.tensor(graph_data, dtype=torch.float).to(DEVICE)edge_tensor = torch.tensor(edge_index, dtype=torch.long).to(DEVICE)global_tensor = torch.tensor(global_feat, dtype=torch.float).unsqueeze(0).to(DEVICE)# 前向传播action_logits, reason_logits = self.policy(x_tensor, edge_tensor, global_tensor)# 采样动作action_probs = F.softmax(action_logits, dim=-1)  # (1, N_STOCKS, ACTION_DIM)actions = torch.multinomial(action_probs.view(-1, ACTION_DIM), 1).view(1, N_STOCKS)# 采样原因reason_probs = F.softmax(reason_logits, dim=-1)  # (1, N_STOCKS, 5)reasons = torch.multinomial(reason_probs.view(-1, 5), 1).view(1, N_STOCKS)return actions.cpu().numpy()[0], reasons.cpu().numpy()[0], action_logitsdef update_policy(self, rewards, log_probs, entropies):returns = []R = 0# 计算折扣回报for r in reversed(rewards):R = r + GAMMA * Rreturns.insert(0, R)returns = torch.tensor(returns).to(DEVICE)returns = (returns - returns.mean()) / (returns.std() + 1e-8)  # 标准化# 计算损失policy_loss = []for log_prob, R in zip(log_probs, returns):policy_loss.append(-log_prob * R)# 修复: 使用stack代替cat处理标量张量policy_loss = torch.stack(policy_loss).sum()entropy_loss = torch.stack(entropies).sum()  # 熵正则化total_loss = policy_loss - 0.01 * entropy_loss# 反向传播self.optimizer.zero_grad()total_loss.backward()self.optimizer.step()return total_loss.item()def train(self, env, epochs=N_EPOCHS):for epoch in range(epochs):state = env.reset()done = Falserewards = []log_probs = []entropies = []step_count = 0max_steps = 10001  # 限制每轮最大步数以控制内存 #强化学习: 每轮步数 太大 会吃掉很多 显存 , 现在这样要500MB显存while not done and step_count < max_steps:step_count += 1# 选择动作actions, reasons, action_logits = self.select_action(state)# 执行动作next_state, reward, done, _ = env.step(actions)# 计算对数概率和熵action_probs = F.softmax(action_logits, dim=-1)  # (1, N_STOCKS, ACTION_DIM)# 确保索引张量与概率张量维度匹配action_indices = torch.tensor(actions).view(1, N_STOCKS, 1).to(DEVICE)# 收集所选动作的概率selected_probs = action_probs.gather(2, action_indices)  # (1, N_STOCKS, 1)# 计算对数概率 (对每个股票独立计算)log_prob = torch.log(selected_probs).sum()# 计算熵entropy = -(action_probs * torch.log(action_probs + 1e-8)).sum()# 存储数据rewards.append(reward)log_probs.append(log_prob)entropies.append(entropy)# 更新状态state = next_state# 更新策略loss = self.update_policy(rewards, log_probs, entropies)# TensorBoard记录self.writer.add_scalar('Loss/train', loss, epoch)self.writer.add_scalar('Reward/train', sum(rewards), epoch)print(f'Epoch {epoch+1}/{N_EPOCHS}, Loss: {loss:.4f}, Reward: {sum(rewards):.4f}')# 手动释放内存torch.cuda.empty_cache()def test(self, env, n_episodes=5):  # 减少测试轮数total_rewards = []for _ in range(n_episodes):state = env.reset()done = Falseepisode_reward = 0step_count = 0max_steps = 5  # 限制测试步数while not done and step_count < max_steps:step_count += 1actions, reasons, _ = self.select_action(state)next_state, reward, done, _ = env.step(actions)# 输出交易决策和原因for i, (action, reason) in enumerate(zip(actions, reasons)):action_str = ['持有', '买入', '卖出', '做空'][action]reason_str = ['技术指标看涨','行业趋势向好','政策利好','市场情绪积极','估值合理'][reason]# print(f"股票{i}: {action_str} - 原因: {reason_str}")episode_reward += rewardstate = next_statetotal_rewards.append(episode_reward)print(f'测试轮次奖励: {episode_reward:.4f}')# 手动释放内存torch.cuda.empty_cache()avg_reward = sum(total_rewards) / n_episodesprint(f'平均测试奖励: {avg_reward:.4f}')self.writer.add_scalar('Reward/test', avg_reward, 0)def visualize_model(self, dummy_input):# 修改为纯张量输入格式x, edge_index, global_feat = dummy_inputself.writer.add_graph(self.policy, (x, edge_index, global_feat))# ==================== 主程序 ====================
if __name__ == "__main__":# 初始化环境和策略env = StockTradingEnv()policy = GNNPolicy()agent = REINFORCE(policy)# 创建纯张量格式的虚拟输入dummy_x = torch.randn(N_STOCKS, N_FEATURES).to(DEVICE)dummy_edge_index = torch.randint(0, N_STOCKS, (2, 100)).to(DEVICE)dummy_global = torch.randn(1, GLOBAL_FEAT_DIM).to(DEVICE)# 可视化模型结构agent.visualize_model((dummy_x, dummy_edge_index, dummy_global))# 训练模型agent.train(env)# 测试模型agent.test(env)# 关闭TensorBoardagent.writer.close()
###
http://www.dtcms.com/a/296973.html

相关文章:

  • 模拟退火算法 (Simulated Annealing, SA)简介
  • JavaWeb学习打卡14(JSP内置对象及作用域)
  • ARM汇编常见伪指令及其用法示例
  • IntelliJ IDEA中管理多版本Git子模块的完整指南
  • 智慧工厂网络升级:新型 SD-WAN 技术架构与应用解析
  • 商场导航软件:3D+AI 基于Deepseek 模型的意图识别技术解析
  • BacNet 是什么?跟 LoRaWAN 的关系是什么?
  • 将JS字节流转化为对象
  • 西安交通大学XJTU 通信/信息工程大三和部分大四 实验和课程答案
  • C++哪些运算符不能被重载?
  • kubernetes集群中部署CoreDNS服务
  • day46day47 通道注意力
  • 一种基于单片机控制的太阳能电池板系统设计
  • 集训Demo6
  • 挖掘录屏宝藏:Screenity 深度解析与使用指南
  • 《计算机网络》实验报告八 加密、数字签名与证书
  • pytest测试框架
  • AUTOSAR进阶图解==>AUTOSAR_SWS_BSWGeneral
  • 【Vue学习笔记】状态管理:Pinia 与 Vuex 的使用方法与对比【附有完整案例】
  • 网络安全入门第一课:信息收集实战手册(2)
  • C语言-指针[变量指针与指针变量]
  • Java 集合框架之----ArrayList
  • Effective Modern C++ 条款16:保证const成员函数的线程安全性
  • 网址收集总结
  • 【硬件-笔试面试题】硬件/电子工程师,笔试面试题-17,(知识点:PCB布线,传输线阻抗影响因素)
  • 第一二章笔记
  • [ComfyUI] --ComfyUI 是什么?比 Stable Diffusion WebUI 强在哪?
  • 【STM32项目】智能台灯
  • 无人机保养指南
  • 深入解析Hadoop NameNode的Full GC问题、堆外内存泄漏及元数据分治策略