当前位置: 首页 > news >正文

强化学习入门三(SARSA)

SARSA算法详解

SARSA是强化学习中另一种经典的时序差分(TD)学习算法,与Q-Learning同属无模型(model-free)算法,但在更新策略上有显著差异。SARSA的名称来源于其更新公式中涉及的五个元素:状态(State)、动作(Action)、奖励(Reward)、下一状态(Next State)、下一动作(Next Action),即(S, A, R, S’, A’)。

SARSA与Q-Learning的核心区别
特性Q-LearningSARSA
学习方式异策略(Off-policy):学习最优策略,无论当前遵循什么策略同策略(On-policy):学习并遵循同一个策略
更新依据基于下一状态的最大Q值(不考虑实际会执行的下一动作)基于实际会执行的下一动作的Q值
适用场景更关注最终结果,适合追求最大累积奖励的场景更关注过程安全性,适合需要考虑执行路径的场景
SARSA的核心公式

SARSA的Q值更新公式如下:

Q(s,a) ← Q(s,a) + α[r + γ·Q(s',a') - Q(s,a)]

其中:

  • α是学习率(0 < α ≤ 1)
  • γ是折扣因子(0 ≤ γ ≤ 1)
  • r是即时奖励
  • s’是执行动作a后到达的新状态
  • a’是在新状态s’下实际会执行的动作(这是与Q-Learning的关键区别)
SARSA算法流程
  1. 初始化Q表(Q(s,a)),通常为0或随机小值
  2. 初始化状态s,根据当前策略选择动作a
  3. 当s不是终止状态时:
    a. 执行动作a,获得奖励r和新状态s’
    b. 根据当前策略选择新状态s’下的动作a’
    c. 使用SARSA更新公式更新Q(s,a)
    d. 将状态和动作更新为s’和a’
  4. 重复步骤2-3,直到Q表收敛

SARSA在机器人控制中的应用示例

以机器人避障导航为例,说明SARSA的应用。与Q-Learning不同,SARSA更适合这类需要考虑路径安全性的任务,因为它会考虑实际要执行的下一步动作,从而避免选择"看似最优但中间步骤危险"的路径。

import numpy as np
import matplotlib.pyplot as pltclass NavigationEnv:"""机器人导航环境"""def __init__(self):# 6x6网格世界: 0-空地, 1-障碍物, 2-目标self.grid = [[0, 0, 0, 1, 0, 0],[0, 1, 0, 1, 0, 1],[0, 1, 0, 0, 0, 0],[0, 1, 1, 1, 1, 0],[0, 0, 0, 0, 1, 0],[0, 1, 0, 0, 0, 2]]self.rows = 6self.cols = 6self.reset()def reset(self):"""重置环境,回到起点"""self.robot_pos = [0, 0]  # 起点位置return tuple(self.robot_pos)def step(self, action):"""执行动作,返回新状态、奖励和是否结束"""# 动作: 0-上, 1-右, 2-下, 3-左row, col = self.robot_posnew_row, new_col = row, coldone = False# 根据动作计算新位置if action == 0:  # 上new_row -= 1elif action == 1:  # 右new_col += 1elif action == 2:  # 下new_row += 1elif action == 3:  # 左new_col -= 1# 检查是否撞墙或越界if (new_row < 0 or new_row >= self.rows or new_col < 0 or new_col >= self.cols or self.grid[new_row][new_col] == 1):# 撞墙惩罚reward = -5done = False  # 撞墙不结束,只是惩罚else:# 移动到新位置self.robot_pos = [new_row, new_col]new_row, new_col = self.robot_pos# 检查是否到达目标if self.grid[new_row][new_col] == 2:reward = 100  # 到达目标的奖励done = Trueelse:# 每步轻微惩罚,鼓励最短路径reward = -1done = Falsereturn tuple(self.robot_pos), reward, donedef render(self):"""可视化当前环境状态"""for i in range(self.rows):for j in range(self.cols):if [i, j] == self.robot_pos:print("R", end=" ")  # 机器人elif self.grid[i][j] == 1:print("#", end=" ")  # 障碍物elif self.grid[i][j] == 2:print("G", end=" ")  # 目标else:print(".", end=" ")  # 空地print()print()class SARSA_Agent:"""基于SARSA算法的智能体"""def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):self.env = envself.alpha = alpha  # 学习率self.gamma = gamma  # 折扣因子self.epsilon = epsilon  # ε-贪婪策略参数# 初始化Q表self.q_table = {}for i in range(env.rows):for j in range(env.cols):self.q_table[(i, j)] = [0.0, 0.0, 0.0, 0.0]  # 四个动作的Q值def choose_action(self, state):"""基于ε-贪婪策略选择动作"""if np.random.uniform(0, 1) < self.epsilon:# 随机选择动作(探索)return np.random.choice(4)else:# 选择当前Q值最大的动作(利用)return np.argmax(self.q_table[state])def learn(self, state, action, reward, next_state, next_action):"""使用SARSA更新公式更新Q值"""# 当前Q值current_q = self.q_table[state][action]# 下一状态和动作的Q值(这是与Q-Learning的关键区别)next_q = self.q_table[next_state][next_action]# SARSA更新公式new_q = current_q + self.alpha * (reward + self.gamma * next_q - current_q)self.q_table[state][action] = new_qdef train(self, episodes=1000):"""训练智能体"""rewards = []  # 记录每回合的总奖励steps = []    # 记录每回合的步数for episode in range(episodes):state = self.env.reset()action = self.choose_action(state)  # 选择初始动作total_reward = 0step = 0done = Falsewhile not done:# 执行动作,获取反馈next_state, reward, done = self.env.step(action)# 选择下一动作(SARSA需要这一步)next_action = self.choose_action(next_state)# 更新Q值self.learn(state, action, reward, next_state, next_action)total_reward += rewardstate, action = next_state, next_action  # 转移到下一状态和动作step += 1# 防止无限循环if step > 200:breakrewards.append(total_reward)steps.append(step)# 每100回合打印一次进度if (episode + 1) % 100 == 0:print(f"Episode {episode+1}/{episodes}, Total Reward: {total_reward:.2f}, Steps: {step}")return rewards, stepsdef test(self):"""测试训练好的智能体"""state = self.env.reset()self.env.render()done = Falsestep = 0while not done and step < 100:action = np.argmax(self.q_table[state])  # 只使用利用,不探索state, _, done = self.env.step(action)self.env.render()step += 1# 主程序
if __name__ == "__main__":# 创建环境和智能体env = NavigationEnv()agent = SARSA_Agent(env, alpha=0.1, gamma=0.9, epsilon=0.1)# 训练智能体print("开始训练...")rewards, steps = agent.train(episodes=1000)# 绘制训练曲线plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(rewards)plt.title("每回合总奖励")plt.xlabel("回合数")plt.ylabel("总奖励")plt.subplot(1, 2, 2)plt.plot(steps)plt.title("每回合步数")plt.xlabel("回合数")plt.ylabel("步数")plt.tight_layout()plt.show()# 测试训练好的智能体print("测试训练好的智能体:")agent.test()

代码解析

上述代码实现了基于SARSA算法的机器人避障导航系统,主要包含两个核心类:

  1. NavigationEnv类:定义了包含障碍物的导航环境

    • 6x6网格世界,包含空地、障碍物和目标点
    • 提供环境交互接口(reset、step、render)
    • 奖励机制:到达目标(+100)、撞墙(-5)、每步移动(-1)
  2. SARSA_Agent类:实现SARSA算法

    • 维护Q表存储状态-动作价值
    • choose_action():基于ε-贪婪策略选择动作
    • learn():使用SARSA公式更新Q值(需要next_action参数)
    • train():多回合训练过程
    • test():验证训练效果

SARSA的关键特性分析

  1. 同策略学习
    SARSA在学习过程中遵循的策略与它要优化的策略是同一个,这使得它学习到的策略更符合实际执行情况,尤其在需要考虑安全性的场景中更具优势。

  2. 对路径的关注
    由于SARSA考虑实际会执行的下一步动作,它倾向于学习更"保守"的路径。例如在避障任务中,SARSA可能会选择远离障碍物的路径,即使这不是最短路径,而Q-Learning可能会选择距离障碍物更近的最短路径。

  3. 探索与利用的平衡
    与Q-Learning类似,SARSA也使用ε-贪婪策略平衡探索和利用,但由于同策略特性,其探索行为会直接影响学习目标。

SARSA的适用场景

  1. 需要安全保障的机器人控制:如自动驾驶、工业机器人操作等,这些场景中过程安全性比单纯追求最优结果更重要。

  2. 连续决策问题:如机器人导航、游戏AI等需要一系列连贯动作的任务。

  3. 部分可观测环境:在环境信息不完全的情况下,SARSA的同策略特性使其能更好地适应实际执行的策略。

总之,SARSA是一种注重执行过程的强化学习算法,在需要考虑动作序列连贯性和安全性的任务中表现优异,是Q-Learning的重要补充。

http://www.dtcms.com/a/296502.html

相关文章:

  • 专题:2025微短剧行业生态构建与跨界融合研究报告|附100+份报告PDF汇总下载
  • LeetCode 1695.删除子数组的最大得分:滑动窗口(哈希表)
  • 07 RK3568 Debian11 网络优先级
  • “抓了个寂寞”:一次实时信息采集的意外和修复
  • 网络基础19--OSPF路由协议(上)
  • 基于QT(C++)实现(图形界面)通讯录系统
  • JavaJSP
  • 【SpringAI实战】FunctionCalling实现企业级自定义智能客服
  • Qt 调用ocx的详细步骤
  • 单片机学习课程
  • 数据推荐丨海天瑞声7月数据集上新啦!
  • 海外红人营销的下一站:APP出海如何布局虚拟网红与UGC生态?
  • idea监控本地堆栈
  • Redis分布式锁的学习(八)
  • 无源域自适应综合研究【2】
  • Qt连接MySql数据库
  • SAP B1 DTW成功登录后点击下一步提示没有权限读取清单
  • QML 模型
  • 阿里云SLS未开启索引时无法查询日志内容
  • 11.事务
  • 【GoLang#1】:Go 语言概述(背景 | 环境配置 | 特点 | 学习)
  • Redis单线程模型(含面试题)
  • pytorch常用函数
  • 【MySQL数据库备份与恢复1】二进制日志,mysqlbinlog
  • Linux Wlan 无线网络驱动开发-scan协议全流程详解
  • 企业安全基石:解锁等保测评的战略价值
  • 循环神经网络--LSTM模型
  • 15.2 DeepSpeed显存优化实战:7B大模型训练资源从84GB压缩到10GB!
  • 11-day08文本匹配
  • Cisco 主模式配置