当前位置: 首页 > news >正文

强化学习DQN解决Cart_Pole问题

CartPole 环境简介

CartPole 是强化学习领域的一个经典测试环境,最早由 OpenAI 的 Gym 库引入,如今在 Gymnasium(Gym 的继任者)中仍然被广泛使用。

该环境的核心任务是:
一根竖直的杆子通过一个铰接点连接在小车上,小车可以在一维轨道上左右移动。智能体(agent)的目标是通过控制小车向左或向右的动作,保持杆子不倒下。

环境设定

  • 状态空间(observation):环境在每个时刻都会返回一个长度为 4 的实数向量,包含:

    1. 小车位置
    2. 小车速度
    3. 杆子与竖直方向的夹角
    4. 杆子角速度
  • 动作空间(action):离散的两个动作:

    • 0:小车向左移动
    • 1:小车向右移动
  • 奖励函数(reward)
    每一步只要杆子没有倒下,智能体就会得到 +1 的奖励。

  • 终止条件(done)
    当杆子与竖直方向偏离超过一定角度,或者小车位置超出轨道边界时,游戏结束。

代码结构

📦 根目录
├── 📂 agent_dqn
│ ├── 📂 algorithm
│   └── 📄 init.py
│   └── 📄 algorithm.py
│ ├── 📂 conf
│   └── 📄 init.py
│   └── 📄 conf.py
│ ├── 📂 feature
│   └── 📄 init.py
│   └── 📄 monitor.py
│   └── 📄 processor.py
│ ├── 📂 model
│   └── 📄 init.py
│   └── 📄 model.py
│ ├── 📂 workflow
│   └── 📄 init.py
│   └── 📄 train_workflow.py
│ ├── 📄 init.py
│ └── 📄 agent.py
└── 📄 train_test.py

algorithm

import math
import random
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from cart_pole.agent_dqn.conf.conf import Config
from cart_pole.agent_dqn.model.model import DQN
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.feature.processor import Processorclass Algorithm:def __init__(self, device, monitor: Monitor):self.device = deviceself.monitor = monitorself.capacity = Config.MEMORY_SIZEself.memory = []self.push_count = 0self.epsilon = Config.EPSILON_MAXself.epsilon_max = Config.EPSILON_MAXself.epsilon_min = Config.EPSILON_MINself.epsilon_decay = Config.EPSILON_DECAY# 初始化策略网络self.model = DQN(Config.DIM_OF_OBSERVATION, Config.DIM_OF_ACTION).to(device)# 初始化目标网络self.target_model = DQN(Config.DIM_OF_OBSERVATION, Config.DIM_OF_ACTION).to(device)# 更新目标网络self.target_model.load_state_dict(self.model.state_dict())self.target_model.eval()# 设置优化器self.optimizer = optim.Adam(params=self.model.parameters(), lr=Config.LR)self.predict_count = 0self.train_count = 0def memory_push(self, experience) -> None:"""| This function responsible for adding experience to the| memory. Also used for sampling experiences from replay memory.IF memory less than memory initialied capacity,we're going to append inside the memoryIF NOTwe're going to begin push new experience onto the frontof memory overwriting the oldest experience.Args:experience"""if len(self.memory) < self.capacity:self.memory.append(experience)else:self.memory[self.push_count % self.capacity] = experienceself.push_count += 1def sample(self, batch_size: int):"""Sample is equal to the `batch_size` sent to this function`"""return random.sample(self.memory, batch_size)def can_provide_sample(self, batch_size: int) -> bool:"""是否可以开始采样:param batch_size::return:"""return len(self.memory) >= batch_sizedef learn(self, list_sample_data):# 将数据处理为tensorstates, actions, next_states, dones, rewards = Processor.extract_tensors(list_sample_data, self.device)# 由target_network得到target_q值self.target_model.eval()with torch.no_grad():final_states_location = next_states.flatten(start_dim=1) \.max(dim=1)[0].eq(0).type(torch.bool)non_final_states_locations = (final_states_location == False)non_final_states = next_states[non_final_states_locations]batch_size = next_states.shape[0]values = torch.zeros(batch_size).to(self.device)values[non_final_states_locations] = self.target_model(non_final_states).max(dim=1)[0].detach()target_q_values = rewards + (Config.GAMMA * values)# 得到estimate_network q值current_q_values = self.model(states).gather(dim=1, index=actions)# 计算lossloss = F.mse_loss(current_q_values, target_q_values).to(self.device)# 计算梯度loss.backward()# 梯度更新self.optimizer.step()# 梯度清0self.optimizer.zero_grad()self.train_count += 1# 更新target_networkif self.train_count % Config.TARGET_UPDATE_INTERVAL == 0:self.update_target_q()# 数据上传监控# 监控lossif self.train_count % Config.LOG_UPDATE_INTERVAL == 0:self.monitor.add_loss_info(loss.detach().item())def predict(self, obs, exploit_flag=False):self.epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) *\math.exp(-1. * self.predict_count * self.epsilon_decay)# 更新当前运行步数self.predict_count += 1# 选择动作if not exploit_flag and np.random.rand() < self.epsilon:return np.random.randint(Config.DIM_OF_ACTION)else:with torch.no_grad():obs = torch.FloatTensor(obs).unsqueeze(0).to(self.device)q_values = self.model(obs)return q_values.argmax().item()def update_target_q(self):self.target_model.load_state_dict(self.model.state_dict())

conf

from collections import namedtupleclass Config:DIM_OF_OBSERVATION = 4DIM_OF_ACTION = 2EPSILON_MAX = 1EPSILON_MIN = 0.01EPSILON_DECAY = 0.001GAMMA = 0.999LR = 0.001SEED = 234MEMORY_SIZE = 100000NUM_EPISODES = 1000TARGET_UPDATE_INTERVAL = 10LOG_UPDATE_INTERVAL = 1BATCH_SIZE = 256Experience = namedtuple('Experience',('state', 'action', 'next_state', 'done', 'reward'))ENV_RENDER_MODE = 'rgb_array'NUM_FOURIER_BASE = 1

processor

import torch
import numpy as np
from typing import NamedTuple
from cart_pole.agent_dqn.conf.conf import Configclass Processor:@staticmethoddef extract_tensors(experiences: NamedTuple, device):"""| accepts a batch of Experiences and first transposesit into an Experience of batches."""# Convert batch of Experiences to Experience of batchesbatch = Config.Experience(*zip(*experiences))t_states = torch.tensor(batch.state).to(device)t_actions = torch.tensor(batch.action).unsqueeze(-1).to(device)t_next_state = torch.tensor(batch.next_state).to(device)t_rewards = torch.tensor(batch.reward).unsqueeze(-1).to(device)t_dones = torch.tensor(batch.done).float().unsqueeze(-1).to(device)return t_states, t_actions, t_next_state, t_dones, t_rewards

monitor

import matplotlib.pyplot as plt
import seaborn as sns
import warningswarnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题class Monitor:def __init__(self):self.loss_log = []self.epsilon_log = []self.reward_log = []self.episode_duration_log = []def add_loss_info(self, loss):"""向监视器添加新的loss信息"""self.loss_log.append(loss)def add_epsilon_info(self, epsilon):"""向监视器添加新的epsilon信息"""self.epsilon_log.append(epsilon)def add_reward_info(self, reward):"""向监视器添加新一轮episode的reward信息"""self.reward_log.append(reward)def add_duration_info(self, duration):"""向监视器添加新的epsilon信息"""self.episode_duration_log.append(duration)def plot_loss(self):"""绘制loss曲线"""plt.figure()plt.plot(self.loss_log)plt.xlabel('迭代次数')plt.ylabel('loss')plt.title('TD error/ loss')plt.show()def plot_epsilon(self):"""绘制epsilon曲线"""plt.figure()plt.plot(self.epsilon_log)plt.xlabel('episode')plt.ylabel('epsilon')plt.title('Epsilon Variation with Episode')plt.show()def plot_reward(self):"""绘制epsilon曲线"""plt.figure()plt.plot(self.reward_log)plt.xlabel('episode')plt.ylabel('reward')plt.title('Reward Variation with Episode')plt.show()def plot_all_log(self):"""在同一画布上绘制loss曲线、回合步长曲线和epsilon曲线"""fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 6))# 绘制loss曲线ax1.plot(self.loss_log)ax1.set_xlabel('Iteration')ax1.set_ylabel('TD error/ loss')ax1.set_title('Loss Variation with Iteration')ax1.grid(True)# 绘制回合步长曲线ax2.plot(self.episode_duration_log)ax2.set_xlabel('episode')ax2.set_ylabel('step')ax2.set_title('Step Variation with Episode')ax2.grid(True)# 绘制epsilon曲线ax3.plot(self.epsilon_log)ax3.set_xlabel('episode')ax3.set_ylabel('epsilon')ax3.set_title('Epsilon Variation with Episode')ax3.grid(True)plt.tight_layout()plt.show()

model

import torch.nn as nn
import torch.nn.functional as Fclass DQN(nn.Module):def __init__(self, num_state_features, num_actions):super().__init__()# Initialize our layers# self.fc1 = nn.Linear(in_features=img_height*img_width*3,#                      out_features=24)self.fc1 = nn.Linear(in_features=num_state_features,out_features=32)self.fc2 = nn.Linear(in_features=32,out_features=64)self.fc3 = nn.Linear(in_features=64,out_features=128)self.out = nn.Linear(in_features=128,out_features=num_actions)  # Back to the Project overview, you can# see that total possible movements# the object can do is (<left, right>)def forward(self, t):# No Longer flatten the input# t = t.flatten(start_dim=1) # starting from the channel matrics instead of batchest = F.relu(self.fc1(t))t = F.relu(self.fc2(t))t = F.relu(self.fc3(t))# t = F.relu(self.out(t))t = self.out(t)return t

train_workflow

import timefrom tqdm import tqdmfrom cart_pole.agent_dqn.agent import Agent
from cart_pole.agent_dqn.conf.conf import Config
from itertools import countdef run_episodes(num_episodes, env, agent: Agent, exploit_flag=False):for episode in tqdm(range(num_episodes)):# 重置任务,获取初始状态state = env.reset(seed=Config.SEED)[0]for duration in count():action = agent.algorithm.predict(state)next_state, reward, terminated, truncated, info = env.step(action)done = terminated or truncatedagent.algorithm.memory_push(Config.Experience(state, action, next_state, done, reward))state = next_stateif done:agent.monitor.add_duration_info(duration)breakif agent.algorithm.can_provide_sample(Config.BATCH_SIZE):sample_data = agent.algorithm.sample(Config.BATCH_SIZE)agent.learn(sample_data)

agent

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from cart_pole.agent_dqn.algorithm.algorithm import Algorithm
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.feature.processor import Processorwarnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题class Agent:def __init__(self, device, monitor: Monitor):self.device = deviceself.monitor = monitorself.algorithm = Algorithm(device, monitor)def predict(self, obs, exploit_flag = False):return self.algorithm.predict(obs, exploit_flag = exploit_flag)def learn(self, list_sample_data):self.algorithm.learn(list_sample_data)def save_model(self, path=None, id="1"):passdef load_model(self, path=None, id="1"):pass

train_test

import torch
from cart_pole.agent_dqn.agent import Agent
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.workflow.train_workflow import *
import gymnasium as gymif __name__ == "__main__":env = gym.make('CartPole-v1', render_mode="rgb_array").unwrappedmonitor = Monitor()# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')device = torch.device("cpu")agent = Agent(device, monitor)run_episodes(2000, env, agent)monitor.plot_all_log()

文章转载自:

http://GRolOZ5r.tpssx.cn
http://MUgLITzk.tpssx.cn
http://gb2wsJsH.tpssx.cn
http://CY6Qi7dw.tpssx.cn
http://ROCEUJtC.tpssx.cn
http://V4tGE2ji.tpssx.cn
http://75xE7mQ3.tpssx.cn
http://wC08ZFYZ.tpssx.cn
http://KRA96NTt.tpssx.cn
http://Ok1xpP71.tpssx.cn
http://qk9Af2xj.tpssx.cn
http://On9OOZY4.tpssx.cn
http://v2fmaP7V.tpssx.cn
http://KJP1CzVQ.tpssx.cn
http://4syvrkz8.tpssx.cn
http://4HXpjVi5.tpssx.cn
http://gLoMjcmk.tpssx.cn
http://FtbS1GSi.tpssx.cn
http://t1aJMxWc.tpssx.cn
http://TGQzJlra.tpssx.cn
http://1Hpg8wIR.tpssx.cn
http://VDJCNj4H.tpssx.cn
http://UfbCf9Lg.tpssx.cn
http://qCVL9YlR.tpssx.cn
http://wbtPiWeV.tpssx.cn
http://UXasTBW6.tpssx.cn
http://buXW2ZUT.tpssx.cn
http://PynIlyYv.tpssx.cn
http://lJyXX5p5.tpssx.cn
http://9Xsyj2Vn.tpssx.cn
http://www.dtcms.com/a/366487.html

相关文章:

  • Cursor 辅助开发:快速搭建 Flask + Vue 全栈 Demo 的实战记录
  • 【Spring Cloud Alibaba】Sentinel(一)
  • Java开发中的依赖环境管理
  • Ubuntu 使用 Samba 共享文件夹
  • HCIA备考:常见路由协议及特点
  • 【LeetCode热题100道笔记】缺失的第一个正数
  • List<?>和List<Object>区别
  • 【开题答辩全过程】以 基于微信小程序的宠物领养系统为例,包含答辩的问题和答案
  • 近期算法学习记录
  • UE4调试UAT时为何断点不了BuildCookRun的部分代码
  • MySQL 时间函数全解析:从 NOW() 到 UTC_TIMESTAMP() 的深度实践与选择策略
  • vscode launch.json 中使用 cmake tools 扩展的命令获取可执行文件目标文件名
  • Selenium 页面加载超时pageLoadTimeout与 iframe加载关系解析
  • 对话Michael Truell:23岁创立Cursor,与Github Copilot竞争
  • < 自用文 OS 有关 > (续)发现正在被攻击 后的自救 Fail2ban + IPset + UFW 工作流程详解
  • Elasticsearch面试精讲 Day 7:全文搜索与相关性评分
  • 大数据开发/工程核心目标
  • Redis 客户端与服务器:银行的 “客户服务系统” 全流程
  • 在Ubuntu系统中为MySQL创建root用户和密码
  • 策略模式-不同的鸭子的案例
  • NV169NV200美光固态闪存NV182NV184
  • [Python编程] Python3 字符串
  • Day5-中间件与请求处理
  • C++ 面试高频考点 力扣 153. 寻找旋转排序数组中的最小值 二分查找 题解 每日一题
  • C++ opencv+gstreamer编译,C++ opencv4.5.5+gstreamer1.0 -1.24.12 编译 ,cmake 4.0.0
  • 新手向:AI IDE+AI 辅助编程
  • 2025年直播电商系统源码趋势解析:AI、推荐算法与多端融合
  • 存储卷快照管理针对海外vps数据保护的配置流程
  • 内网穿透的应用-小白也能建博客:Halo+cpolar让个人网站从梦想变现实
  • 25高教社杯数模国赛【C题顶流思路+问题解析】第三弹