股票智能体系统的设计与开发
股票智能体系统的设计与开发
摘要
本文详细阐述了基于Python的股票智能体系统的设计与实现过程。该系统整合了数据获取、预处理、特征工程、机器学习模型训练、交易策略制定和风险控制等多个模块,构建了一个完整的自动化股票交易决策系统。系统采用强化学习框架,结合深度神经网络,能够从历史数据中学习有效的交易策略,并实时适应市场变化。本文涵盖了系统架构设计、关键技术实现、性能评估以及未来改进方向等内容,为量化交易领域的研究和实践提供了有价值的参考。
关键词:股票智能体、量化交易、强化学习、Python、机器学习
1. 引言
1.1 研究背景
随着人工智能技术的快速发展和金融市场数据的日益丰富,智能算法在金融投资领域的应用越来越广泛。传统的股票分析方法主要依赖于技术指标和基本面分析,而现代量化交易则利用计算机强大的数据处理能力,通过数学模型和统计方法挖掘市场规律。特别是近年来,深度学习和强化学习在金融时间序列预测和交易策略优化方面展现出巨大潜力。
1.2 研究意义
股票智能体系统的开发具有重要的理论和实践价值:
- 提高交易效率:自动化交易系统可以24小时不间断监控市场,快速响应交易信号,消除人为情绪干扰。
- 挖掘市场规律:通过机器学习算法可以从海量历史数据中发现人类难以识别的复杂模式。
- 风险控制:智能系统可以实时计算风险指标,严格执行止损策略,降低投资风险。
- 策略验证:系统可以在历史数据上进行回测,验证交易策略的有效性,避免实盘损失。
1.3 国内外研究现状
国外量化交易发展较早,对冲基金如文艺复兴科技、Two Sigma等早已将机器学习应用于实际交易。国内近年来也涌现出众多量化交易平台和研究团队。深度学习模型如LSTM、Transformer在股价预测中取得了一定效果,强化学习在组合优化和交易执行方面也有广泛应用。
2. 系统总体设计
2.1 系统架构
股票智能体系统采用模块化设计,主要包含以下组件:
+-------------------+ +-------------------+ +-------------------+
| 数据采集模块 | --> | 数据处理模块 | --> | 特征工程模块 |
+-------------------+ +-------------------+ +-------------------+|v
+-------------------+ +-------------------+ +-------------------+
| 模型训练模块 | <-- | 策略决策模块 | --> | 风险控制模块 |
+-------------------+ +-------------------+ +-------------------+|v+-----------+| 交易执行 |+-----------+
2.2 技术选型
- 编程语言:Python 3.8+(丰富的金融和机器学习生态系统)
- 数据获取:Tushare/AkShare(国内股票数据),yfinance(国际数据)
- 数据处理:Pandas, NumPy
- 机器学习:Scikit-learn, TensorFlow/PyTorch
- 强化学习:Stable Baselines3, Ray RLlib
- 可视化:Matplotlib, Plotly, PyQt5
- 回测框架:Backtrader, Zipline
- 并发处理:Asyncio, Celery
2.3 系统流程
- 数据采集与存储
- 数据清洗与预处理
- 特征提取与选择
- 模型训练与验证
- 策略生成与优化
- 风险控制与执行
- 绩效评估与反馈
3. 数据模块实现
3.1 数据获取
import akshare as ak
import pandas as pd
from datetime import datetime, timedeltaclass DataFetcher:def __init__(self):self.cache_dir = "./data_cache"os.makedirs(self.cache_dir, exist_ok=True)def get_stock_daily(self, symbol, start_date, end_date, adjust="hfq"):"""获取股票日线数据:param symbol: 股票代码,如"600519":param start_date: 开始日期,"YYYY-MM-DD":param end_date: 结束日期,"YYYY-MM-DD":param adjust: 复权类型,None不复权,"hfq"后复权,"qfq"前复权:return: DataFrame"""cache_file = f"{self.cache_dir}/{symbol}_{start_date}_{end_date}_{adjust}.pkl"if os.path.exists(cache_file):return pd.read_pickle(cache_file)try:df = ak.stock_zh_a_daily(symbol=symbol, start_date=start_date, end_date=end_date, adjust=adjust)df.to_pickle(cache_file)return dfexcept Exception as e:print(f"Error fetching data for {symbol}: {str(e)}")return Nonedef get_index_daily(self, symbol, start_date, end_date):"""获取指数数据"""# 实现类似股票数据的获取逻辑passdef get_financial_report(self, symbol, report_type="balance"):"""获取财务报表数据"""# 实现财务报表获取逻辑pass
3.2 数据预处理
class DataPreprocessor:@staticmethoddef clean_data(df):"""数据清洗"""# 处理缺失值df = df.dropna()# 处理异常值for col in ['open', 'high', 'low', 'close', 'volume']:df = df[(df[col] > 0) & (df[col] < df[col].quantile(0.99))]# 标准化日期索引df.index = pd.to_datetime(df.index)df = df.sort_index()return df@staticmethoddef add_technical_indicators(df):"""添加技术指标"""# 移动平均线df['ma5'] = df['close'].rolling(window=5).mean()df['ma10'] = df['close'].rolling(window=10).mean()df['ma20'] = df['close'].rolling(window=20).mean()# 布林带df['upper_band'], df['middle_band'], df['lower_band'] = \DataPreprocessor._bollinger_bands(df['close'])# MACDdf['macd'], df['macd_signal'], df['macd_hist'] = \DataPreprocessor._macd(df['close'])# RSIdf['rsi'] = DataPreprocessor._rsi(df['close'], period=14)return df.dropna()@staticmethoddef _bollinger_bands(series, window=20, num_std=2):rolling_mean = series.rolling(window=window).mean()rolling_std = series.rolling(window=window).std()upper_band = rolling_mean + (rolling_std * num_std)lower_band = rolling_mean - (rolling_std * num_std)return upper_band, rolling_mean, lower_band@staticmethoddef _macd(series, fast=12, slow=26, signal=9):ema_fast = series.ewm(span=fast, adjust=False).mean()ema_slow = series.ewm(span=slow, adjust=False).mean()macd_line = ema_fast - ema_slowsignal_line = macd_line.ewm(span=signal, adjust=False).mean()macd_hist = macd_line - signal_linereturn macd_line, signal_line, macd_hist@staticmethoddef _rsi(series, period=14):delta = series.diff(1)gain = delta.where(delta > 0, 0)loss = -delta.where(delta < 0, 0)avg_gain = gain.rolling(window=period).mean()avg_loss = loss.rolling(window=period).mean()rs = avg_gain / avg_lossrsi = 100 - (100 / (1 + rs))return rsi
3.3 特征工程
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.feature_selection import SelectKBest, f_regressionclass FeatureEngineer:def __init__(self, n_features=20):self.n_features = n_featuresself.scaler = StandardScaler()self.selector = SelectKBest(score_func=f_regression, k=n_features)self.selected_features = Nonedef fit_transform(self, X, y):"""特征选择和缩放"""# 填充缺失值X = X.fillna(method='ffill').fillna(method='bfill').fillna(0)# 特征缩放X_scaled = self.scaler.fit_transform(X)# 特征选择X_selected = self.selector.fit_transform(X_scaled, y)self.selected_features = X.columns[self.selector.get_support()]return X_selecteddef transform(self, X):"""应用训练好的特征工程"""X = X.fillna(method='ffill').fillna(method='bfill').fillna(0)X_scaled = self.scaler.transform(X)X_selected = self.selector.transform(X_scaled)return X_selecteddef get_feature_names(self):"""获取选择的特征名称"""return self.selected_features@staticmethoddef create_lagged_features(df, columns, lags=5):"""创建滞后特征"""new_df = df.copy()for col in columns:for lag in range(1, lags+1):new_df[f'{col}_lag_{lag}'] = df[col].shift(lag)return new_df.dropna()@staticmethoddef create_rolling_features(df, columns, windows=[5, 10, 20]):"""创建滚动统计特征"""new_df = df.copy()for col in columns:for window in windows:new_df[f'{col}_rolling_mean_{window}'] = df[col].rolling(window).mean()new_df[f'{col}_rolling_std_{window}'] = df[col].rolling(window).std()new_df[f'{col}_rolling_max_{window}'] = df[col].rolling(window).max()new_df[f'{col}_rolling_min_{window}'] = df[col].rolling(window).min()return new_df.dropna()
4. 模型模块实现
4.1 监督学习模型
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import xgboost as xgb
import lightgbm as lgbclass StockPredictor:def __init__(self, model_type='lgbm'):self.model_type = model_typeself.model = Noneself.feature_importance = Nonedef train(self, X_train, y_train, cv_folds=5):"""训练模型"""tscv = TimeSeriesSplit(n_splits=cv_folds)if self.model_type == 'random_forest':param_grid = {'n_estimators': [100, 200],'max_depth': [None, 10, 20],'min_samples_split': [2, 5]}base_model = RandomForestRegressor(random_state=42)elif self.model_type == 'xgboost':param_grid = {'n_estimators': [100, 200],'max_depth': [3, 6, 9],'learning_rate': [0.01, 0.1]}base_model = xgb.XGBRegressor(random_state=42)elif self.model_type == 'lgbm':param_grid = {'n_estimators': [100, 200],'max_depth': [5, 10],'learning_rate': [0.01, 0.1],'num_leaves': [31, 63]}base_model = lgb.LGBMRegressor(random_state=42)elif self.model_type == 'mlp':param_grid = {'hidden_layer_sizes': [(50,), (100,), (50, 50)],'activation': ['relu', 'tanh'],'learning_rate_init': [0.001, 0.01]}base_model = MLPRegressor(random_state=42, max_iter=1000)grid_search = GridSearchCV(estimator=base_model, param_grid=param_grid,cv=tscv, scoring='neg_mean_squared_error',n_jobs=-1, verbose=1)grid_search.fit(X_train, y_train)self.model = grid_search.best_estimator_# 保存特征重要性if hasattr(self.model, 'feature_importances_'):self.feature_importance = dict(zip(X_train.columns,self.model.feature_importances_))return grid_search.best_score_def predict(self, X):"""预测"""return self.model.predict(X)def evaluate(self, X_test, y_test):"""评估模型"""y_pred = self.predict(X_test)metrics = {'mse': mean_squared_error(y_test, y_pred),'mae': mean_absolute_error(y_test, y_pred),'r2': r2_score(y_test, y_pred)}return metrics
4.2 强化学习环境
import gym
from gym import spaces
import numpy as npclass StockTradingEnv(gym.Env):metadata = {'render.modes': ['human']}def __init__(self, df, initial_balance=100000, commission=0.0025):super(StockTradingEnv, self).__init__()self.df = dfself.current_step = 0self.initial_balance = initial_balanceself.commission = commission# 动作空间: 买入(0), 持有(1), 卖出(2)self.action_space = spaces.Discrete(3)# 状态空间: 技术指标 + 持仓信息self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(len(self.df.columns) + 3,), # 技术指标 + 余额,持仓,成本dtype=np.float32)self.reset()def reset(self):"""重置环境"""self.balance = self.initial_balanceself.shares_held = 0self.avg_cost = 0self.total_profit = 0self.current_step = 0return self._next_observation()def _next_observation(self):"""获取下一个观察状态"""obs = self.df.iloc[self.current_step].values# 添加账户信息account_info = np.array([self.balance,self.shares_held,self.avg_cost])return np.append(obs, account_info)def _take_action(self, action):"""执行交易动作"""current_price = self.df.iloc[self.current_step]['close']if action == 0: # 买入# 计算可买数量available_amount = self.balance / (current_price * (1 + self.commission))shares_bought = min(available_amount, available_amount) # 全仓买入# 更新账户cost = shares_bought * current_price * (1 + self.commission)self.balance -= costself.avg_cost = (self.avg_cost * self.shares_held + current_price * shares_bought) / (self.shares_held + shares_bought)self.shares_held += shares_boughtelif action == 2: # 卖出if self.shares_held > 0:# 卖出全部持仓revenue = self.shares_held * current_price * (1 - self.commission)self.balance += revenueself.total_profit += revenue - (self.shares_held * self.avg_cost)self.shares_held = 0self.avg_cost = 0def step(self, action):"""执行一步动作"""self._take_action(action)self.current_step += 1# 检查是否结束done = self.current_step >= len(self.df) - 1# 计算奖励current_price = self.df.iloc[self.current_step]['close']portfolio_value = self.balance + self.shares_held * current_pricereward = portfolio_value - self.initial_balance# 添加惩罚项if action == 0 and self.balance < 0:reward -= 1000elif action == 2 and self.shares_held < 0:reward -= 1000obs = self._next_observation()return obs, reward, done, {'portfolio_value': portfolio_value}def render(self, mode='human'):"""渲染环境状态"""current_price = self.df.iloc[self.current_step]['close']portfolio_value = self.balance + self.shares_held * current_priceprint(f"Step: {self.current_step}")print(f"Balance: {self.balance:.2f}")print(f"Shares held: {self.shares_held} (Avg Cost: {self.avg_cost:.2f})")print(f"Current Price: {current_price:.2f}")print(f"Portfolio Value: {portfolio_value:.2f}")print(f"Total Profit: {self.total_profit:.2f}")
4.3 强化学习智能体
from stable_baselines3 import PPO, A2C, DDPG
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnvclass TradingAgent:def __init__(self, env, algorithm='ppo'):self.env = DummyVecEnv([lambda: env])self.algorithm = algorithmif algorithm == 'ppo':self.model = PPO('MlpPolicy', self.env, verbose=1,learning_rate=3e-4,n_steps=2048,batch_size=64,n_epochs=10,gamma=0.99,gae_lambda=0.95,clip_range=0.2,ent_coef=0.0)elif algorithm == 'a2c':self.model = A2C('MlpPolicy', self.env, verbose=1,learning_rate=7e-4,n_steps=5,gamma=0.99,gae_lambda=1.0,ent_coef=0.0)elif algorithm == 'ddpg':self.model = DDPG('MlpPolicy', self.env, verbose=1,learning_rate=1e-3,buffer_size=100000,batch_size=100,gamma=0.99,tau=0.005)def train(self, total_timesteps=100000, callback=None):"""训练智能体"""self.model.learn(total_timesteps=total_timesteps, callback=callback)def predict(self, obs):"""预测动作"""action, _ = self.model.predict(obs, deterministic=True)return actiondef save(self, path):"""保存模型"""self.model.save(path)def load(self, path):"""加载模型"""if self.algorithm == 'ppo':self.model = PPO.load(path, env=self.env)elif self.algorithm == 'a2c':self.model = A2C.load(path, env=self.env)elif self.algorithm == 'ddpg':self.model = DDPG.load(path, env=self.env)class TensorboardCallback(BaseCallback):def __init__(self, verbose=0):super(TensorboardCallback, self).__init__(verbose)self.portfolio_values = []def _on_step(self) -> bool:# 记录投资组合价值portfolio_value = self.training_env.get_attr('portfolio_value')[0]self.logger.record('portfolio/value', portfolio_value)# 记录奖励reward = self.locals['rewards'][0]self.logger.record('portfolio/reward', reward)return True
5. 交易策略模块
5.1 策略基类
from abc import ABC, abstractmethodclass TradingStrategy(ABC):def __init__(self, data_handler):self.data_handler = data_handlerself.positions = {}self.cash = 100000 # 初始资金self.portfolio_value = []self.trade_history = []@abstractmethoddef generate_signals(self):"""生成交易信号"""passdef execute_trades(self, signals):"""执行交易"""for symbol, signal in signals.items():current_price = self.data_handler.get_latest_price(symbol)if signal == 'BUY' and symbol not in self.positions:# 计算可买数量shares = int(self.cash * 0.1 / current_price) # 每次用10%资金买入if shares > 0:cost = shares * current_priceself.cash -= costself.positions[symbol] = {'shares': shares,'avg_price': current_price}self.trade_history.append({'symbol': symbol,'action': 'BUY','shares': shares,'price': current_price,'timestamp': self.data_handler.get_latest_timestamp()})elif signal == 'SELL' and symbol in self.positions:position = self.positions[symbol]revenue = position['shares'] * current_priceself.cash += revenueself.trade_history.append({'symbol': symbol,'action': 'SELL','shares': position['shares'],'price': current_price,'timestamp': self.data_handler.get_latest_timestamp(),'profit': revenue - (position['shares'] * position['avg_price'])})del self.positions[symbol]def update_portfolio_value(self):"""更新投资组合价值"""positions_value = sum(pos['shares'] * self.data_handler.get_latest_price(sym)for sym, pos in self.positions.items())total_value = self.cash + positions_valueself.portfolio_value.append(total_value)return total_value
5.2 均值回归策略
class MeanReversionStrategy(TradingStrategy):def __init__(self, data_handler, lookback=20, z_threshold=2.0):super().__init__(data_handler)self.lookback = lookbackself.z_threshold = z_thresholddef generate_signals(self):signals = {}for symbol in self.data_handler.symbols:prices = self.data_handler.get_historical_prices(symbol, self.lookback)if len(prices) < self.lookback:continuecurrent_price = prices[-1]mean_price = np.mean(prices[:-1])std_price = np.std(prices[:-1])if std_price == 0:continuez_score = (current_price - mean_price) / std_priceif z_score < -self.z_threshold:signals[symbol] = 'BUY'elif z_score > self.z_threshold and symbol in self.positions:signals[symbol] = 'SELL'else:signals[symbol] = 'HOLD'return signals
5.3 动量策略
class MomentumStrategy(TradingStrategy):def __init__(self, data_handler, lookback=20, hold_period=5):super().__init__(data_handler)self.lookback = lookbackself.hold_period = hold_periodself.holding_periods = {}def generate_signals(self):signals = {}# 更新持仓时间for symbol in list(self.holding_periods.keys()):self.holding_periods[symbol] += 1if self.holding_periods[symbol] >= self.hold_period:signals[symbol] = 'SELL'del self.holding_periods[symbol]for symbol in self.data_handler.symbols:if symbol in self.positions:continueprices = self.data_handler.get_historical_prices(symbol, self.lookback)if len(prices) < self.lookback:continuereturns = np.diff(prices) / prices[:-1]momentum = np.prod(1 + returns) - 1if momentum > 0.1: # 10%动量阈值signals[symbol] = 'BUY'self.holding_periods[symbol] = 0return signals
5.4 机器学习策略
class MLBasedStrategy(TradingStrategy):def __init__(self, data_handler, model, threshold=0.5):super().__init__(data_handler)self.model = modelself.threshold = thresholdself.prediction_history = []def generate_signals(self):signals = {}latest_features = self.data_handler.get_latest_features()for symbol, features in latest_features.items():# 预测未来收益率prediction = self.model.predict([features])[0]self.prediction_history.append(prediction)if prediction > self.threshold and symbol not in self.positions:signals[symbol] = 'BUY'elif prediction < -self.threshold and symbol in self.positions:signals[symbol] = 'SELL'else:signals[symbol] = 'HOLD'return signals
6. 风险控制模块
class RiskManager:def __init__(self, max_position_size=0.1, max_portfolio_risk=0.2, stop_loss=0.05):""":param max_position_size: 单个头寸最大比例:param max_portfolio_risk: 组合最大风险值:param stop_loss: 止损比例"""self.max_position_size = max_position_sizeself.max_portfolio_risk = max_portfolio_riskself.stop_loss = stop_lossdef check_position_size(self, strategy, symbol, price, quantity):"""检查头寸规模是否合规"""position_value = price * quantityportfolio_value = strategy.update_portfolio_value()return position_value <= portfolio_value * self.max_position_sizedef check_stop_loss(self, strategy):"""检查止损条件"""symbols_to_sell = []for symbol, position in strategy.positions.items():current_price = strategy.data_handler.get_latest_price(symbol)loss = (current_price - position['avg_price']) / position['avg_price']if loss <= -self.stop_loss:symbols_to_sell.append(symbol)return symbols_to_selldef calculate_var(self, portfolio_returns, confidence_level=0.95):"""计算风险价值(VaR)"""if len(portfolio_returns) < 50: # 至少需要50个数据点return 0return -np.percentile(portfolio_returns, 100 * (1 - confidence_level))def calculate_max_drawdown(self, portfolio_values):"""计算最大回撤"""peak = -np.infmax_drawdown = 0for value in portfolio_values:if value > peak:peak = valuedrawdown = (peak - value) / peakif drawdown > max_drawdown:max_drawdown = drawdownreturn max_drawdown
7. 回测与评估
7.1 回测引擎
class BacktestEngine:def __init__(self, data_handler, strategy, risk_manager):self.data_handler = data_handlerself.strategy = strategyself.risk_manager = risk_managerself.results = Nonedef run(self, start_date, end_date):"""运行回测"""dates = self.data_handler.get_trading_dates(start_date, end_date)portfolio_values = []returns = []for date in dates:# 更新数据self.data_handler.update(date)# 生成信号signals = self.strategy.generate_signals()# 风险检查stop_loss_symbols = self.risk_manager.check_stop_loss(self.strategy)for symbol in stop_loss_symbols:signals[symbol] = 'SELL'# 执行交易self.strategy.execute_trades(signals)# 更新投资组合价值current_value = self.strategy.update_portfolio_value()portfolio_values.append(current_value)# 计算日收益率if len(portfolio_values) > 1:daily_return = (portfolio_values[-1] - portfolio_values[-2]) / portfolio_values[-2]returns.append(daily_return)# 保存结果self.results = {'dates': dates,'portfolio_values': portfolio_values,'returns': returns,'trades': self.strategy.trade_history}return self.resultsdef evaluate(self):"""评估回测结果"""if not self.results:raise ValueError("Backtest not run yet")returns = np.array(self.results['returns'])portfolio_values = np.array(self.results['portfolio_values'])trades = self.results['trades']# 基本指标total_return = (portfolio_values[-1] - portfolio_values[0]) / portfolio_values[0]annualized_return = (1 + total_return) ** (252 / len(portfolio_values)) - 1volatility = np.std(returns) * np.sqrt(252)sharpe_ratio = annualized_return / volatility if volatility != 0 else 0# 风险指标var_95 = self.risk_manager.calculate_var(returns, 0.95)max_drawdown = self.risk_manager.calculate_max_drawdown(portfolio_values)# 交易统计winning_trades = [t for t in trades if 'profit' in t and t['profit'] > 0]losing_trades = [t for t in trades if 'profit' in t and t['profit'] <= 0]win_rate = len(winning_trades) / len(trades) if len(trades) > 0 else 0avg_win = np.mean([t['profit'] for t in winning_trades]) if winning_trades else 0avg_loss = np.mean([t['profit'] for t in losing_trades]) if losing_trades else 0profit_factor = -avg_win * len(winning_trades) / (avg_loss * len(losing_trades)) if losing_trades else np.infmetrics = {'total_return': total_return,'annualized_return': annualized_return,'volatility': volatility,'sharpe_ratio': sharpe_ratio,'var_95': var_95,'max_drawdown': max_drawdown,'num_trades': len(trades),'win_rate': win_rate,'avg_win': avg_win,'avg_loss': avg_loss,'profit_factor': profit_factor}return metrics
7.2 可视化
import matplotlib.pyplot as plt
import seaborn as snsclass Visualizer:@staticmethoddef plot_portfolio(backtest_results):"""绘制投资组合价值曲线"""dates = backtest_results['dates']values = backtest_results['portfolio_values']plt.figure(figsize=(12, 6))plt.plot(dates, values)plt.title('Portfolio Value Over Time')plt.xlabel('Date')plt.ylabel('Portfolio Value')plt.grid(True)plt.show()@staticmethoddef plot_drawdown(backtest_results):"""绘制回撤曲线"""dates = backtest_results['dates']values = backtest_results['portfolio_values']# 计算回撤peak = -np.infdrawdowns = []for value in values:if value > peak:peak = valuedrawdown = (peak - value) / peakdrawdowns.append(drawdown)plt.figure(figsize=(12, 6))plt.plot(dates, drawdowns)plt.title('Drawdown Over Time')plt.xlabel('Date')plt.ylabel('Drawdown')plt.grid(True)plt.show()@staticmethoddef plot_returns_distribution(backtest_results):"""绘制收益率分布"""returns = backtest_results['returns']plt.figure(figsize=(12, 6))sns.histplot(returns, kde=True, bins=50)plt.title('Distribution of Daily Returns')plt.xlabel('Daily Return')plt.ylabel('Frequency')plt.grid(True)plt.show()
8. 系统集成与部署
8.1 主控制系统
import time
from datetime import datetimeclass TradingSystem:def __init__(self, config):self.config = configself.data_fetcher = DataFetcher()self.data_handler = Noneself.strategy = Noneself.risk_manager = Noneself.backtest_engine = Noneself.realtime_trading = Falseself.last_update = Nonedef initialize(self, mode='backtest'):"""初始化系统"""# 初始化数据处理模块symbols = self.config['symbols']start_date = self.config['start_date']end_date = self.config['end_date']# 获取并预处理数据all_data = {}for symbol in symbols:df = self.data_fetcher.get_stock_daily(symbol, start_date, end_date)if df is not None:df = DataPreprocessor.clean_data(df)df = DataPreprocessor.add_technical_indicators(df)all_data[symbol] = dfself.data_handler = DataHandler(all_data)# 初始化策略if self.config['strategy'] == 'mean_reversion':self.strategy = MeanReversionStrategy(self.data_handler,lookback=self.config.get('lookback', 20),z_threshold=self.config.get('z_threshold', 2.0))elif self.config['strategy'] == 'momentum':self.strategy = MomentumStrategy(self.data_handler,lookback=self.config.get('lookback', 20),hold_period=self.config.get('hold_period', 5))elif self.config['strategy'] == 'ml_based':model = self.load_model(self.config['model_path'])self.strategy = MLBasedStrategy(self.data_handler,model=model,threshold=self.config.get('threshold', 0.5))# 初始化风险管理self.risk_manager = RiskManager(max_position_size=self.config.get('max_position_size', 0.1),max_portfolio_risk=self.config.get('max_portfolio_risk', 0.2),stop_loss=self.config.get('stop_loss', 0.05))# 初始化回测引擎self.backtest_engine = BacktestEngine(self.data_handler,self.strategy,self.risk_manager)if mode == 'realtime':self.realtime_trading = Truedef run_backtest(self):"""运行回测"""if not self.backtest_engine:raise ValueError("Backtest engine not initialized")results = self.backtest_