当前位置: 首页 > news >正文

【matlab】SARSA算法及示例代码

参考链接1:Sarsa算法
参考链接2:强化学习:时间差分(TD)(SARSA算法和Q-Learning算法)(看不懂算我输专栏)——手把手教你入门强化学习(六)
参考链接3:【强化学习】Sarsa+Sarsa-lambda(Sarsa(λ))算法详解

SARSA(State-Action-Reward-State-Action)算法是一种用于强化学习的算法,它属于时序差分(Temporal-Difference,TD)学习方法。

算法原理:基本思想:SARSA算法的核心是通过智能体(Agent)与环境(Environment)的交互来学习最优的行动策略。智能体在每个时间步根据当前策略选择一个动作,执行该动作后会获得环境的反馈,包括新的状态和奖励。然后,智能体会根据这些信息来更新其对当前状态和动作的价值估计,从而逐步改进策略。

示例:

以下是一个简单的SARSA算法的MATLAB示例代码,用于解决一个简单的网格世界问题。
在这个问题中,智能体需要从起点移动到终点,同时避免障碍物。
问题描述
网格世界:一个5x5的网格,智能体从左上角(1,1)开始,目标是到达右下角(5,5)。
动作:智能体可以向上、向下、向左、向右移动。
奖励:到达目标位置获得+10的奖励,撞到障碍物或边界获得-1的奖励,其他情况下获得0奖励。
障碍物:假设障碍物坐标为[3,2]、[3,3]、[3,4]、[3,5]。

代码:
主代码:

% SARSA算法示例代码
function main()
clc
clear
close all
% 以下是一个简单的SARSA算法的MATLAB示例代码,用于解决一个简单的网格世界问题。
% 在这个问题中,智能体需要从起点移动到终点,同时避免障碍物。
% 问题描述
%     网格世界:一个5x5的网格,智能体从左上角(1,1)开始,目标是到达右下角(5,5)。
%     动作:智能体可以向上、向下、向左、向右移动。
%     奖励:到达目标位置获得+10的奖励,撞到障碍物或边界获得-1的奖励,其他情况下获得0奖励。
%     障碍物:假设障碍物坐标为[3,2]、[3,3]、[3,4]、[3,5]。%%
% 初始化参数
grid_size = 5; % 网格大小
num_states = grid_size * grid_size; % 状态总数
num_actions = 4; % 动作总数(上、下、左、右)
alpha = 0.1; % 学习率
gamma = 0.9; % 折扣因子
epsilon = 0.1; % 探索概率
num_episodes = 1000; % 训练的总轮数% 定义动作
actions = [0, 1; 1, 0; 0, -1; -1, 0]; % 上、右、下、左% 初始化Q表
Q = zeros(num_states, num_actions);% 定义障碍物和目标位置
obstacle = {[3,1],[3,3],[3,4],[3,5]};%必须用花括号,最终结果中应不包含obstacle内的点
goal = [5, 5];%%
% 训练SARSA算法
for episode = 1:num_episodes% 初始化状态state = [1, 1];state_index = state_to_index(state, grid_size);% 选择初始动作action = choose_action(state_index, Q, epsilon,num_actions);while ~isequal(state, goal)% 执行动作,观察下一个状态和奖励[next_state, reward] = take_action(state, action, grid_size, obstacle, goal,actions);next_state_index = state_to_index(next_state, grid_size);% 选择下一个动作next_action = choose_action(next_state_index, Q, epsilon,num_actions);% 更新Q值Q(state_index, action) = Q(state_index, action) + ...alpha * (reward + gamma * Q(next_state_index, next_action) - Q(state_index, action));% 更新状态和动作state = next_state;state_index = next_state_index;action = next_action;end
end% 打印Q表
disp('Q表:');
disp(Q);%%
% 测试策略
state = [1, 1];
while ~isequal(state, goal)state_index = state_to_index(state, grid_size);[~, action] = max(Q(state_index, :));[next_state, ~] = take_action(state, action, grid_size, obstacle, goal,actions);fprintf('从位置 (%d, %d) 执行动作 %d 到达位置 (%d, %d)\n', ...state(1), state(2), action, next_state(1), next_state(2));state = next_state;
end
fprintf('到达目标位置 (%d, %d)\n', goal(1), goal(2));

take_action.m

% 执行动作,返回下一个状态和奖励
function [next_state, reward] = take_action(state, action, grid_size, obstacle, goal,actions)next_state = state + actions(action, :);isInCell = false;for i = 1:length(obstacle)if isequal(obstacle{i}, next_state)isInCell = true;break; % 找到后退出循环endend% 检查是否超出边界if next_state(1) < 1 || next_state(1) > grid_size || next_state(2) < 1 || next_state(2) > grid_sizereward = -1;next_state = state; % 保持在原位置% 检查是否撞到障碍物elseif isInCell%isequal(next_state, obstacle)reward = -1;% 检查是否到达目标elseif isequal(next_state, goal)reward = 10;elsereward = 0;end
end

choose_action.m

% 选择动作(epsilon-greedy策略)
function action = choose_action(state_index, Q, epsilon,num_actions)if rand < epsilonaction = randi(num_actions); % 随机选择动作else[~, action] = max(Q(state_index, :)); % 选择最优动作end
end

index_to_state.m

% 将一维索引转换为二维坐标
function state = index_to_state(state_index, grid_size)state(2) = floor((state_index - 1) / grid_size) + 1;state(1) = mod(state_index - 1, grid_size) + 1;
end

state_to_index.m

% 将状态从二维坐标转换为一维索引
function state_index = state_to_index(state, grid_size)state_index = (state(2) - 1) * grid_size + state(1);
end

文章转载自:

http://D2RXAAWO.gpryk.cn
http://U474nQeo.gpryk.cn
http://4WxMNIaY.gpryk.cn
http://HkstNdpx.gpryk.cn
http://LxRvJklL.gpryk.cn
http://p2NuNIsb.gpryk.cn
http://NijKp5sg.gpryk.cn
http://7n9WxJuD.gpryk.cn
http://aBSMJfOt.gpryk.cn
http://8Jm8wwHE.gpryk.cn
http://6fOd0zfc.gpryk.cn
http://4hLmYl6n.gpryk.cn
http://7yuYzcqI.gpryk.cn
http://VE5TYHiU.gpryk.cn
http://XmJjSvLb.gpryk.cn
http://hnSVeiMz.gpryk.cn
http://HlI71Gck.gpryk.cn
http://gqFxiqdz.gpryk.cn
http://ULGaQHBa.gpryk.cn
http://IqsL74Ob.gpryk.cn
http://3oB7O12Y.gpryk.cn
http://mZtPNX0Z.gpryk.cn
http://DgHbztqQ.gpryk.cn
http://k64W0Zsg.gpryk.cn
http://BBwrURu2.gpryk.cn
http://bnXnjRto.gpryk.cn
http://fVRRkfUT.gpryk.cn
http://ebG5hKlt.gpryk.cn
http://whXfPz85.gpryk.cn
http://5XtSAvWV.gpryk.cn
http://www.dtcms.com/a/365881.html

相关文章:

  • 服务器搭建日记(十二):创建专用用户通过 Navicat 远程连接 MySQL
  • 红外人体感应(PIR)传感器介绍
  • Linux磁盘inode使用率打满问题处理方案
  • 硬盘 (FOREIGN) Slot:Unconfigured Bad
  • 41. 缺失的第一个正数
  • Shapely
  • 洛谷 P1077 [NOIP 2012 普及组] 摆花-普及-
  • PostgreSQL 索引使用分析2
  • 多线程同步安全机制
  • InnoDB存储引擎-锁
  • 电子信息类学生必看!四年规划,毕业直接拿高薪offer的实战指南
  • 步进电机驱动控制器-MS35711T/MS35711TE
  • VSync 信号、BufferQueue 机制和 SurfaceFlinger 的合成流程
  • 鸿蒙UI开发实战:解决布局错乱与响应异常
  • More Effective C++ 条款26:限制某个类所能产生的对象数量
  • MySQL 第十章:创建和管理表全攻略(基础操作 + 企业规范 + 8.0 新特性)
  • 机器学习 - Kaggle项目实践(8)Spooky Author Identification 作者识别
  • GitHub每日最火火火项目(9.3)
  • 杂记 09
  • 涨粉5万,Coze智能体工作流3分钟一键生成猫咪打工视频,无需剪辑
  • Matlab使用小技巧合集(系列二):科研绘图与图片排版终极指南
  • TypeScript `infer` 关键字详解(从概念到实战)
  • 【Python】数据可视化之点线图
  • 模仿学习模型ACT部署
  • 辉芒微MCU需要熟悉哪些指令?这15条核心指令与入门要点必须掌握
  • Linux gzip 命令详解:从基础到高级用法
  • Python基础(①①Ctypes)
  • C 内存对齐踩坑记录
  • 【随手记】vscode中C语言满足KR风格的方法
  • Elasticsearch核心数据类型