【学习笔记】强化学习从原理到实践
视频链接 https://www.youtube.com/watch?v=D0ylO5qzIv0
环境准备
pip uninstall opencv-python opencv-contrib-python opencv-python-headless
pip install opencv-pythonpip install gymnasium
pip install "gymnasium[toy-text]"
基于表格的强化学习方法
冰湖游戏 FrozenLake-v1
import gymnasium as gymclass MyWrapper(gym.Wrapper):def __init__(self):env = gym.make('FrozenLake-v1',render_mode='rgb_array',is_slippery=False)super().__init__(env)self.env = envdef reset(self):state, _ = self.env.reset()return statedef step(self, action):state, reward, terminated, truncated, info = self.env.step(action)over = terminated or truncatedif not over:reward = -1if over and reward == 0:reward = -100return state, reward, overdef show(self):from matplotlib import pyplot as pltplt.figure(figsize=(3,3))plt.imshow(self.env.render())plt.savefig(f"rl_frozenlake")#plt.show()env = MyWrapper()
env.reset()action = 1
next_state, reward, over = env.step(action)
print(next_state, reward, over)env.show()"""
Output:
4 -1 False
"""

Q-Learning 和 sarsa
#import gym
import gymnasium as gym
from IPython import display
import random
import numpy as npclass MyWrapper(gym.Wrapper):def __init__(self):env = gym.make('FrozenLake-v1',render_mode='rgb_array',is_slippery=False)super().__init__(env)self.env = envdef reset(self):state, _ = self.env.reset()return statedef step(self, action):state, reward, terminated, truncated, info = self.env.step(action)over = terminated or truncatedif not over:reward = -1if over and reward == 0:reward = -100return state, reward, overdef show(self):from matplotlib import pyplot as pltplt.figure(figsize=(3,3))plt.imshow(self.env.render())plt.savefig(f"rl_ch03_frozenlake")plt.show()class Pool:def __init__(self):self.pool = []def __len__(self):return len(self.pool)def __getitem__(self, i):return self.pool[i]def update(self):old_len = len(self.pool)while len(self.pool) - old_len < 200:self.pool.extend(play()[0])# print("len:", len(pool), old_len)self.pool = self.pool[-1_0000:]def sample(self):return random.choice(self.pool)def play(show=False):data = []reward_sum = 0state = env.reset()over = Falsewhile not over:action = Q[state].argmax()if random.random() < 0.1:action = env.action_space.sample()# action: 0: left, 1: down, 2: right, 3: upnext_state, reward, over = env.step(action)data.append((state, action, reward, next_state, over))reward_sum += rewardstate = next_stateif show:display.clear_output(wait=True)env.show()return data, reward_sumdef train():for epoch in range(1000):pool.update()for i in range(200):state, action, reward, next_state, over = pool.sample()value = Q[state, action]# Q-Learningtarget = reward + Q[next_state].max() * 0.9# SARSA# next_action = Q[next_state].argmax()# target = reward + Q[next_state, next_action] * 0.9update = (target - value) * 0.1Q[state, action] += updateif epoch % 100 == 0:print(epoch, len(pool), play()[-1])def train_sarsa():for epoch in range(1000):for (state, action, reward, next_state, over) in play()[0]:value = Q[state, action]next_action = Q[next_state].argmax()target = reward + Q[next_state, next_action] * 0.9update = (target - value) * 0.02Q[state, action] += updateif epoch % 100 == 0:print(epoch, play()[-1])env = MyWrapper()
Q = np.zeros((16, 4))
pool = Pool()"""
env.reset()
action = 2 # 0: left, 1: down, 2: right, 3: up
next_state, reward, over = env.step(action)
print(next_state, reward, over)
env.show()
""""""
data, reward = play()
print(len(data), data, reward)
""""""
pool.update()
print(len(pool), pool[0])
"""train()
print(Q)
data, reward = play()
print(len(data), data, reward)
exit()
play(True)[-1]train_sarsa()
输出
Q:
[[ -4.154149 -3.50461 -3.50461 -4.154149 ][ -4.15414906 -100. -2.7829 -3.50502366][ -3.50542088 -1.981 -3.50572718 -2.78303427][ -2.78332601 -89.05810109 -3.50993427 -3.50647229][ -3.50461 -2.7829 -100. -4.154149 ][ 0. 0. 0. 0. ][ -98.80274848 -1.09 -99.36373146 -2.78301387][ 0. 0. 0. 0. ][ -2.7829 -100. -1.981 -3.50461 ][ -2.7829 -1.09 -1.09 -100. ][ -1.981 -0.1 -100. -1.981 ][ 0. 0. 0. 0. ][ 0. 0. 0. 0. ][-100. -1.09 -0.1 -1.981 ][ -1.09 -0.1 1. -1.09 ][ 0. 0. 0. 0. ]]play:
6 [(0, np.int64(1), -1, 4, False), (4, np.int64(1), -1, 8, False), (8, np.int64(2), -1, 9, False), (9, np.int64(1), -1, 13, False), (13, np.int64(2), -1, 14, False), (14, np.int64(2), 1, 15, True)] -4
