当前位置: 首页 > news >正文

RL代码实践 02——策略迭代

目录

一、问题描述

二、问题分析和解决

1、策略迭代算法


一、问题描述

悬崖漫步

灰色格子代表悬崖,一旦进入就游戏失败;

绿色格子是终点,一旦进入就游戏成功;

白色格子是通路。

 

已知对于每个格子,可选的动作有4种,分别往上下左右走。

奖惩规则如下:

  • 普通步长:-1(鼓励少走步)
  • 撞墙:-2(略高于普通步长,减少撞墙)
  • 终点:+50,悬崖陷阱:-100(明确目标和风险)

求每个格子的策略,使可以找到通往终点的最优路径。

二、问题分析和解决

这是一个有模型的情况:

对于特定的状态(格子)s,采取特定的动作a后,能到达的下一状态已知且概率已知,

获得的奖励已知且概率已知。

 

可以用策略迭代算法或者值迭代算法

1、策略迭代算法

首先初始化策略,主要分为两个阶段:

(1)策略评估(Policy Evaluation)

需要计算出各个状态的value。

有两个方法:一个是根据贝尔曼公式求解,另一个是用迭代求解。

这里用迭代算法求解,直到value收敛(np.abs(values - old_values) < theta))才结束。

具体来说,

即对于第k轮迭代,

某状态的 state value = 各个动作的q值 * 各个动作的概率 之和,

而q值  action value =  reward + next_state_value * 0.9。

注意此时next_state_value是上一轮,即k-1轮的值。

(2)策略改进(Policy Improvement)

get_pi() 使用贪心策略(np.argmax)选择每个状态的最优动作,生成确定性策略(概率1赋予最优动作)。

也可以采用用随机性策略:若多个动作的 Q 值相同,可均分概率。(但是值迭代算法一般默认greedy,不用随机性策略)

 

外层循环交替执行策略评估和策略改进,直到策略稳定(pi == old_pi)。

# 悬崖漫步
import numpy as np# 设置(获取)格子状态
def get_state(row, col):if row!=3 or col==0:return 'ground' # 通路if row==3 and col==11:return 'terminal' # 终点return 'trap' # 悬崖for row in range(4):for col in range(12):if get_state(row, col)=='ground':print('o', end=' ') # 不换行,用空字符结尾elif get_state(row,col)=='terminal':print('p', end=' ')else:print('x', end=' ')print()# 在特定s做特定a,求得到的下一s和r
def move(row, col, action):# 如果当前状态已经是掉进悬崖或者到达终点,直接返回(因为游戏结束了,不会再有状态转移)if get_state(row,col) in ['terminal', 'trap']:return row, col, 0  # 让它待在原地不能移动,原地不动没有奖惩# 状态转移if action == 0: # 向上走row-=1elif action == 1: # 向下走row+=1elif action == 2: # 向左走col-=1elif action == 3: # 向右走col+=1# 注意限制不能走出地图外面去out = 0 # 标记是否出界if row<0 or row>3 or col<0 or col>11:out = 1row = max(0,row)row = min(3,row)col = max(0,col)col = min(11,col)# 获得奖励reward = -1 # 普通步长if get_state(row, col)=='trap':reward = -100 # 陷阱if get_state(row, col)=='terminal':reward = 50 # 终点if out==1:reward = -2 # 出界(撞墙)return row, col, reward# 初始化state value table
values = np.zeros((4,12))
# 初始化q-table
q_table = np.zeros((4,12,4))
# 初始化策略(每个格子下采取动作的概率)
pi = np.ones((4,12,4))*0.25# 计算q
def get_q(row, col, action):# 当前rewardnext_row, next_col, reward = move(row, col, action)# 下一状态的valuenext_state_value = values[next_row, next_col]# s,a对应的action valuereturn reward + next_state_value * 0.9# 计算q-table(选择所有动作的可能性)
def get_q_table():new_q_table = np.zeros((4,12,4))# 遍历所有格子for row in range(4):for col in range(12):# 对于特定格子(状态),四个动作的q值for action in range(4):new_q_table[row, col, action] = get_q(row, col, action)return new_q_table# policy evaluation(value update)
def get_values():new_values = np.zeros((4,12))# 遍历所有格子for row in range(4):for col in range(12):# 终止状态价值为0if get_state(row, col) in ['terminal', 'trap']:new_values[row, col] = 0else:# 该状态的value = 各个动作的q值 * 各个动作的概率 之和new_values[row,col] = np.sum(q_table[row, col] * pi[row, col])return new_values# policy improvement(policy update)
def get_pi():new_pi = np.zeros((4,12,4))# 遍历所有格子for row in range(4):for col in range(12):# 终止状态无需策略if get_state(row, col) in ['terminal', 'trap']:continue# # 该状态下,有最大q值的动作有几个# max_q = np.max(q_table[row,col])# count = np.sum(q_table[row,col]==max_q)# # 让这些动作均分概率,其它为0# for action in range(4):#     if q_table[row,col,action]==max_q:#         new_pi[row, col, action]=1/count#     else:#         new_pi[row, col, action] =0# greedya = np.argmax(q_table[row, col])for action in range(4):if action == a:new_pi[row, col, action] = 1else:new_pi[row, col, action] = 0return new_pi# 循环迭代策略评估和策略提升,寻找最优解
# 增加收敛判断,可提前终止
theta = 1e-6  # 收敛阈值
for _ in range(100):old_pi = pi.copy()  # 保存旧策略old_values = values.copy()# 策略评估:直到价值函数收敛while True:q_table = get_q_table()values = get_values()if np.all(np.abs(values - old_values) < theta):break  # 价值函数收敛,结束本轮评估old_values = values.copy()  # 注意:要更新旧价值,继续迭代# 策略提升:生成新策略pi = get_pi()if np.all(pi == old_pi):break  # 策略不再变化时终止# 打印结果
for row in range(4):for col in range(12):state = get_state(row,col)if state == 'terminal':print('🚩', end=' ')  # 终点elif state == 'trap':print('🪨', end=' ')  # 悬崖else:action = np.argmax(pi[row, col])if action == 0:print('⬆️',end=' ')elif action == 1:print('⬇️', end=' ')elif action == 2:print('👈',end=' ')else:print('👉',end=' ')print()

注意:

如果当前状态已经是掉进悬崖或者到达终点(终止状态),

则在状态转移move函数中,直接返回原地状态和reward=0(因为游戏结束了,不会再有状态转移,让它待在原地不动,没有奖惩);

在get_value函数中,终止状态的value为0(因为后续一直待在原地,奖惩一直为0);

在get_pi函数中,终止状态无需策略,后面打印时打印陷阱或终点的图标即可。

http://www.dtcms.com/a/325144.html

相关文章:

  • ai生成完成后语音通知
  • Starlink卫星终端对星策略是终端自主执行的还是网管中心调度的?
  • 如何部署图床系统 完整教程
  • python魔法属性__module__与__class__介绍
  • 学习numpy详解
  • Shell脚本-其他变量定义
  • 全面了解机器语言之kmeans
  • Redis缓存穿透、缓存击穿、缓存雪崩
  • Mock与Stub
  • 组合期权:水平价差
  • day29 消息队列
  • CST支持对哪些模型进行特征模仿真?分别有哪些用于特征模分析的求解器?
  • 信号处理函数中调用printf时,遇到中断为什么容易导致缓冲区损坏?
  • 介绍一下线程的生命周期及状态?
  • 化工设备健康管理解决方案:基于多物理场监测的智能化技术实现
  • 【系统分析师】软件需求工程——第11章学习笔记(上)
  • 堆(Java实现)
  • 大数据架构演变之路
  • [激光原理与应用-222]:机械 - 3D设计与2D设计的异同比较
  • 赋值运算符指南
  • GoBy 工具安装 | Windows 操作系统安装 GoBy
  • 某市智慧社区企业管理平台原型设计:数据驱动的社区治理新路径
  • 常用hook钩子函数
  • 设备活动审计技术方案解析
  • WSL创建虚拟机配置VNC
  • Linux系统编程——进程控制
  • 编程基础之多维数组——计算鞍点
  • 六、RuoYi-Cloud-Plus OSS文件上传配置
  • [Python 基础课程]常用函数
  • 数学与应用数学专业大学如何规划?就业前景怎么样?