当前位置: 首页 > news >正文

RL【8】:Value Function Approximation

系列文章目录

Fundamental Tools

RL【1】:Basic Concepts
RL【2】:Bellman Equation
RL【3】:Bellman Optimality Equation

Algorithm

RL【4】:Value Iteration and Policy Iteration
RL【5】:Monte Carlo Learning
RL【6】:Stochastic Approximation and Stochastic Gradient Descent

Method

RL【7-1】:Temporal-difference Learning
RL【7-2】:Temporal-difference Learning


文章目录

  • 系列文章目录
    • Fundamental Tools
    • Algorithm
    • Method
  • 前言
  • Algorithm for state value estimation
    • Objective function
    • Optimization algorithms
    • Selection of function approximators
    • Theoretical analysis
  • Sarsa & Q-learning with function approximation
    • Sarsa with function approximation
    • Q-learning with function approximation
  • Deep Q-learning
  • 总结


前言

本系列文章主要用于记录 B站 赵世钰老师的【强化学习的数学原理】的学习笔记,关于赵老师课程的具体内容,可以移步:
B站视频:【【强化学习的数学原理】课程:从零开始到透彻理解(完结)】
GitHub 课程资料:Book-Mathematical-Foundation-of-Reinforcement-Learning


Algorithm for state value estimation

Objective function

Formal introduction

  • Let vπ(s)v_\pi(s)vπ(s) and v^(s,w)\hat v(s,w)v^(s,w) be the true state value and a function for approximation.
  • Our goal is to find an optimal w so that v^(s,w)\hat v(s,w)v^(s,w) can best approximate vπ(s)v_\pi(s)vπ(s) for every sss.
  • This is a policy evaluation problem. Later we will extend to policy improvement.
  • To find the optimal www, we need two steps.
    • The first step is to define an objective function.
    • The second step is to derive algorithms optimizing the objective function.

问题背景:函数近似的价值估计

  • 在实际强化学习中,状态空间可能非常大(甚至连续),没法为每个状态单独存储一个 vπ(s)v_\pi(s)vπ(s)
  • 因此我们需要用一个函数近似器(比如线性函数、神经网络)v^(s,w)\hat v(s,w)v^(s,w) 来逼近真实的 vπ(s)v_\pi(s)vπ(s)
  • 目标:找到一组参数 www,使得 v^(s,w)\hat v(s,w)v^(s,w) 尽可能接近 vπ(s)v_\pi(s)vπ(s)

Objective function

J(w)=E⁣[(vπ(S)−v^(S,w))2].J(w)=\mathbb{E}\!\left[(\,v_\pi(S)-\hat v(S,w)\,)^2\right].J(w)=E[(vπ(S)v^(S,w))2].

  • Our goal is to find the best w that can minimize J(w)J(w)J(w).
  • The expectation is with respect to the random variable S∈SS\in\mathcal{S}SS.

Several ways to define the probability distribution of SSS

  • The first way is to use a uniform distribution
    • That is to treat all the states to be equally important by setting the probability of each state as 1/∣S∣1/|\mathcal{S}|1/∣S.

    • In this case, the objective function becomes

      J(w)=E[(vπ(S)−v^(S,w))2]=1∣S∣∑s∈S(vπ(s)−v^(s,w))2.J(w) = \mathbb{E}[(v_\pi(S) - \hat v(S, w))^2] = \frac{1}{|\mathcal{S}|} \sum_{s \in \mathcal{S}} (v_\pi(s) - \hat v(s, w))^2.J(w)=E[(vπ(S)v^(S,w))2]=S1sS(vπ(s)v^(s,w))2.

    • Drawback:

      • The states may not be equally important. For example, some states may be rarely visited by a policy. Hence, this way does not consider the real dynamics of the Markov process under the given policy.
  • The second way is to use the stationary distribution
    • Stationary distribution is an important concept that will be frequently used in this course. In short, it describes the long-run behavior of a Markov process.

    • Let dπ(s)s∈S{d_\pi(s)}{s \in \mathcal{S}}dπ(s)sS denote the stationary distribution of the Markov process under policy π\piπ. By definition, dπ(s)≥0d\pi(s) \ge 0dπ(s)0 and ∑s∈Sdπ(s)=1\sum_{s \in \mathcal{S}} d_\pi(s) = 1sSdπ(s)=1.

    • The objective function can be rewritten as

      J(w)=E[(vπ(S)−v^(S,w))2]=∑s∈Sdπ(s)(vπ(s)−v^(s,w))2.J(w) = \mathbb{E}[(v_\pi(S) - \hat v(S, w))^2] = \sum_{s \in \mathcal{S}} d_\pi(s)(v_\pi(s) - \hat v(s, w))^2.J(w)=E[(vπ(S)v^(S,w))2]=sSdπ(s)(vπ(s)v^(s,w))2.

    • This function is a weighted squared error.

    • Since more frequently visited states have higher values of dπ(s)d_\pi(s)dπ(s), their weights in the objective function are also higher than those rarely visited states.

状态分布的选择 —— 两种方式

关键问题:**期望 E\mathbb{E}E 是对哪个分布下的状态 SSS 取的?**这会影响训练出来的近似器“偏向于哪些状态更准确”。

  1. Uniform distribution
    • 做法:假设所有状态都等重要,给每个状态分配相同概率:

      P(S=s)=1∣S∣.P(S=s) = \frac{1}{|\mathcal{S}|}.P(S=s)=S1.

    • 目标函数变为:

      J(w)=1∣S∣∑s∈S(vπ(s)−v^(s,w))2.J(w) = \frac{1}{|\mathcal{S}|}\sum_{s\in\mathcal{S}} \big(v_\pi(s) - \hat v(s,w)\big)^2.J(w)=S1sS(vπ(s)v^(s,w))2.

    • 优点

      • 简单直观,保证所有状态都有“平等对待”。
    • 缺点

      • 不符合实际。现实中某些状态出现得很少(比如游戏里的罕见场景),强行要求对它们也拟合得很好,会浪费模型容量。
      • 没有体现马尔可夫过程在策略 π\piπ 下的真实动态。
  2. Stationary distribution
    • 做法:考虑在策略 π\piπ 下,环境长期运行后,每个状态出现的概率分布 dπ(s)d_\pi(s)dπ(s)(即 stationary distribution)。

    • 目标函数变为:

      J(w)=∑s∈Sdπ(s),(vπ(s)−v^(s,w))2.J(w) = \sum_{s \in \mathcal{S}} d_\pi(s),\big(v_\pi(s) - \hat v(s,w)\big)^2.J(w)=sSdπ(s),(vπ(s)v^(s,w))2.

    • 优点

      • 更贴合真实情况,因为智能体在实际运行时,会频繁遇到某些状态而几乎不会遇到另外一些。
      • 在这些“高频状态”上的估计更准确,提升实际执行效果。
    • 缺点

      • 低频状态可能拟合得很差,尤其当某些关键状态虽然重要但很少出现时。

Stationary Distribution:

  • Distribution: Distribution of the state
  • Stationary: Long-run behavior
  • Summary: after the agent runs a long time following a policy, the probability that the agent is at any state can be described by this distribution.

stationary distribution(平稳分布/稳态分布) 相关的基本概念

1. Stationary Distribution 的定义

  • 在马尔可夫过程 (Markov Process) 或马尔可夫决策过程 (MDP) 中,智能体随着时间不断转移状态。

  • Stationary distribution 指的是:当智能体运行足够长时间后,落在每个状态的概率分布。

  • 数学形式:如果 dπ(s)s∈S{d_\pi(s)}{s \in \mathcal{S}}dπ(s)sS 是策略 π\piπ 下的 stationary distribution,则有:

    dπ(s’)=∑s∈Sdπ(s)∑a∈Aπ(a∣s)P(s’∣s,a),d\pi(s’) = \sum_{s \in \mathcal{S}} d_\pi(s) \sum_{a \in \mathcal{A}} \pi(a|s) P(s’|s,a),dπ(s)=sSdπ(s)aAπ(as)P(s’∣s,a),

    • 并且

      ∑s∈Sdπ(s)=1,dπ(s)≥0.\sum_{s \in \mathcal{S}} d_\pi(s) = 1, \quad d_\pi(s) \ge 0.sSdπ(s)=1,dπ(s)0.

  1. 相关的基本概念
    1. Distribution
      • 字面意义:某个变量的概率分布。
      • 在这里是 状态分布:即智能体处于每个状态 s∈Ss \in \mathcal{S}sS 的概率。
    2. Stationary
      • 指的是 长期稳定 的状态。
      • 当时间 t→∞t \to \inftyt 时,状态分布趋于固定值,不再随着时间波动。
      • 也就是说,状态分布收敛到了一个平衡点。
    3. Steady-state distribution / Limiting distribution
      • 同义词:Stationary distribution 也常被称为 steady-state distribution(稳态分布)limiting distribution(极限分布)
      • 强调的是:它是一个长期极限意义下的稳定分布。
  2. 在强化学习中的意义
    1. 价值函数逼近
      • 近似方法(value function approximation) 中,我们用期望定义目标函数:

        J(w)=Es∼dπ[(vπ(s)−v^(s,w))2],J(w) = \mathbb{E}{s \sim d\pi} \big[(v_\pi(s) - \hat v(s,w))^2\big],J(w)=Esdπ[(vπ(s)v^(s,w))2],

        • 其中 dπd_\pidπ 就是 stationary distribution。
      • 这样我们对常访问的状态赋予更大权重,更符合策略的实际表现。

    2. 策略梯度 (Policy Gradient)
      • 策略梯度方法 中,性能目标函数通常写为:

        J(π)=∑sdπ(s)∑aπ(a∣s)qπ(s,a).J(\pi) = \sum_s d_\pi(s) \sum_a \pi(a|s) q_\pi(s,a).J(π)=sdπ(s)aπ(as)qπ(s,a).

      • 这里 dπ(s)d_\pi(s)dπ(s) 表示智能体在策略 π\piπ 下,长期处于状态 sss 的概率。

      • 因此,stationary distribution 是策略梯度方法的核心组成部分。

    3. 直观解释
      • 如果智能体执行一个策略 π\piπ 很久以后:
        • 常访问的状态在 dπ(s)d_\pi(s)dπ(s) 下概率更高;
        • 很少访问的状态概率接近 000
      • 所以 dπ(s)d_\pi(s)dπ(s) 反映了该策略下“现实中真正重要的状态”。
  3. 总结:
    • stationary distribution 描述了策略 π\piπ 下智能体 长期访问状态的概率分布
    • 它也叫 steady-state distributionlimiting distribution
    • 价值函数逼近策略梯度方法 中起着关键作用,因为它决定了优化时不同状态的重要性权重。

Optimization algorithms

Gradient Descent for Value Function Approximation

  • While we have the objective function, the next step is to optimize it.

  • To minimize the objective function J(w)J(w)J(w), we can use the gradient-descent algorithm:

    wk+1=wk−αk∇wJ(wk)w_{k+1} = w_k - \alpha_k \nabla_w J(w_k)wk+1=wkαkwJ(wk)

  • The true gradient is:

    ∇wJ(w)=∇wE[(vπ(S)−v^(S,w))2]\nabla_w J(w) = \nabla_w \mathbb{E}[(v_\pi(S) - \hat v(S,w))^2]wJ(w)=wE[(vπ(S)v^(S,w))2]

    =E[∇w(vπ(S)−v^(S,w))2]= \mathbb{E}[\nabla_w (v_\pi(S) - \hat v(S,w))^2]=E[w(vπ(S)v^(S,w))2]

    =2E[(vπ(S)−v^(S,w))(−∇wv^(S,w))]= 2\mathbb{E}[(v_\pi(S) - \hat v(S,w))(-\nabla_w \hat v(S,w))]=2E[(vπ(S)v^(S,w))(wv^(S,w))]

    =−2E[(vπ(S)−v^(S,w))∇wv^(S,w)]= -2\mathbb{E}[(v_\pi(S) - \hat v(S,w))\nabla_w \hat v(S,w)]=2E[(vπ(S)v^(S,w))wv^(S,w)]

  • The true gradient above involves the calculation of an expectation.

Stochastic Gradient

  • We can use the stochastic gradient to replace the true gradient:

    wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt),w_{t+1} = w_t + \alpha_t (v_\pi(s_t) - \hat v(s_t, w_t)) \nabla_w \hat v(s_t, w_t),wt+1=wt+αt(vπ(st)v^(st,wt))wv^(st,wt),

    • where sts_tst is a sample of SSS. Here, 2αk2\alpha_k2αk is merged to αk\alpha_kαk.
  • This algorithm is not implementable because it requires the true state value vπv_\pivπ, which is the unknown to be estimated.

  • We can replace vπ(st)v_\pi(s_t)vπ(st) with an approximation so that the algorithm is implementable.

Monte Carlo and TD Learning with Function Approximation

  • First, Monte Carlo learning with function approximation

    Let gtg_tgt be the discounted return starting from sts_tst in the episode. Then, gtg_tgt can be used to approximate vπ(st)v_\pi(s_t)vπ(st). The algorithm becomes:

    wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt).w_{t+1} = w_t + \alpha_t (g_t - \hat v(s_t, w_t)) \nabla_w \hat v(s_t, w_t).wt+1=wt+αt(gtv^(st,wt))wv^(st,wt).

  • Second, TD learning with function approximation

    By the spirit of TD learning, rt+1+γv^(st+1,wt)r_{t+1} + \gamma \hat v(s_{t+1}, w_t)rt+1+γv^(st+1,wt) can be viewed as an approximation of vπ(st)v_\pi(s_t)vπ(st). Then, the algorithm becomes:

    wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt).w_{t+1} = w_t + \alpha_t [r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t)] \nabla_w \hat v(s_t, w_t).wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt).

深入解释

  1. 为什么要用 Gradient Descent?

    • 我们有一个目标函数:

      J(w)=E[(vπ(S)−v^(S,w))2].J(w)=\mathbb{E}[(v_\pi(S)-\hat v(S,w))^2].J(w)=E[(vπ(S)v^(S,w))2].

    • 目标是最小化 真实状态价值 vπ(S)v_\pi(S)vπ(S)近似函数 v^(S,w)\hat v(S,w)v^(S,w) 的误差。

    • 这相当于一个 回归问题:拟合一个函数 v^\hat vv^ 来逼近真实的 vπv_\pivπ

    • 于是,可以使用最常见的优化方法 —— 梯度下降 (Gradient Descent)

      wk+1=wk−αk∇wJ(wk).w_{k+1} = w_k - \alpha_k \nabla_w J(w_k).wk+1=wkαkwJ(wk).

      • 也就是说,每一步更新参数 www,使得 J(w)J(w)J(w) 逐渐减小。
  2. 真梯度 (True Gradient) 的含义

    • 通过链式法则展开:

      ∇wJ(w)=−2E[(vπ(S)−v^(S,w))∇wv^(S,w)].\nabla_w J(w) = -2\mathbb{E}[(v_\pi(S) - \hat v(S,w)) \nabla_w \hat v(S,w)].wJ(w)=2E[(vπ(S)v^(S,w))wv^(S,w)].

    • 解释:

      • 误差项 (vπ(S)−v^(S,w))(v_\pi(S) - \hat v(S,w))(vπ(S)v^(S,w)) 表示预测与真实值之间的差距。
      • 乘上 ∇wv^(S,w)\nabla_w \hat v(S,w)wv^(S,w),告诉我们 怎样调整参数 www 才能缩小这个差距
      • 负号说明:如果预测比真实值小,就增加 v^\hat vv^;反之减少 v^\hat vv^
    • 这和标准的监督学习回归完全一致。

  3. 为什么要用 Stochastic Gradient?

    • 问题在于:

      • 真梯度 ∇wJ(w)\nabla_w J(w)wJ(w) 涉及对所有状态 SSS 的期望;
      • 这通常不可行,因为状态空间巨大,而且 vπ(S)v_\pi(S)vπ(S) 也未知。
    • 于是用 SGD (Stochastic Gradient Descent) 替代:

      wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt),w_{t+1} = w_t + \alpha_t (v_\pi(s_t) - \hat v(s_t,w_t)) \nabla_w \hat v(s_t,w_t),wt+1=wt+αt(vπ(st)v^(st,wt))wv^(st,wt),

      • 其中 sts_tst 是一个采样的状态。
        • 好处:只需要一个样本就能更新,成本低。
        • 问题:它依然需要 vπ(st)v_\pi(s_t)vπ(st),但这正是我们要估计的未知量。
  4. 怎么替代 vπ(st)v_\pi(s_t)vπ(st)

    因为 vπ(st)v_\pi(s_t)vπ(st) 无法直接获得,我们需要找到可以近似它的替代量:

    1. Monte Carlo Learning with Function Approximation
      • 使用 episode 完整回报 gtg_tgt 作为 vπ(st)v_\pi(s_t)vπ(st) 的无偏估计。

      • 更新公式:

        wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt).w_{t+1} = w_t + \alpha_t (g_t - \hat v(s_t, w_t)) \nabla_w \hat v(s_t, w_t).wt+1=wt+αt(gtv^(st,wt))wv^(st,wt).

      • 直观理解:

        • gtg_tgt = 从 sts_tst 出发一路走到底的累计奖励。
        • gtg_tgt 替代 vπ(st)v_\pi(s_t)vπ(st),再做梯度下降。
        • 缺点:要等 整条轨迹结束 才能更新;方差大。
    2. TD Learning with Function Approximation
      • 使用 TD Target rt+1+γv^(st+1,wt)r_{t+1} + \gamma \hat v(s_{t+1}, w_t)rt+1+γv^(st+1,wt) 来近似 vπ(st)v_\pi(s_t)vπ(st)

      • 更新公式:

        wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt).w_{t+1} = w_t + \alpha_t \big[r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t)\big] \nabla_w \hat v(s_t, w_t).wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt).

      • 直观理解:

        • 不用等整条轨迹,只看一步奖励 rt+1r_{t+1}rt+1 加上下一个状态的预测。
        • 属于 bootstrapping:用已有估计来辅助更新。
        • 优点:在线学习,更新快,方差小;缺点:可能引入偏差。

Pseudocode: TD learning with function approximation

  • Initialization: A function v^(s,w)\hat v(s,w)v^(s,w) that is a differentiable in www. Initial parameter w0w_0w0.
  • Aim: Approximate the true state values of a given policy π\piπ.
  • For each episode generated following the policy π\piπ, do
    • For each step (st,rt+1,st+1)(s_t, r_{t+1}, s_{t+1})(st,rt+1,st+1), do
      • In the general case,

        wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big] \nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt)

      • In the linear case,

        wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \phi^T(s_{t+1}) w_t - \phi^T(s_t) w_t \big] \phi(s_t)wt+1=wt+αt[rt+1+γϕT(st+1)wtϕT(st)wt]ϕ(st)

Selection of function approximators

  1. Function selection

    • The first approach, which was widely used before, is to use a linear function

      v^(s,w)=ϕT(s)w\hat v(s, w) = \phi^T(s) wv^(s,w)=ϕT(s)w

      Here, ϕ(s)\phi(s)ϕ(s) is the feature vector, which can be a polynomial basis, Fourier basis, … .

    • The second approach, which is widely used nowadays, is to use a neural network as a nonlinear function approximator. The input of the NN is the state, the output is v^(s,w)\hat v(s,w)v^(s,w), and the network parameter is www.

  2. TD-Linear

    • In the linear case where v^(s,w)=ϕT(s)w\hat v(s, w) = \phi^T(s) wv^(s,w)=ϕT(s)w, we have

      ∇wv^(s,w)=ϕ(s).\nabla_w \hat v(s, w) = \phi(s).wv^(s,w)=ϕ(s).

    • Substituting the gradient into the TD algorithm

      wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big] \nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt)

    • yields

      wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st),w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \phi^T(s_{t+1}) w_t - \phi^T(s_t) w_t \big] \phi(s_t),wt+1=wt+αt[rt+1+γϕT(st+1)wtϕT(st)wt]ϕ(st),

    • which is the algorithm of TD learning with linear function approximation (TD-Linear).

  3. Disadvantages and Advantages of linear function approximation

    • Disadvantages of linear function approximation:
      • Difficult to select appropriate feature vectors.
    • Advantages of linear function approximation:
      • The theoretical properties of the TD algorithm in the linear case can be much better understood than in the nonlinear case.
      • Linear function approximation is still powerful in the sense that the tabular representation is merely a special case of linear function approximation.
  4. Tabular representation as a special case of linear function approximation

    We next show that the tabular representation is a special case of linear function approximation.

    • First, consider the special feature vector for state sss:

      ϕ(s)=es∈R∣S∣,\phi(s) = e_s \in \mathbb{R}^{|\mathcal{S}|},ϕ(s)=esRS,

      -where ese_ses is a vector with the sss-th entry as 111 and the others as 000.

    • In this case,

      v^(s,w)=esTw=w(s),\hat v(s, w) = e_s^T w = w(s),v^(s,w)=esTw=w(s),

      • where w(s)w(s)w(s) is the sss-th entry of www.
  5. Connection with Tabular TD

    • Recall that the TD-Linear algorithm is

      wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st).w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \phi^T(s_{t+1}) w_t - \phi^T(s_t) w_t \big] \phi(s_t).wt+1=wt+αt[rt+1+γϕT(st+1)wtϕT(st)wt]ϕ(st).

    • When ϕ(st)=es\phi(s_t) = e_sϕ(st)=es, the above algorithm becomes

      wt+1=wt+αt(rt+1+γwt(st+1)−wt(st))est.w_{t+1} = w_t + \alpha_t \big( r_{t+1} + \gamma w_t(s_{t+1}) - w_t(s_t) \big) e_{s_t}.wt+1=wt+αt(rt+1+γwt(st+1)wt(st))est.

      • This is a vector equation that merely updates the sts_tstth entry of wtw_twt.
    • Multiplying estTe_{s_t}^TestT on both sides of the equation gives

      wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)−wt(st)),w_{t+1}(s_t) = w_t(s_t) + \alpha_t \big( r_{t+1} + \gamma w_t(s_{t+1}) - w_t(s_t) \big),wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)wt(st)),

    • which is exactly the tabular TD algorithm.

方法选择

  1. 该用线性还是神经网络?
    • 线性逼近v^(s,w)=ϕ(s)⊤w\hat v(s,w)=\phi(s)^\top wv^(s,w)=ϕ(s)w

      你通过手工设计的特征 ϕ(s)\phi(s)ϕ(s)(多项式、Fourier、tile coding、one-hot…)把状态映射到低维,再学一个权重向量 www

      • 何时优先用
        • 状态空间不大或可良好表征;
        • 追求可解释性与收敛保证(尤其 on-policy 情形);
        • 算力或数据有限,需要稳健、低方差的学习器。
    • 非线性逼近(NN):直接学 v^(s,w)=NN(s;w)\hat v(s,w)=\text{NN}(s;w)v^(s,w)=NN(s;w)

      • 何时优先用
        • 原始状态是高维/非线性(图像、文本、复杂传感器);
        • 目标是端到端的深度 RL;
        • 有足够数据与算力,且可以接受训练不稳定时的调参成本。
  2. TD-Linear 的本质:半梯度 + 投影不动点
    • 线性情形 ∇wv^(s,w)=ϕ(s)\nabla_w \hat v(s,w)=\phi(s)wv^(s,w)=ϕ(s)。把它代入 TD(0) 更新:

      wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]ϕ(st).w_{t+1}=w_t+\alpha_t\Big[r_{t+1}+\gamma\,\hat v(s_{t+1},w_t)-\hat v(s_t,w_t)\Big]\;\phi(s_t).wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]ϕ(st).

    • 这是半梯度(semi-gradient)方法:把 TD target rt+1+γv^(st+1,wt)r_{t+1}+\gamma \hat v(s_{t+1},w_t)rt+1+γv^(st+1,wt) 当作常数来对 www 求导(不反向传播到下一状态的 v^\hat vv^ 里),否则会得到“全梯度”算法,实践中反而更容易不稳定。

    • 几何视角看,它在求解投影 Bellman 方程

      Φw≈ΠdπTπ(Φw),\Phi w \approx \Pi_{d_\pi}\, \mathcal T_\pi(\Phi w),ΦwΠdπTπ(Φw),

      • 其中 Φ\PhiΦ 的列空间是特征张成的函数类,Πdπ\Pi_{d_\pi}Πdπ 是按stationary distribution dπd_\pidπ 的最小二乘投影。
      • 含义:环境让你往 Tπv\mathcal T_\pi vTπv 方向走,但你只能留在“可表示”的子空间里,于是把它投影回去
    • 收敛性(指 on-policy、线性、适当步长):Tπ\mathcal T_\piTπ 是压缩映射,半梯度 TD(0) 在很多条件下收敛到上述投影不动点

  3. 表格(tabular)为什么是线性逼近的特例?
    • 取 one-hot 特征:ϕ(s)=es\phi(s)=e_sϕ(s)=es。则

      v^(s,w)=es⊤w=w(s),\hat v(s,w)=e_s^\top w = w(s),v^(s,w)=esw=w(s),

      • 也就是“每个状态一个参数”。
    • 代回 TD-Linear:

      wt+1=wt+αt(rt+1+γwt(st+1)−wt(st))est.w_{t+1}=w_t+\alpha_t\big(r_{t+1}+\gamma w_t(s_{t+1})-w_t(s_t)\big)e_{s_t}.wt+1=wt+αt(rt+1+γwt(st+1)wt(st))est.

    • sts_tst 这一维被更新,左乘 est⊤e_{s_t}^\topest

      wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)−wt(st)),w_{t+1}(s_t)=w_t(s_t)+\alpha_t\big(r_{t+1}+\gamma w_t(s_{t+1})-w_t(s_t)\big),wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)wt(st)),

    • 这正是表格 TD(0)

    • 结论:表格 = 线性逼近 + one-hot 特征

  4. 和目标函数 J(w) 的关系
    • 线性 on-policy 情况下,半梯度 TD 并不是直接最小化

      J(w)=ES∼dπ⁣[(vπ(S)−v^(S,w))2],J(w)=\mathbb E_{S\sim d_\pi}\!\big[(v_\pi(S)-\hat v(S,w))^2\big],J(w)=ESdπ[(vπ(S)v^(S,w))2],

      • 而是逼近投影 Bellman 解。两者在一般情况下并不相同,但在很多问题上,这个解既可计算效果好
    • 若你确实想“真最小化” J(w)J(w)J(w),需要能访问 vπv_\pivπ 或用 MC 回报近似它,此时会回到 MC+函数逼近 的更新

      wt+1=wt+αt(gt−v^(st,wt))ϕ(st),w_{t+1}=w_t+\alpha_t\,(g_t-\hat v(s_t,w_t))\,\phi(s_t),wt+1=wt+αt(gtv^(st,wt))ϕ(st),

      • 方差更大,但目标一致。

Theoretical analysis

  1. The algorithm

    wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big] \nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt)

    does not minimize the following objective function:

    J(w)=E⁣[(vπ(S)−v^(S,w))2]J(w) = \mathbb{E}\!\big[ ( v_\pi(S) - \hat v(S, w) )^2 \big]J(w)=E[(vπ(S)v^(S,w))2]

  2. Different objective functions

    • Objective function 1: True value error

        $J_E(w) = \mathbb{E}\!\big[ ( v_\pi(S) - \hat v(S, w) )^2 \big]
      

      = | \hat v(w) - v_\pi |^2_D$

    • Objective function 2: Bellman error

      JBE(w)=∥v^(w)−(rπ+γPπv^(w))∥D2≐∥v^(w)−Tπ(v^(w))∥D2J_{BE}(w) = \| \hat v(w) - (r_\pi + \gamma P_\pi \hat v(w)) \|^2_D \doteq \| \hat v(w) - T_\pi(\hat v(w)) \|^2_DJBE(w)=v^(w)(rπ+γPπv^(w))D2v^(w)Tπ(v^(w))D2

      • where

        Tπ(x)≐rπ+γPπxT_\pi(x) \doteq r_\pi + \gamma P_\pi xTπ(x)rπ+γPπx

    • Objective function 3: Projected Bellman error

      JPBE(w)=∥v^(w)−MTπ(v^(w))∥D2J_{PBE}(w) = \| \hat v(w) - M T_\pi(\hat v(w)) \|^2_DJPBE(w)=v^(w)MTπ(v^(w))D2

      • where MMM is a projection matrix.

算法间的差距

  1. True value error

    JE(w)=E⁣[(vπ(S)−v^(S,w))2]=∥v^(w)−vπ∥D2J_E(w) = \mathbb{E}\!\big[(v_\pi(S) - \hat v(S, w))^2\big] = \| \hat v(w) - v_\pi \|^2_DJE(w)=E[(vπ(S)v^(S,w))2]=v^(w)vπD2

    • 含义:直接最小化近似值函数 v^(s,w)\hat v(s,w)v^(s,w) 与真实值函数 vπ(s)v_\pi(s)vπ(s) 的差距。
    • 理想目标:这是最自然、最直观的优化目标(类似 supervised learning)。
    • 问题:我们 不知道 vπ(s)v_\pi(s)vπ(s),只能通过采样和 Bellman 方程间接近似,因此无法直接最小化这个目标。
  2. Bellman error

    JBE(w)=∥v^(w)−Tπ(v^(w))∥D2J_{BE}(w) = \| \hat v(w) - T_\pi(\hat v(w)) \|^2_DJBE(w)=v^(w)Tπ(v^(w))D2

    • 其中

      Tπ(x)=rπ+γPπxT_\pi(x) = r_\pi + \gamma P_\pi xTπ(x)=rπ+γPπx

    • 含义:衡量 v^(s,w)\hat v(s,w)v^(s,w) 与 Bellman 方程的不一致程度。

      • Bellman 方程的固定点就是 vπv_\pivπ
      • 如果 v^(w)\hat v(w)v^(w) 落在“完美的”函数空间里,最小化 Bellman 误差就能得到真实值函数。
      • 问题:当函数近似器(如线性函数或神经网络)不能精确表示 vπv_\pivπ 时,直接最小化 Bellman 误差可能会得到“发散”的解。
  3. Projected Bellman error

    JPBE(w)=∥v^(w)−MTπ(v^(w))∥D2J_{PBE}(w) = \| \hat v(w) - M T_\pi(\hat v(w)) \|^2_DJPBE(w)=v^(w)MTπ(v^(w))D2

    • 其中 M 是一个投影矩阵,把 Bellman 更新结果 Tπ(v^)T_\pi(\hat v)Tπ(v^) 投影回函数近似空间。
    • 含义:由于函数近似空间(比如线性函数空间)通常比较有限,我们无法保证v^(w)证 \hat v(w)v^(w) 可以完美满足 Bellman 方程。所以我们要求 投影后的 Bellman 更新结果 尽可能接近 v^(w)\hat v(w)v^(w)
    • 本质:寻找一个在函数空间中“最接近 Bellman 固定点”的近似解。
    • 重要性:这是 TD-Linear 实际优化的目标。TD 的更新过程相当于隐式地做了这个投影,因此收敛点是最小化 projected Bellman error 的解,而不是直接的 true value error。
  4. 为什么 TD-Linear 对应 projected Bellman error?

    • TD 更新:

      wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big]\nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt)

    • 这里的更新使用的是 TD target rt+1+γv^(st+1,wt)r_{t+1} + \gamma \hat v(s_{t+1}, w_t)rt+1+γv^(st+1,wt),它相当于把 Bellman 更新结果 Tπ(v^)T_\pi(\hat v)Tπ(v^) 投影回函数近似空间。

    • 因此,TD 不会直接最小化 JE(w)J_E(w)JE(w) JBE(w)J_{BE}(w)JBE(w),而是最小化 projected Bellman error

  5. 总结

    • True value error:最理想的目标,但无法直接计算。
    • Bellman error:衡量与 Bellman 方程的不一致,但函数逼近时可能不稳定。
    • Projected Bellman error:TD 实际最小化的目标,在逼近空间内找到最合理的解,保证了收敛性。

Sarsa & Q-learning with function approximation

Sarsa with function approximation

Sarsa algorithm

  • So far, we merely considered the problem of state value estimation. That is we hope

    v^≈vπ\hat v \approx v_\piv^vπ

  • To search for optimal policies, we need to estimate action values.

  • The Sarsa algorithm with value function approximation is

    wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt).w_{t+1} = w_t + \alpha_t \Big[r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t)\Big] \nabla_w \hat q(s_t, a_t, w_t).wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)q^(st,at,wt)]wq^(st,at,wt).

  • This is the same as the algorithm we introduced previously in this lecture except that v^\hat vv^ is replaced by q^\hat qq^.

Pseudocode: Sarsa with function approximation

  • Aim: Search a policy that can lead the agent to the target from an initial state-action pair (s0,a0)(s_0, a_0)(s0,a0).
  • For each episode, do
    • If the current sts_tst is not the target state, do
      • Take action ata_tat following πt(st)\pi_t(s_t)πt(st), generate rt+1,st+1r_{t+1}, s_{t+1}rt+1,st+1, and then take action at+1a_{t+1}at+1 following πt(st+1)\pi_t(s_{t+1})πt(st+1)

      • Value update (parameter update):

        wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt)w_{t+1} = w_t + \alpha_t \Big[r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t)\Big]\nabla_w \hat q(s_t, a_t, w_t)wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)q^(st,at,wt)]wq^(st,at,wt)

      • Policy update:

        πt+1(a∣st)=1−ε∣A(s)∣(∣A(s)∣−1)if a=arg⁡max⁡a∈A(st)q^(st,a,wt+1)\pi_{t+1}(a|s_t) = 1 - \frac{\varepsilon}{|\mathcal{A}(s)|} (|\mathcal{A}(s)| - 1) \quad \text{if } a = \arg\max_{a \in \mathcal{A}(s_t)} \hat q(s_t, a, w_{t+1})πt+1(ast)=1A(s)ε(A(s)1)if a=argmaxaA(st)q^(st,a,wt+1)

        πt+1(a∣st)=ε∣A(s)∣otherwise\pi_{t+1}(a|s_t) = \frac{\varepsilon}{|\mathcal{A}(s)|} \quad \text{otherwise}πt+1(ast)=A(s)εotherwise

Sarsa with function approximation

  • 公式:

    wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt).w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t).wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)q^(st,at,wt)]wq^(st,at,wt).

  • 含义:

    • 这里 q^(s,a,w)\hat q(s, a, w)q^(s,a,w)近似的 action-value function,用参数 www 来表示(比如线性函数或神经网络)。

    • TD target 使用的是 实际执行的下一步动作 at+1a_{t+1}at+1

      rt+1+γq^(st+1,at+1,wt)r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t)rt+1+γq^(st+1,at+1,wt)

    • 这意味着 Sarsa 是 on-policy 算法

      • 行为策略 π\piπ 既用于生成数据(选择 ata_tatat+1a_{t+1}at+1),
      • 也用于更新价值函数。
    • 更新方向由 TD error 决定:

      δt=rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)\delta_t = r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t)δt=rt+1+γq^(st+1,at+1,wt)q^(st,at,wt)

      • 然后对当前 q^(st,at,wt)\hat q(s_t, a_t, w_t)q^(st,at,wt) 的参数做梯度修正。

Q-learning with function approximation

Q-learning algorithm

  • Similar to Sarsa, tabular Q-learning can also be extended to the case of value function approximation.

  • The q-value update rule is

    wt+1=wt+αt[rt+1+γmax⁡a∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt),w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t),wt+1=wt+αt[rt+1+γmaxaA(st+1)q^(st+1,a,wt)q^(st,at,wt)]wq^(st,at,wt),

  • which is the same as Sarsa except that q^(st+1,at+1,wt)\hat q(s_{t+1}, a_{t+1}, w_t)q^(st+1,at+1,wt) is replaced by max⁡a∈A(st+1)q^(st+1,a,wt)\max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t)maxaA(st+1)q^(st+1,a,wt).

Pseudocode: Q-learning with function approximation (on-policy version)

  • Initialization: Initial parameter vector w0w_0w0. Initial policy π0\pi_0π0. Small ε>0\varepsilon > 0ε>0.
  • Aim: Search a good policy that can lead the agent to the target from an initial state-action pair (s0,a0)(s_0, a_0)(s0,a0).
  • For each episode, do
    • If the current sts_tst is not the target state, do
      • Take action ata_tat following πt(st)\pi_t(s_t)πt(st), and generate rt+1,st+1r_{t+1}, s_{t+1}rt+1,st+1

      • Value update (parameter update):

        wt+1=wt+αt[rt+1+γmax⁡a∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt)w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t)wt+1=wt+αt[rt+1+γmaxaA(st+1)q^(st+1,a,wt)q^(st,at,wt)]wq^(st,at,wt)

      • Policy update:

        πt+1(a∣st)=1−ε∣A(s)∣(∣A(s)∣−1)if a=arg⁡max⁡a∈A(st)q^(st,a,wt+1)\pi_{t+1}(a|s_t) = 1 - \frac{\varepsilon}{|\mathcal{A}(s)|} (|\mathcal{A}(s)| - 1) \quad \text{if } a = \arg\max_{a \in \mathcal{A}(s_t)} \hat q(s_t, a, w_{t+1})πt+1(ast)=1A(s)ε(A(s)1)if a=argmaxaA(st)q^(st,a,wt+1)

        πt+1(a∣st)=ε∣A(s)∣otherwise\pi_{t+1}(a|s_t) = \frac{\varepsilon}{|\mathcal{A}(s)|} \quad \text{otherwise}πt+1(ast)=A(s)εotherwise

Q-learning with function approximation

  • 公式:

    wt+1=wt+αt[rt+1+γmax⁡a∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt).w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t).wt+1=wt+αt[rt+1+γmaxaA(st+1)q^(st+1,a,wt)q^(st,at,wt)]wq^(st,at,wt).

  • 含义:

    • 同样 q^(s,a,w)\hat q(s, a, w)q^(s,a,w) 是参数化的 Q 函数。

    • TD target 使用的是 下一状态所有可能动作的最大值

      rt+1+γmax⁡a∈A(st+1)q^(st+1,a,wt)r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t)rt+1+γmaxaA(st+1)q^(st+1,a,wt)

    • 这意味着 Q-learning 是 off-policy 算法

      • 行为策略(behavior policy)可以是探索性的,比如 ϵ\epsilonϵ-greedy。
      • 但是更新时假设 agent 永远选择“最优动作”,因为取了 max⁡\maxmax
    • 更新方式仍然基于 TD error:

      δt=rt+1+γmax⁡a∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)\delta_t = r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t)δt=rt+1+γmaxaA(st+1)q^(st+1,a,wt)q^(st,at,wt)

      • 再做参数更新。

Deep Q-learning

Objective function

  • Definition
    • Deep Q-learning aims to minimize the objective function/loss function:

      J(w)=E[(R+γmax⁡a∈A(S’)q^(S’,a,w)−q^(S,A,w))2],J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right],J(w)=E[(R+γmaxaA(S)q^(S,a,w)q^(S,A,w))2],

      • where (S,A,R,S’)(S, A, R, S’)(S,A,R,S) are random variables.
    • This is actually the Bellman optimality error.

    • That is because

      q(s,a)=E[Rt+1+γmax⁡a∈A(St+1)q(St+1,a)∣St=s,At=a],∀s,aq(s, a) = \mathbb{E}\Big[ R_{t+1} + \gamma \max_{a \in \mathcal{A}(S_{t+1})} q(S_{t+1}, a) \,\Big|\, S_t = s, A_t = a \Big], \quad \forall s, aq(s,a)=E[Rt+1+γmaxaA(St+1)q(St+1,a)St=s,At=a],s,a

      • The value of

        R+γmax⁡a∈A(S’)q^(S’,a,w)−q^(S,A,w)R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)R+γmaxaA(S)q^(S,a,w)q^(S,A,w)

      • should be zero in the expectation sense.

  • How to minimize the objective function? Gradient-descent!
    • In this objective function

      J(w)=E[(R+γmax⁡a∈A(S’)q^(S’,a,w)−q^(S,A,w))2],J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right],J(w)=E[(R+γmaxaA(S)q^(S,a,w)q^(S,A,w))2],

      • the parameter w not only appears in q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w) but also in

        y≐R+γmax⁡a∈A(S’)q^(S’,a,w).y \doteq R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w).yR+γmaxaA(S)q^(S,a,w).

    • For the sake of simplicity, we can assume that www in yyy is fixed (at least for a while) when we calculate the gradient.

Deep Q-learning 的目标函数与 Bellman optimality error 的关系

  1. Deep Q-learning 的目标函数

    J(w)=E[(R+γmax⁡a∈A(S’)q^(S’,a,w)−q^(S,A,w))2]J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right]J(w)=E[(R+γmaxaA(S)q^(S,a,w)q^(S,A,w))2]

    • 其中:
      • RRR:当前状态执行动作 AAA 后得到的奖励;
      • S’S’S:下一个状态;
      • q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w):由神经网络(带参数 www)给出的 QQQ 值近似;
      • max⁡a∈A(S’)\max_{a \in \mathcal{A}(S’)}maxaA(S):在下一个状态选择最优动作对应的 QQQ 值。
    • 这个目标函数就是一个 均方误差 (MSE),它度量的是 QQQ 网络的输出 q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w) 与目标值 R+γmax⁡aq^(S’,a,w)R + \gamma \max_{a} \hat{q}(S’, a, w)R+γmaxaq^(S,a,w) 之间的差距。
  2. 为什么它对应 Bellman optimality error?

    • Bellman 最优方程定义了最优 QQQ 值的递推关系:

      q(s,a)=E[Rt+1+γmax⁡a’∈A(St+1)q(St+1,a’)∣St=s,At=a]q(s, a) = \mathbb{E}\Big[ R_{t+1} + \gamma \max_{a’ \in \mathcal{A}(S_{t+1})} q(S_{t+1}, a’) \,\Big|\, S_t = s, A_t = a \Big]q(s,a)=E[Rt+1+γmaxaA(St+1)q(St+1,a)St=s,At=a]

    • 换句话说,如果 q^\hat{q}q^ 是最优的,那么:

      R+γmax⁡aq^(S’,a,w)−q^(S,A,w)=0R + \gamma \max_{a} \hat{q}(S’, a, w) - \hat{q}(S, A, w) = 0R+γmaxaq^(S,a,w)q^(S,A,w)=0

      • 在期望意义下应该完全成立。
    • 但是实际中 q^\hat{q}q^ 是近似函数,所以它并不能完全满足 Bellman 方程。

      • 于是我们就把这个残差(差距)定义为 Bellman optimality error,并通过最小化它来逼近最优 Q 值。
  3. 为什么优化比较 tricky?

    • 在损失函数

      J(w)=E[(R+γmax⁡aq^(S’,a,w)−q^(S,A,w))2]J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right]J(w)=E[(R+γmaxaq^(S,a,w)q^(S,A,w))2]

      • 里面,参数 www 既出现在当前的 q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w),也出现在目标值 R+γmax⁡aq^(S’,a,w)R + \gamma \max_{a} \hat{q}(S’, a, w)R+γmaxaq^(S,a,w) 中。
      • 这就导致梯度计算比较复杂,因为我们同时要对 预测值目标值 求梯度。
    • 为简化计算,DQN 通常采用 固定目标网络(target network) 的方法:

      • 在一段时间内,把目标部分 R+γmax⁡aq^(S’,a,w)R + \gamma \max_{a} \hat{q}(S’, a, w)R+γmaxaq^(S,a,w) 的参数 www 固定;
      • 只更新当前 Q 网络的参数。
    • 这样就可以避免梯度传播的复杂性。

Two networks

  • Introduction
    • One is a main network representing q^(s,a,w)\hat q(s,a,w)q^(s,a,w)

    • The other is a target network q^(s,a,wT)\hat q(s,a,w_T)q^(s,a,wT).

    • The objective function in this case degenerates to

      J=E[(R+γmax⁡a∈A(S’)q^(S’,a,wT)−q^(S,A,w))2],J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)^2\Big],J=E[(R+γmaxaA(S)q^(S,a,wT)q^(S,A,w))2],

      • where wTw_TwT is the target network parameter.
  • Gradient with fixed target network
    • When wTw_TwT is fixed, the gradient of JJJ can be easily obtained as

      ∇wJ=E[(R+γmax⁡a∈A(S’)q^(S’,a,wT)−q^(S,A,w))∇wq^(S,A,w)].\nabla_w J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)\nabla_w \hat q(S,A,w)\Big].wJ=E[(R+γmaxaA(S)q^(S,a,wT)q^(S,A,w))wq^(S,A,w)].

    • The basic idea of deep Q-learning is to use the gradient-descent algorithm to minimize the objective function.

在 DQN 里,如果只用一个网络 q^(s,a,w)\hat q(s,a,w)q^(s,a,w) 来估计 QQQ 值并同时更新参数,会遇到 训练不稳定 的问题。原因是目标值(TD target)和估计值(prediction)都依赖于同一个网络,参数更新会相互干扰。

解决办法:

  • 引入 两个网络
    1. Main networkq^(s,a,w)\hat q(s,a,w)q^(s,a,w),用来学习和更新参数。
    2. Target networkq^(s,a,wT)\hat q(s,a,w_T)q^(s,a,wT),用来生成相对稳定的目标值。
      • 参数 wTw_TwT 会定期从 www 同步(例如每隔 CCC 步复制一次)。
  • 这样目标值不会在每一步都随 www 的更新而改变,从而降低训练震荡。

目标函数:

J=E[(R+γmax⁡a∈A(S’)q^(S’,a,wT)−q^(S,A,w))2]J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)^2\Big]J=E[(R+γmaxaA(S)q^(S,a,wT)q^(S,A,w))2]

  • 当前 Q 值估计:q^(S,A,w)\hat q(S,A,w)q^(S,A,w)
  • 目标 Q 值(TD target):R+γmax⁡a∈A(S’)q^(S’,a,wT)R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T)R+γmaxaA(S)q^(S,a,wT)

梯度下降更新:

∇wJ=E[(R+γmax⁡a∈A(S’)q^(S’,a,wT)−q^(S,A,w))∇wq^(S,A,w)]\nabla_w J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)\nabla_w \hat q(S,A,w)\Big]wJ=E[(R+γmaxaA(S)q^(S,a,wT)q^(S,A,w))wq^(S,A,w)]

  • wTw_TwT 固定时,梯度计算非常清晰,不会被目标值同时更新而扰动。

总结

  • Two networks(Main & Target)解决了 目标值不稳定 的问题。

Two techniques

  1. First technique: Two networks, a main network and a target network
    • Why is it used?
      • The mathematical reason has been explained when we calculate the gradient.
    • Implementation details:
      • Let www and wTw_TwT denote the parameters of the main and target networks, respectively. They are set to be the same initially.
      • In every iteration, we draw a mini-batch of samples (s,a,r,s’){(s,a,r,s’)}(s,a,r,s) from the replay buffer (will be explained later).
      • The inputs of the networks include state sss and action aaa.
        • The target output is

          yT≐r+γmax⁡a∈A(s’)q^(s’,a,wT).y_T \doteq r + \gamma \max_{a\in \mathcal{A}(s’)} \hat q(s’,a,w_T).yTr+γmaxaA(s)q^(s,a,wT).

        • Then, we directly minimize the TD error or called loss function

          (yT−q^(s,a,w))2(y_T - \hat q(s,a,w))^2(yTq^(s,a,w))2

          • over the mini-batch (s,a,yT){(s,a,y_T)}(s,a,yT).
  2. Another technique: Experience replay
    • Question: What is experience replay?

    • Answer:

      • After we have collected some experience samples, we do NOT use these samples in the order they were collected.
      • Instead, we store them in a set, called replay buffer B≐(s,a,r,s’)\mathcal{B} \doteq {(s,a,r,s’)}B(s,a,r,s).
      • Every time we train the neural network, we can draw a mini-batch of random samples from the replay buffer.
      • The draw of samples, or called experience replay, should follow a uniform distribution (why?).
    • Question: Why is experience replay necessary in deep Q-learning? Why does the replay must follow a uniform distribution?

    • Answer: The answers lie in the objective function.

      J=E[(R+γmax⁡a∈A(S’)q^(S’,a,w)−q^(S,A,w))2]J = \mathbb{E}\left[ \left( R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w) \right)^2 \right]J=E[(R+γmaxaA(S)q^(S,a,w)q^(S,A,w))2]

      • (S,A)∼d:(S,A)(S, A) \sim d: (S, A)(S,A)d:(S,A) is an index and treated as a single random variable
      • R∼p(R∣S,A),S’∼p(S’∣S,A):RR \sim p(R|S,A), S’ \sim p(S’|S,A): RRp(RS,A),Sp(S’∣S,A):R and SSS are determined by the system model.
      • The distribution of the state-action pair (S,A)(S, A)(S,A) is assumed to be uniform.
      • However, the samples are not uniformly collected because they are generated consequently by certain policies.
      • To break the correlation between consequent samples, we can use the experience replay technique by uniformly drawing samples from the replay buffer.
      • This is the mathematical reason why experience replay is necessary and why the experience replay must be uniform.

Experience replay (经验回放)

问题:

  • 如果我们按顺序使用交互数据来更新网络,样本之间是强相关的(例如 st,st+1,st+2s_t, s_{t+1}, s_{t+2}st,st+1,st+2),不符合随机采样的假设,会导致训练不稳定甚至发散。

解决办法:

  • 引入 Replay Buffer B\mathcal{B}B 来存储过往的经验 (s,a,r,s’)(s,a,r,s’)(s,a,r,s)
  • 每次训练时,不是直接用最近的数据,而是 随机抽取一个 mini-batch 来打破样本相关性。

数学解释:

  • 目标函数为:

    J=E[(R+γmax⁡a∈A(S’)q^(S’,a,w)−q^(S,A,w))2]J = \mathbb{E}\left[ \left( R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w) \right)^2 \right]J=E[(R+γmaxaA(S)q^(S,a,w)q^(S,A,w))2]

    • (S,A)∼d(S,A) \sim d(S,A)d: (S,A)(S,A)(S,A) 被看作一个随机变量
    • 理论上,(S,A)(S,A)(S,A) 的分布应该是 均匀的
    • 但实际收集的数据由当前策略产生,不是均匀分布的(可能更集中在某些区域)

经验回放的作用:

  1. 打破样本相关性(避免梯度更新时出现偏差)。
  2. 近似均匀采样,使得 (S,A)(S,A)(S,A) 的经验分布接近理论假设的均匀分布。
  3. 提高数据利用率(同一个样本可以被多次使用)。

总结

  • Experience replay 解决了 样本相关性与分布偏差 的问题。

Revisit the tabular case:

  • Question: Why does not tabular Q-learning require experience replay?

    • Answer: No uniform distribution requirement.
  • Question: Why Deep Q-learning involves distribution?

    • Answer: The objective function in the deep case is a scalar average over all (S,A)(S, A)(S,A).

      The tabular case does not involve any distribution of SSS or AAA.

      The algorithm in the tabular case aims to solve a set of equations for all (s,a)(s,a)(s,a) (Bellman optimality equation).

  • Question: Can we use experience replay in tabular Q-learning?

    • Answer: Yes, we can. And more sample efficient (why?).

为什么 tabular Q-learning 不需要经验回放,而 deep Q-learning 需要

  1. Tabular Q-learning 的特点
    • 存储方式:每个状态-动作对 (s,a)(s,a)(s,a) 都有一个对应的 QQQ 值表项 Q(s,a)Q(s,a)Q(s,a)

    • 更新方式:更新是 局部的,只影响当前的 (s,a)(s,a)(s,a)

      Q(s,a)←Q(s,a)+α[r+γmax⁡a’Q(s’,a’)−Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha \Big[ r + \gamma \max_{a’} Q(s’,a’) - Q(s,a) \Big]Q(s,a)Q(s,a)+α[r+γmaxaQ(s,a)Q(s,a)]

    • 无分布要求

      • 在表格方法中,我们实际上是在“解方程”(Bellman 方程组),只要所有 (s,a)(s,a)(s,a) 都被访问到,无论采样分布是否均匀,最终都能收敛到最优解 Q∗Q^*Q
      • 因此,不需要对采样分布做均匀化的要求,也就不需要经验回放来打破采样的相关性。
  2. Deep Q-learning 的特点
    • 存储方式QQQ 值不是用表格存储,而是用 神经网络近似

      Q(s,a;w)≈Q∗(s,a)Q(s,a;w) \approx Q^*(s,a)Q(s,a;w)Q(s,a)

      • 参数 www 是共享的,因此一次更新会影响 所有 (s,a)(s,a)(s,a) 的估计,而不是仅仅一个表项。
    • 目标函数

      • 深度 Q 学习的目标函数是一个 均方误差 (MSE)

        J(w)=E[(r+γmax⁡a’Q(s’,a’;w)−Q(s,a;w))2]J(w) = \mathbb{E}\Big[\big(r + \gamma \max_{a’} Q(s’,a’;w) - Q(s,a;w)\big)^2\Big]J(w)=E[(r+γmaxaQ(s,a;w)Q(s,a;w))2]

      • 注意:这里的期望是对 状态-动作对 (s,a)(s,a)(s,a) 的分布 取的。

    • 分布问题

      • 如果训练样本高度相关(例如连续从同一个 episode 采样),网络会过拟合局部轨迹,梯度估计偏差很大。
      • 目标函数隐含假设 (S,A)(S,A)(S,A) 是独立同分布 (i.i.d.),但实际 RL 环境中采样是序列相关的。
  3. 为什么 Deep Q-learning 需要经验回放
    • 经验回放 (Experience Replay) 做了两件事:
      1. 打破相关性:从 replay buffer 中均匀采样,打乱序列相关性,近似满足 i.i.d. 假设。
      2. 提高样本利用率:同一个样本可以被多次采样更新,而不是一次性丢弃。
    • 数学解释
      • 如果我们不使用经验回放,那么在计算期望时:

        (S,A)∼dπ(S,A) \sim d_\pi(S,A)dπ

      • 这个分布 dπd_\pidπ 会强烈依赖于当前策略和轨迹,导致梯度估计不稳定。

      • 使用 replay buffer 并均匀采样后,可以近似模拟出一个 接近均匀的采样分布,稳定训练。

  4. 对比总结
    • Tabular Q-learning:更新是局部的,不需要采样分布均匀性,只要覆盖所有 (s,a)(s,a)(s,a),就能收敛。
    • Deep Q-learning:更新是全局的,依赖于目标函数的期望分布,需要经验回放来保证样本分布近似均匀,避免梯度偏差。

Pseudocode: Deep Q-learning (off-policy version)

  • Aim: Learn an optimal target network to approximate the optimal action values from the experience samples generated by a behavior policy πb\pi_bπb.
  • Store the experience samples generated by πb\pi_bπb in a replay buffer B={(s,a,r,s’)}\mathcal{B} = \{(s,a,r,s’)\}B={(s,a,r,s)}
    • For each iteration, do

      • Uniformly draw a mini-batch of samples from B\mathcal{B}B

      • For each sample (s,a,r,s’)(s,a,r,s’)(s,a,r,s), calculate the target value as

        yT=r+γmax⁡a∈A(s’)q^(s’,a,wT),y_T = r + \gamma \max_{a \in \mathcal{A}(s’)} \hat q(s’,a,w_T),yT=r+γmaxaA(s)q^(s,a,wT),

        • where wTw_TwT is the parameter of the target network
    • Update the main network to minimize

      (yT−q^(s,a,w))2(y_T - \hat q(s,a,w))^2(yTq^(s,a,w))2

      • using the mini-batch {(s,a,yT)}\{(s,a,y_T)\}{(s,a,yT)}
    • Set wT=ww_T = wwT=w every CCC iterations


总结

从 表格方法 (Tabular Q-learning) 到 函数逼近 (Sarsa/Q-learning with function approximation),再到 深度强化学习 (DQN),核心都是最小化 Bellman 误差,区别在于:

  • 表格方法直接解方程,不依赖样本分布;
  • 函数逼近需要引入梯度下降与投影;
  • 深度 Q-learning 则通过 目标网络 (Target Network) 和 经验回放 (Experience Replay) 稳定训练神经网络近似器。

文章转载自:

http://hQXqme54.thrcj.cn
http://ox7G1MaU.thrcj.cn
http://lcOA7gXt.thrcj.cn
http://MVhgHTzU.thrcj.cn
http://cFkiSu9H.thrcj.cn
http://EHsIjI9S.thrcj.cn
http://yKsSAQlc.thrcj.cn
http://YNCeKfel.thrcj.cn
http://msxa1wJL.thrcj.cn
http://d1f1rgmy.thrcj.cn
http://gDSeZFnm.thrcj.cn
http://6YAiMQHG.thrcj.cn
http://2y1LiEBY.thrcj.cn
http://bMpnu0JP.thrcj.cn
http://sBAgOKT1.thrcj.cn
http://gmle3flB.thrcj.cn
http://7SHHt6mR.thrcj.cn
http://PA3A8amM.thrcj.cn
http://0oWKJ0nj.thrcj.cn
http://zlGqrvqh.thrcj.cn
http://ngR2dmeY.thrcj.cn
http://eDxT0eCF.thrcj.cn
http://6obrMlVM.thrcj.cn
http://7vQBMDGJ.thrcj.cn
http://09rMEQs2.thrcj.cn
http://IURWmdy3.thrcj.cn
http://czhuVeEr.thrcj.cn
http://0Agx831X.thrcj.cn
http://gWqUIjjn.thrcj.cn
http://UaKqJR1i.thrcj.cn
http://www.dtcms.com/a/378822.html

相关文章:

  • StringJoiner
  • 【知识堂】制造业与物流数字化全景图:系统缩写大全与专业名词速查手册
  • 项目1——单片机程序审查,控制系统安全漏洞分析和改进建议
  • 中断上半部与中断下半部
  • 吱吱企业即时通讯以安全为基,重塑安全办公新体验
  • ctfshow_web13-----------文件上传.user.ini
  • 112. 路径总和
  • 四,基础开发工具(下)
  • Docker+jenkinsPipeline 运行实现python自动化测试
  • Android图案解锁绘制
  • 分布式事务性能优化:从故障现场到方案落地的实战手记(一)
  • JVM第一部分
  • websocket和socket io的区别
  • codebuddy ai cli安装教程
  • MySQL5.7.44保姆级安装教程
  • 正则表达式基础
  • 如何解决pip安装报错ModuleNotFoundError: No module named ‘pandas-profiling’问题
  • GRPOConfig中参数num_generations
  • 电源线束选型
  • 系统稳定性保障:研发规约V1.0
  • Day13 | Java多态详解
  • hbuilderx配置微信小程序开发环境
  • opc ua c#订阅报错【记录】
  • Caffeine 本地缓存最佳实践与性能优化指南
  • MySQL 高级特性与性能优化:深入理解函数、视图、存储过程、触发器
  • Java常见排序算法实现
  • 生产环境禁用AI框架工具回调:安全风险与最佳实践
  • Git - Difftool
  • leetcode28( 汇总区间)
  • 直击3D内容创作痛点-火山引擎多媒体实验室首次主持SIGGRAPH Workshop,用前沿技术降低沉浸式内容生成门槛