强化学习入门三(SARSA)
SARSA算法详解
SARSA是强化学习中另一种经典的时序差分(TD)学习算法,与Q-Learning同属无模型(model-free)算法,但在更新策略上有显著差异。SARSA的名称来源于其更新公式中涉及的五个元素:状态(State)、动作(Action)、奖励(Reward)、下一状态(Next State)、下一动作(Next Action),即(S, A, R, S’, A’)。
SARSA与Q-Learning的核心区别
特性 | Q-Learning | SARSA |
---|---|---|
学习方式 | 异策略(Off-policy):学习最优策略,无论当前遵循什么策略 | 同策略(On-policy):学习并遵循同一个策略 |
更新依据 | 基于下一状态的最大Q值(不考虑实际会执行的下一动作) | 基于实际会执行的下一动作的Q值 |
适用场景 | 更关注最终结果,适合追求最大累积奖励的场景 | 更关注过程安全性,适合需要考虑执行路径的场景 |
SARSA的核心公式
SARSA的Q值更新公式如下:
Q(s,a) ← Q(s,a) + α[r + γ·Q(s',a') - Q(s,a)]
其中:
- α是学习率(0 < α ≤ 1)
- γ是折扣因子(0 ≤ γ ≤ 1)
- r是即时奖励
- s’是执行动作a后到达的新状态
- a’是在新状态s’下实际会执行的动作(这是与Q-Learning的关键区别)
SARSA算法流程
- 初始化Q表(Q(s,a)),通常为0或随机小值
- 初始化状态s,根据当前策略选择动作a
- 当s不是终止状态时:
a. 执行动作a,获得奖励r和新状态s’
b. 根据当前策略选择新状态s’下的动作a’
c. 使用SARSA更新公式更新Q(s,a)
d. 将状态和动作更新为s’和a’ - 重复步骤2-3,直到Q表收敛
SARSA在机器人控制中的应用示例
以机器人避障导航为例,说明SARSA的应用。与Q-Learning不同,SARSA更适合这类需要考虑路径安全性的任务,因为它会考虑实际要执行的下一步动作,从而避免选择"看似最优但中间步骤危险"的路径。
import numpy as np
import matplotlib.pyplot as pltclass NavigationEnv:"""机器人导航环境"""def __init__(self):# 6x6网格世界: 0-空地, 1-障碍物, 2-目标self.grid = [[0, 0, 0, 1, 0, 0],[0, 1, 0, 1, 0, 1],[0, 1, 0, 0, 0, 0],[0, 1, 1, 1, 1, 0],[0, 0, 0, 0, 1, 0],[0, 1, 0, 0, 0, 2]]self.rows = 6self.cols = 6self.reset()def reset(self):"""重置环境,回到起点"""self.robot_pos = [0, 0] # 起点位置return tuple(self.robot_pos)def step(self, action):"""执行动作,返回新状态、奖励和是否结束"""# 动作: 0-上, 1-右, 2-下, 3-左row, col = self.robot_posnew_row, new_col = row, coldone = False# 根据动作计算新位置if action == 0: # 上new_row -= 1elif action == 1: # 右new_col += 1elif action == 2: # 下new_row += 1elif action == 3: # 左new_col -= 1# 检查是否撞墙或越界if (new_row < 0 or new_row >= self.rows or new_col < 0 or new_col >= self.cols or self.grid[new_row][new_col] == 1):# 撞墙惩罚reward = -5done = False # 撞墙不结束,只是惩罚else:# 移动到新位置self.robot_pos = [new_row, new_col]new_row, new_col = self.robot_pos# 检查是否到达目标if self.grid[new_row][new_col] == 2:reward = 100 # 到达目标的奖励done = Trueelse:# 每步轻微惩罚,鼓励最短路径reward = -1done = Falsereturn tuple(self.robot_pos), reward, donedef render(self):"""可视化当前环境状态"""for i in range(self.rows):for j in range(self.cols):if [i, j] == self.robot_pos:print("R", end=" ") # 机器人elif self.grid[i][j] == 1:print("#", end=" ") # 障碍物elif self.grid[i][j] == 2:print("G", end=" ") # 目标else:print(".", end=" ") # 空地print()print()class SARSA_Agent:"""基于SARSA算法的智能体"""def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):self.env = envself.alpha = alpha # 学习率self.gamma = gamma # 折扣因子self.epsilon = epsilon # ε-贪婪策略参数# 初始化Q表self.q_table = {}for i in range(env.rows):for j in range(env.cols):self.q_table[(i, j)] = [0.0, 0.0, 0.0, 0.0] # 四个动作的Q值def choose_action(self, state):"""基于ε-贪婪策略选择动作"""if np.random.uniform(0, 1) < self.epsilon:# 随机选择动作(探索)return np.random.choice(4)else:# 选择当前Q值最大的动作(利用)return np.argmax(self.q_table[state])def learn(self, state, action, reward, next_state, next_action):"""使用SARSA更新公式更新Q值"""# 当前Q值current_q = self.q_table[state][action]# 下一状态和动作的Q值(这是与Q-Learning的关键区别)next_q = self.q_table[next_state][next_action]# SARSA更新公式new_q = current_q + self.alpha * (reward + self.gamma * next_q - current_q)self.q_table[state][action] = new_qdef train(self, episodes=1000):"""训练智能体"""rewards = [] # 记录每回合的总奖励steps = [] # 记录每回合的步数for episode in range(episodes):state = self.env.reset()action = self.choose_action(state) # 选择初始动作total_reward = 0step = 0done = Falsewhile not done:# 执行动作,获取反馈next_state, reward, done = self.env.step(action)# 选择下一动作(SARSA需要这一步)next_action = self.choose_action(next_state)# 更新Q值self.learn(state, action, reward, next_state, next_action)total_reward += rewardstate, action = next_state, next_action # 转移到下一状态和动作step += 1# 防止无限循环if step > 200:breakrewards.append(total_reward)steps.append(step)# 每100回合打印一次进度if (episode + 1) % 100 == 0:print(f"Episode {episode+1}/{episodes}, Total Reward: {total_reward:.2f}, Steps: {step}")return rewards, stepsdef test(self):"""测试训练好的智能体"""state = self.env.reset()self.env.render()done = Falsestep = 0while not done and step < 100:action = np.argmax(self.q_table[state]) # 只使用利用,不探索state, _, done = self.env.step(action)self.env.render()step += 1# 主程序
if __name__ == "__main__":# 创建环境和智能体env = NavigationEnv()agent = SARSA_Agent(env, alpha=0.1, gamma=0.9, epsilon=0.1)# 训练智能体print("开始训练...")rewards, steps = agent.train(episodes=1000)# 绘制训练曲线plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(rewards)plt.title("每回合总奖励")plt.xlabel("回合数")plt.ylabel("总奖励")plt.subplot(1, 2, 2)plt.plot(steps)plt.title("每回合步数")plt.xlabel("回合数")plt.ylabel("步数")plt.tight_layout()plt.show()# 测试训练好的智能体print("测试训练好的智能体:")agent.test()
代码解析
上述代码实现了基于SARSA算法的机器人避障导航系统,主要包含两个核心类:
-
NavigationEnv类:定义了包含障碍物的导航环境
- 6x6网格世界,包含空地、障碍物和目标点
- 提供环境交互接口(reset、step、render)
- 奖励机制:到达目标(+100)、撞墙(-5)、每步移动(-1)
-
SARSA_Agent类:实现SARSA算法
- 维护Q表存储状态-动作价值
- choose_action():基于ε-贪婪策略选择动作
- learn():使用SARSA公式更新Q值(需要next_action参数)
- train():多回合训练过程
- test():验证训练效果
SARSA的关键特性分析
-
同策略学习:
SARSA在学习过程中遵循的策略与它要优化的策略是同一个,这使得它学习到的策略更符合实际执行情况,尤其在需要考虑安全性的场景中更具优势。 -
对路径的关注:
由于SARSA考虑实际会执行的下一步动作,它倾向于学习更"保守"的路径。例如在避障任务中,SARSA可能会选择远离障碍物的路径,即使这不是最短路径,而Q-Learning可能会选择距离障碍物更近的最短路径。 -
探索与利用的平衡:
与Q-Learning类似,SARSA也使用ε-贪婪策略平衡探索和利用,但由于同策略特性,其探索行为会直接影响学习目标。
SARSA的适用场景
-
需要安全保障的机器人控制:如自动驾驶、工业机器人操作等,这些场景中过程安全性比单纯追求最优结果更重要。
-
连续决策问题:如机器人导航、游戏AI等需要一系列连贯动作的任务。
-
部分可观测环境:在环境信息不完全的情况下,SARSA的同策略特性使其能更好地适应实际执行的策略。
总之,SARSA是一种注重执行过程的强化学习算法,在需要考虑动作序列连贯性和安全性的任务中表现优异,是Q-Learning的重要补充。