【matlab】SARSA算法及示例代码
参考链接1:Sarsa算法
参考链接2:强化学习:时间差分(TD)(SARSA算法和Q-Learning算法)(看不懂算我输专栏)——手把手教你入门强化学习(六)
参考链接3:【强化学习】Sarsa+Sarsa-lambda(Sarsa(λ))算法详解
SARSA(State-Action-Reward-State-Action)算法是一种用于强化学习的算法,它属于时序差分(Temporal-Difference,TD)学习方法。
算法原理:基本思想:SARSA算法的核心是通过智能体(Agent)与环境(Environment)的交互来学习最优的行动策略。智能体在每个时间步根据当前策略选择一个动作,执行该动作后会获得环境的反馈,包括新的状态和奖励。然后,智能体会根据这些信息来更新其对当前状态和动作的价值估计,从而逐步改进策略。
示例:
以下是一个简单的SARSA算法的MATLAB示例代码,用于解决一个简单的网格世界问题。
在这个问题中,智能体需要从起点移动到终点,同时避免障碍物。
问题描述
网格世界:一个5x5的网格,智能体从左上角(1,1)开始,目标是到达右下角(5,5)。
动作:智能体可以向上、向下、向左、向右移动。
奖励:到达目标位置获得+10的奖励,撞到障碍物或边界获得-1的奖励,其他情况下获得0奖励。
障碍物:假设障碍物坐标为[3,2]、[3,3]、[3,4]、[3,5]。
代码:
主代码:
% SARSA算法示例代码
function main()
clc
clear
close all
% 以下是一个简单的SARSA算法的MATLAB示例代码,用于解决一个简单的网格世界问题。
% 在这个问题中,智能体需要从起点移动到终点,同时避免障碍物。
% 问题描述
% 网格世界:一个5x5的网格,智能体从左上角(1,1)开始,目标是到达右下角(5,5)。
% 动作:智能体可以向上、向下、向左、向右移动。
% 奖励:到达目标位置获得+10的奖励,撞到障碍物或边界获得-1的奖励,其他情况下获得0奖励。
% 障碍物:假设障碍物坐标为[3,2]、[3,3]、[3,4]、[3,5]。%%
% 初始化参数
grid_size = 5; % 网格大小
num_states = grid_size * grid_size; % 状态总数
num_actions = 4; % 动作总数(上、下、左、右)
alpha = 0.1; % 学习率
gamma = 0.9; % 折扣因子
epsilon = 0.1; % 探索概率
num_episodes = 1000; % 训练的总轮数% 定义动作
actions = [0, 1; 1, 0; 0, -1; -1, 0]; % 上、右、下、左% 初始化Q表
Q = zeros(num_states, num_actions);% 定义障碍物和目标位置
obstacle = {[3,1],[3,3],[3,4],[3,5]};%必须用花括号,最终结果中应不包含obstacle内的点
goal = [5, 5];%%
% 训练SARSA算法
for episode = 1:num_episodes% 初始化状态state = [1, 1];state_index = state_to_index(state, grid_size);% 选择初始动作action = choose_action(state_index, Q, epsilon,num_actions);while ~isequal(state, goal)% 执行动作,观察下一个状态和奖励[next_state, reward] = take_action(state, action, grid_size, obstacle, goal,actions);next_state_index = state_to_index(next_state, grid_size);% 选择下一个动作next_action = choose_action(next_state_index, Q, epsilon,num_actions);% 更新Q值Q(state_index, action) = Q(state_index, action) + ...alpha * (reward + gamma * Q(next_state_index, next_action) - Q(state_index, action));% 更新状态和动作state = next_state;state_index = next_state_index;action = next_action;end
end% 打印Q表
disp('Q表:');
disp(Q);%%
% 测试策略
state = [1, 1];
while ~isequal(state, goal)state_index = state_to_index(state, grid_size);[~, action] = max(Q(state_index, :));[next_state, ~] = take_action(state, action, grid_size, obstacle, goal,actions);fprintf('从位置 (%d, %d) 执行动作 %d 到达位置 (%d, %d)\n', ...state(1), state(2), action, next_state(1), next_state(2));state = next_state;
end
fprintf('到达目标位置 (%d, %d)\n', goal(1), goal(2));
take_action.m
% 执行动作,返回下一个状态和奖励
function [next_state, reward] = take_action(state, action, grid_size, obstacle, goal,actions)next_state = state + actions(action, :);isInCell = false;for i = 1:length(obstacle)if isequal(obstacle{i}, next_state)isInCell = true;break; % 找到后退出循环endend% 检查是否超出边界if next_state(1) < 1 || next_state(1) > grid_size || next_state(2) < 1 || next_state(2) > grid_sizereward = -1;next_state = state; % 保持在原位置% 检查是否撞到障碍物elseif isInCell%isequal(next_state, obstacle)reward = -1;% 检查是否到达目标elseif isequal(next_state, goal)reward = 10;elsereward = 0;end
end
choose_action.m
% 选择动作(epsilon-greedy策略)
function action = choose_action(state_index, Q, epsilon,num_actions)if rand < epsilonaction = randi(num_actions); % 随机选择动作else[~, action] = max(Q(state_index, :)); % 选择最优动作end
end
index_to_state.m
% 将一维索引转换为二维坐标
function state = index_to_state(state_index, grid_size)state(2) = floor((state_index - 1) / grid_size) + 1;state(1) = mod(state_index - 1, grid_size) + 1;
end
state_to_index.m
% 将状态从二维坐标转换为一维索引
function state_index = state_to_index(state, grid_size)state_index = (state(2) - 1) * grid_size + state(1);
end