当前位置: 首页 > news >正文

基于“动手学强化学习”的知识点(一):第 14 章 SAC 算法(gym版本 >= 0.26)

第 14 章 SAC 算法(gym版本 >= 0.26)

  • 摘要
  • SAC 算法(连续)
  • SAC 算法(离散)

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!


对应动手学强化学习——SAC 算法


SAC 算法(连续)

# -*- coding: utf-8 -*-


import random
import gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import rl_utils


class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
        super(PolicyNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)
        '''
        作用:保存动作幅度的界限,便于后续对动作做缩放。
        数值例子:若 action_bound=2,最终动作将会在 [-2, 2] 范围内。
        '''
        self.action_bound = action_bound

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x))
        '''
        作用:使用上面计算得到的 mu 和 std 构造正态分布对象 dist。
        数值例子:
        - 这时构造的分布为 𝑁(0.8,0.474^2)。
        '''
        dist = Normal(mu, std)
        '''
        作用:从正态分布中采样,但采用“重参数化采样”(rsample),以便后续能对采样过程进行梯度反传。
        数值例子:
        - 例如,若采样时随机变量 ε 从标准正态分布中取到 0.3,则采样值为 0.8 + 0.474 * 0.3 ≈ 0.8 + 0.1422 = 0.9422。
        '''
        normal_sample = dist.rsample()  # rsample()是重参数化采样
        '''作用:计算刚采样值在原始正态分布下的对数概率密度。'''
        log_prob = dist.log_prob(normal_sample)
        '''
        作用:对采样的原始动作进行 tanh 激活,将其映射到 (-1, 1) 范围内,保证动作平滑且有界。
        数值例子:
        - 对于采样值 0.9422,torch.tanh(0.9422) ≈ 0.737。
        '''
        action = torch.tanh(normal_sample)
        # 计算tanh_normal分布的对数概率密度
        '''
        作用:由于经过了 tanh 非线性变换,原来的对数概率密度需要进行修正(Jacobian 修正项),这里用公式
        logp_action=logp_normal−log(1−tanh(action)^2+ϵ)
        注意:实际应用中,通常是对 normal_sample 进行修正,写法可能略有不同,但这里的目标是一致的——补偿 tanh 变换带来的概率密度变换。
        '''
        log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)
        action = action * self.action_bound
        return action, log_prob


class QValueNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, a):
        cat = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(cat))
        x = F.relu(self.fc2(x))
        return self.fc_out(x)
    
    
class SACContinuous:
    ''' 处理连续动作的SAC算法 '''
    """
    解释:
    - 定义一个名为 SACContinuous 的类,用来实现针对连续动作的 Soft Actor-Critic 算法。
    """
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device):
        """
        定义构造函数,接收一系列超参数,分别代表状态维度、隐藏层神经元个数、动作维度、动作界限、
        各网络的学习率、目标熵、软更新参数、折扣因子和设备。
        """
        '''
        # 策略网络
        使用前面定义的 PolicyNetContinuous 构造函数生成策略网络(actor),
        并将该网络放到指定设备上(例如 CPU 或 GPU)。
        '''
        self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim, action_bound).to(device)  
        '''
        # 第一个Q网络
        创建第一个 Q 网络,用于评估(状态,动作)对的价值,同样放到指定设备。
        '''
        self.critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device) 
        '''
        # 第二个Q网络
        创建第二个 Q 网络,与第一个结构相同,用于双重估计,帮助缓解过估计问题。
        '''
        self.critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)  
        '''
        # 第一个目标Q网络
        构造第一个目标 Q 网络,其结构与 critic_1 相同,用于计算目标值(TD目标),以便实现平滑更新。
        '''
        self.target_critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device) 
        '''
        # 第二个目标Q网络
        构造第二个目标 Q 网络,其结构与 critic_2 相同,用于目标值计算。
        '''
        self.target_critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)  
        '''
        # 令目标Q网络的初始参数和Q网络一样
        将 critic_1 网络的所有参数复制到 target_critic_1 中,使二者初始时完全一致。
        将 critic_2 网络的所有参数复制到 target_critic_2 中,使二者初始时完全一致。
        '''
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        '''
        使用 Adam 优化器为策略网络分配优化器,学习率为 actor_lr。
        '''
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        '''
        为 critic_1 分配 Adam 优化器,学习率为 critic_lr。
        为 critic_2 分配 Adam 优化器,学习率为 critic_lr。
        '''
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)
        # 使用alpha的log值,可以使训练结果比较稳定
        '''
        创建一个标量张量,用于存储温度参数 alpha 的对数值。初始值设为 log(0.01) ≈ -4.6052。
        这样做有助于稳定训练,因为直接优化正数会带来数值不稳定问题。
        '''
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        '''
        设置该张量的 requires_grad 属性为 True,表示在反向传播时会计算关于 log_alpha 的梯度,
        从而能更新温度参数。
        '''
        self.log_alpha.requires_grad = True  # 可以对alpha求梯度
        '''
        为 log_alpha 创建一个 Adam 优化器,学习率为 alpha_lr。
        注意优化器接收的是一个包含 log_alpha 的列表。
        '''
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)
        '''保存目标熵参数,这个值用于指导策略更新时保持足够的探索性。'''
        self.target_entropy = target_entropy  # 目标熵的大小
        '''保存折扣因子,用于计算未来奖励的折现值。'''
        self.gamma = gamma
        '''保存软更新系数 tau,用于更新目标网络的参数。'''
        self.tau = tau
        '''保存设备信息,便于后续将数据和模型都放到同一设备上。'''
        self.device = device

    def take_action(self, state):
        """定义一个方法,根据当前状态输出一个动作(供环境交互时调用)。"""
        '''
        将传入的状态(例如一个列表或数组)转换为 PyTorch 张量,并在外面加一层列表以增加 batch 维度,
        然后将其放到指定设备上。
        state = [1,2,3,4]
        state = torch.tensor([state], dtype=torch.float).to("cuda")
        print(state) # tensor([[1., 2., 3., 4.]], device='cuda:0')
        state1 = [1,2,3,4]
        state1 = torch.tensor(state1, dtype=torch.float).to("cuda")
        print(state1) # tensor([1., 2., 3., 4.], device='cuda:0')
        state2 = [1,2,3,4]
        state2 = torch.tensor(state2, dtype=torch.float).unsqueeze(0).to("cuda")
        print(state2) # tensor([[1., 2., 3., 4.]], device='cuda:0')
        '''
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        '''
        解释:
        - 将状态输入 actor 网络,得到输出。由于 actor 的 forward 返回的是一个元组(动作、对数概率),
          这里取第一个元素(动作部分)。
        数值例子:
        - 假设 actor 返回 (tensor([[0.737]]), tensor([[-0.45]])),则 action = tensor([[0.737]]);
          再取 [0] 后得到单个样本的动作张量 tensor([0.737])。
        '''
        action = self.actor(state)[0]
        '''
        解释:
        - 将动作张量转换为 Python 标量,并放入列表后返回。
        数值例子:
        - action.item() 会返回 0.737,最终返回 [0.737]。这样可以适应环境要求动作为列表格式的情况。
        '''
        return [action.item()]

    def calc_target(self, rewards, next_states, dones):  # 计算目标Q值
        """
        定义一个方法,利用下一时刻状态、奖励和 done 标志计算 TD 目标(目标 Q 值),用于 critic 网络的回归训练。
        """
        '''
        对所有下一状态(通常是一个 batch),利用 actor 网络计算下一时刻动作和其对应的对数概率。
        数值例子:假设 next_states 有 2 个样本,每个样本状态为 3 维;actor 返回
        - next_actions = tensor([[1.2], [0.8]])
        - log_prob = tensor([[-0.5], [-0.6]])
        '''
        next_actions, log_prob = self.actor(next_states)
        '''
        计算熵项,实际上熵等于负的对数概率。
        数值例子:如果 log_prob = tensor([[-0.5], [-0.6]]),则 entropy = tensor([[0.5], [0.6]])。
        '''
        entropy = -log_prob
        '''使用目标网络1计算给定下一状态和对应动作的 Q 值。'''
        q1_value = self.target_critic_1(next_states, next_actions)
        '''使用目标网络2计算给定下一状态和对应动作的 Q 值。'''
        q2_value = self.target_critic_2(next_states, next_actions)
        '''
        解释:
        - 首先,取两个目标 Q 值的最小值(用来降低过估计风险);
        - 然后加上温度参数 alpha(由 self.log_alpha.exp() 得到)乘以熵项,这一项鼓励探索。
        数值例子:
        - 对第一样本:min(2.0, 2.5) = 2.0,且 self.log_alpha.exp() 计算为 exp(-4.6052) ≈ 0.01;
          熵为 0.5,则 next_value = 2.0 + 0.01 * 0.5 = 2.0 + 0.005 = 2.005。
        - 对第二样本:min(3.0, 3.5) = 3.0,熵为 0.6,
          则 next_value = 3.0 + 0.01 * 0.6 = 3.0 + 0.006 = 3.006。
        '''
        next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy
        '''
        计算 TD 目标:
                                td_target = 𝑟 + 𝛾 × next_value × (1−done)
        当 done 为 1(表示回合结束)时,不再折扣未来奖励。
        '''
        td_target = rewards + self.gamma * next_value * (1 - dones)
        return td_target

    def soft_update(self, net, target_net):
        """定义一个方法,用于对目标网络参数做软更新。传入当前网络和对应的目标网络。"""
        '''遍历目标网络和当前网络中对应的每一对参数(权重和偏置)。'''
        for param_target, param in zip(target_net.parameters(), net.parameters()):
            '''
            对每个参数做软更新:
                                𝜃target←(1−𝜏)𝜃target+𝜏𝜃
                                θtarget←(1−τ)θ target+τθ
            这可以平滑地将目标网络参数向当前网络参数靠拢。
            '''
            param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)

    def update(self, transition_dict):
        """
        定义一个方法,根据从 replay buffer 中采样的转换数据(transition)更新 actor、critic 网络以及温度参数 alpha。"""
        '''将 transition_dict 中的状态数据转换为浮点型张量,并放到指定设备上。'''
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        '''
        同理,将动作数据转换为张量,并通过 view(-1, 1) 调整形状为 (batch_size, 1)(即每个动作为一个标量)。
        数值例子:若 transition_dict['actions'] = [1.0, 0.5],转换后形状为 (2, 1)。
        '''
        actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1, 1).to(self.device)
        '''
        将奖励数据转换为形状为 (batch_size, 1) 的张量。
        数值例子:若 rewards = [1.0, -0.5],则转换后为形状 (2, 1)。
        '''
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        '''
        将下一时刻的状态数据转换为张量。
        数值例子:例如 next_states = [[1.1, 0.4, -0.1], [0.2, 0.0, 0.9]],形状 (2, 3)。
        '''
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        '''
        将 done 标志(0 或 1)转换为形状为 (batch_size, 1) 的张量,用于指示回合是否结束。
        数值例子:若 dones = [0, 1],则转换后为 tensor([[0.0], [1.0]])。
        '''
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)
        # 和之前章节一样,对倒立摆环境的奖励进行重塑以便训练
        '''
        对奖励进行归一化或重塑,使其数值范围更适合训练。对于倒立摆(或类似)环境,
        原始奖励可能范围较大,这里将所有奖励平移 8.0 后除以 8.0。
        数值例子:
        - 如果原始 reward = -8.0,则 ( -8.0 + 8.0) / 8.0 = 0;
        - 如果 reward = 0,则变为 1;
        - 如果 reward = 8,则变为 2。
        '''
        rewards = (rewards + 8.0) / 8.0

        # 更新两个Q网络
        '''
        调用前面定义的 calc_target 方法,根据重塑后的奖励、下一状态和 done 标志计算 TD 目标。
        '''
        td_target = self.calc_target(rewards, next_states, dones)
        '''
        计算 critic_1 的均方误差(MSE)损失。
        - 调用 self.critic_1(states, actions) 得到当前 Q 值估计;
        - 使用 td_target.detach() 表示目标值不参与梯度计算;
        - 用 MSE 损失函数计算误差,再取平均。
        数值例子:
        - 假设 critic_1 输出 Q 值为 2.5,td_target 为 2.98,则误差为 (2.5−2.98)^2≈0.2304;
          对 batch 求均值。
        '''
        critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))
        '''清空 critic_1 优化器中所有累积的梯度,防止梯度累加。'''
        self.critic_1_optimizer.zero_grad()
        '''对 critic_1 损失进行反向传播,计算每个参数的梯度。'''
        critic_1_loss.backward()
        '''更新 critic_1 网络的参数,根据之前计算的梯度和设定的学习率进行一步更新。'''
        self.critic_1_optimizer.step()
        '''清空 critic_2 优化器中所有累积的梯度,防止梯度累加。'''
        self.critic_2_optimizer.zero_grad()
        '''对 critic_2 损失进行反向传播,计算每个参数的梯度。'''
        critic_2_loss.backward()
        '''更新 critic_2 网络的参数,根据之前计算的梯度和设定的学习率进行一步更新。'''
        self.critic_2_optimizer.step()

        # 更新策略网络
        '''使用当前 actor 网络,根据当前状态生成一组新的动作及其对数概率,用于策略更新。'''
        new_actions, log_prob = self.actor(states)
        '''计算熵项,即负的对数概率。'''
        entropy = -log_prob
        '''用当前 critic_1 网络评估新生成动作的 Q 值。'''
        q1_value = self.critic_1(states, new_actions)
        '''用当前 critic_2 网络评估新生成动作的 Q 值。'''
        q2_value = self.critic_2(states, new_actions)
        '''
        计算策略网络(actor)的损失。
        - 第一项:−𝛼 × entropy 用于鼓励策略探索;
        - 第二项:−min(𝑞1, 𝑞2) 表示希望选择高价值动作;
        - 取均值作为整个 batch 的损失。
        '''
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))
        '''对 actor 网络进行梯度清零、反向传播和参数更新。'''
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        '''
        计算温度参数 alpha 的损失。
        - (entropy - self.target_entropy) 表示当前熵与目标熵之间的偏差;
        - 用 detach() 阻断梯度传递给 entropy(即仅更新 alpha);
        - 乘以当前的 𝛼 = exp{log(𝛼);
        - 取均值作为总体损失。
        '''
        alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())
        '''清空 log_alpha 的梯度、反向传播损失并更新 log_alpha 参数。'''
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()
        '''
        调用之前定义的 soft_update 方法,对两个目标 Q 网络分别做软更新,
        使得目标网络参数慢慢跟随当前 Q 网络的更新。
        '''
        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)
    
env_name = 'Pendulum-v1'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]  # 动作最大值
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)

actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005  # 软更新参数
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,
                      actor_lr, critic_lr, alpha_lr, target_entropy, tau,
                      gamma, device)

return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,
                                              replay_buffer, minimal_size,
                                              batch_size)    
    
    
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

SAC 算法(离散)

# -*- coding: utf-8 -*-



import random
import gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import rl_utils


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class QValueNet(torch.nn.Module):
    ''' 只有一层隐藏层的Q网络 '''
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
class SAC:
    ''' 处理离散动作的SAC算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 alpha_lr, target_entropy, tau, gamma, device):
        # 策略网络
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        # 第一个Q网络
        self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device)
        # 第二个Q网络
        self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device)
        self.target_critic_1 = QValueNet(state_dim, hidden_dim,
                                         action_dim).to(device)  # 第一个目标Q网络
        self.target_critic_2 = QValueNet(state_dim, hidden_dim,
                                         action_dim).to(device)  # 第二个目标Q网络
        # 令目标Q网络的初始参数和Q网络一样
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),
                                                   lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),
                                                   lr=critic_lr)
        # 使用alpha的log值,可以使训练结果比较稳定
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True  # 可以对alpha求梯度
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr)
        self.target_entropy = target_entropy  # 目标熵的大小
        self.gamma = gamma
        self.tau = tau
        self.device = device

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    # 计算目标Q值,直接用策略网络的输出概率进行期望计算
    def calc_target(self, rewards, next_states, dones):
        next_probs = self.actor(next_states)
        next_log_probs = torch.log(next_probs + 1e-8)
        entropy = -torch.sum(next_probs * next_log_probs, dim=1, keepdim=True)
        q1_value = self.target_critic_1(next_states)
        q2_value = self.target_critic_2(next_states)
        min_qvalue = torch.sum(next_probs * torch.min(q1_value, q2_value),
                               dim=1,
                               keepdim=True)
        next_value = min_qvalue + self.log_alpha.exp() * entropy
        td_target = rewards + self.gamma * next_value * (1 - dones)
        return td_target

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(),
                                       net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) +
                                    param.data * self.tau)

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)  # 动作不再是float类型
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)

        # 更新两个Q网络
        td_target = self.calc_target(rewards, next_states, dones)
        critic_1_q_values = self.critic_1(states).gather(1, actions)
        critic_1_loss = torch.mean(
            F.mse_loss(critic_1_q_values, td_target.detach()))
        critic_2_q_values = self.critic_2(states).gather(1, actions)
        critic_2_loss = torch.mean(
            F.mse_loss(critic_2_q_values, td_target.detach()))
        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        # 更新策略网络
        probs = self.actor(states)
        log_probs = torch.log(probs + 1e-8)
        # 直接根据概率计算熵
        entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)  #
        q1_value = self.critic_1(states)
        q2_value = self.critic_2(states)
        min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value),
                               dim=1,
                               keepdim=True)  # 直接根据概率计算期望
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy - min_qvalue)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        alpha_loss = torch.mean(
            (entropy - target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)    
        
        
actor_lr = 1e-3
critic_lr = 1e-2
alpha_lr = 1e-2
num_episodes = 200
hidden_dim = 128
gamma = 0.98
tau = 0.005  # 软更新参数
buffer_size = 10000
minimal_size = 500
batch_size = 64
target_entropy = -1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = SAC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, alpha_lr,
            target_entropy, tau, gamma, device)

return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,
                                              replay_buffer, minimal_size,
                                              batch_size)



episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

相关文章:

  • 基本的WinDbg调试指令
  • SEO优先级矩阵:有限资源下的ROI最大化决策模型
  • 科技快讯 | “垃圾短信”可以被识别了;阿里正式推出AI旗舰应用;OpenAI深夜发布全新Agent工具
  • python数据分析文件夹篇--pandas,openpyxl,xlwings三种方法批量创建、 复制、删除工作表
  • JAVA中的多态性以及它在实际编程中的作用
  • 1141. 【贪心算法】排队打水
  • 【2025最新版】如何将fnm与node.js安装在D盘?【保姆级安装及人性话理解教程】
  • git submodule
  • 疗养院管理系统设计与实现(代码+数据库+LW)
  • 动态规划习题代码题解
  • 本地部署量化满血版本deepseek的Ktransformer清华方案的硬件配置
  • 【linux驱动开发】创建proc文件系统中的目录和文件实现
  • win10 win+shift+s 无法立即连续截图 第二次截图需要等很久
  • [RA-L 2023] Coco-LIC:基于非均匀 B 样条的连续时间紧密耦合 LiDAR-惯性-相机里程计
  • API自动化测试实战:Postman + Newman/Pytest的深度解析
  • 深度学习中学习率调整策略
  • java实现智能家居控制系统——入门版
  • vue3怎么和大模型交互?
  • spring security学习入门指引
  • Spring框架详解(IOC容器-上)
  • 古稀之年的设计家吴国欣:重拾水彩,触摸老上海文脉
  • 淮安市车桥中学党总支书记王习元逝世,终年51岁
  • 新版城市规划体检评估解读:把城市安全韧性摆在更加突出位置
  • 人民日报整版聚焦:外贸产品拓内销提速增量,多地加快推动内外贸一体化
  • 泽连斯基已离开土耳其安卡拉
  • 李强:把做强国内大循环作为推动经济行稳致远的战略之举