RL【8】:Value Function Approximation
系列文章目录
Fundamental Tools
RL【1】:Basic Concepts
RL【2】:Bellman Equation
RL【3】:Bellman Optimality Equation
Algorithm
RL【4】:Value Iteration and Policy Iteration
RL【5】:Monte Carlo Learning
RL【6】:Stochastic Approximation and Stochastic Gradient Descent
Method
RL【7-1】:Temporal-difference Learning
RL【7-2】:Temporal-difference Learning
文章目录
- 系列文章目录
- Fundamental Tools
- Algorithm
- Method
- 前言
- Algorithm for state value estimation
- Objective function
- Optimization algorithms
- Selection of function approximators
- Theoretical analysis
- Sarsa & Q-learning with function approximation
- Sarsa with function approximation
- Q-learning with function approximation
- Deep Q-learning
- 总结
前言
本系列文章主要用于记录 B站 赵世钰老师的【强化学习的数学原理】的学习笔记,关于赵老师课程的具体内容,可以移步:
B站视频:【【强化学习的数学原理】课程:从零开始到透彻理解(完结)】
GitHub 课程资料:Book-Mathematical-Foundation-of-Reinforcement-Learning
Algorithm for state value estimation
Objective function
Formal introduction
- Let vπ(s)v_\pi(s)vπ(s) and v^(s,w)\hat v(s,w)v^(s,w) be the true state value and a function for approximation.
- Our goal is to find an optimal w so that v^(s,w)\hat v(s,w)v^(s,w) can best approximate vπ(s)v_\pi(s)vπ(s) for every sss.
- This is a policy evaluation problem. Later we will extend to policy improvement.
- To find the optimal www, we need two steps.
- The first step is to define an objective function.
- The second step is to derive algorithms optimizing the objective function.
问题背景:函数近似的价值估计
- 在实际强化学习中,状态空间可能非常大(甚至连续),没法为每个状态单独存储一个 vπ(s)v_\pi(s)vπ(s)。
- 因此我们需要用一个函数近似器(比如线性函数、神经网络)v^(s,w)\hat v(s,w)v^(s,w) 来逼近真实的 vπ(s)v_\pi(s)vπ(s)。
- 目标:找到一组参数 www,使得 v^(s,w)\hat v(s,w)v^(s,w) 尽可能接近 vπ(s)v_\pi(s)vπ(s)。
Objective function
J(w)=E[(vπ(S)−v^(S,w))2].J(w)=\mathbb{E}\!\left[(\,v_\pi(S)-\hat v(S,w)\,)^2\right].J(w)=E[(vπ(S)−v^(S,w))2].
- Our goal is to find the best w that can minimize J(w)J(w)J(w).
- The expectation is with respect to the random variable S∈SS\in\mathcal{S}S∈S.
Several ways to define the probability distribution of SSS
- The first way is to use a uniform distribution
-
That is to treat all the states to be equally important by setting the probability of each state as 1/∣S∣1/|\mathcal{S}|1/∣S∣.
-
In this case, the objective function becomes
J(w)=E[(vπ(S)−v^(S,w))2]=1∣S∣∑s∈S(vπ(s)−v^(s,w))2.J(w) = \mathbb{E}[(v_\pi(S) - \hat v(S, w))^2] = \frac{1}{|\mathcal{S}|} \sum_{s \in \mathcal{S}} (v_\pi(s) - \hat v(s, w))^2.J(w)=E[(vπ(S)−v^(S,w))2]=∣S∣1∑s∈S(vπ(s)−v^(s,w))2.
-
Drawback:
- The states may not be equally important. For example, some states may be rarely visited by a policy. Hence, this way does not consider the real dynamics of the Markov process under the given policy.
-
- The second way is to use the stationary distribution
-
Stationary distribution is an important concept that will be frequently used in this course. In short, it describes the long-run behavior of a Markov process.
-
Let dπ(s)s∈S{d_\pi(s)}{s \in \mathcal{S}}dπ(s)s∈S denote the stationary distribution of the Markov process under policy π\piπ. By definition, dπ(s)≥0d\pi(s) \ge 0dπ(s)≥0 and ∑s∈Sdπ(s)=1\sum_{s \in \mathcal{S}} d_\pi(s) = 1∑s∈Sdπ(s)=1.
-
The objective function can be rewritten as
J(w)=E[(vπ(S)−v^(S,w))2]=∑s∈Sdπ(s)(vπ(s)−v^(s,w))2.J(w) = \mathbb{E}[(v_\pi(S) - \hat v(S, w))^2] = \sum_{s \in \mathcal{S}} d_\pi(s)(v_\pi(s) - \hat v(s, w))^2.J(w)=E[(vπ(S)−v^(S,w))2]=∑s∈Sdπ(s)(vπ(s)−v^(s,w))2.
-
This function is a weighted squared error.
-
Since more frequently visited states have higher values of dπ(s)d_\pi(s)dπ(s), their weights in the objective function are also higher than those rarely visited states.
-
状态分布的选择 —— 两种方式
关键问题:**期望 E\mathbb{E}E 是对哪个分布下的状态 SSS 取的?**这会影响训练出来的近似器“偏向于哪些状态更准确”。
- Uniform distribution
做法:假设所有状态都等重要,给每个状态分配相同概率:
P(S=s)=1∣S∣.P(S=s) = \frac{1}{|\mathcal{S}|}.P(S=s)=∣S∣1.
目标函数变为:
J(w)=1∣S∣∑s∈S(vπ(s)−v^(s,w))2.J(w) = \frac{1}{|\mathcal{S}|}\sum_{s\in\mathcal{S}} \big(v_\pi(s) - \hat v(s,w)\big)^2.J(w)=∣S∣1∑s∈S(vπ(s)−v^(s,w))2.
优点:
- 简单直观,保证所有状态都有“平等对待”。
缺点:
- 不符合实际。现实中某些状态出现得很少(比如游戏里的罕见场景),强行要求对它们也拟合得很好,会浪费模型容量。
- 没有体现马尔可夫过程在策略 π\piπ 下的真实动态。
- Stationary distribution
做法:考虑在策略 π\piπ 下,环境长期运行后,每个状态出现的概率分布 dπ(s)d_\pi(s)dπ(s)(即 stationary distribution)。
目标函数变为:
J(w)=∑s∈Sdπ(s),(vπ(s)−v^(s,w))2.J(w) = \sum_{s \in \mathcal{S}} d_\pi(s),\big(v_\pi(s) - \hat v(s,w)\big)^2.J(w)=∑s∈Sdπ(s),(vπ(s)−v^(s,w))2.
优点:
- 更贴合真实情况,因为智能体在实际运行时,会频繁遇到某些状态而几乎不会遇到另外一些。
- 在这些“高频状态”上的估计更准确,提升实际执行效果。
缺点:
- 低频状态可能拟合得很差,尤其当某些关键状态虽然重要但很少出现时。
Stationary Distribution:
- Distribution: Distribution of the state
- Stationary: Long-run behavior
- Summary: after the agent runs a long time following a policy, the probability that the agent is at any state can be described by this distribution.
stationary distribution(平稳分布/稳态分布) 相关的基本概念
1. Stationary Distribution 的定义
在马尔可夫过程 (Markov Process) 或马尔可夫决策过程 (MDP) 中,智能体随着时间不断转移状态。
Stationary distribution 指的是:当智能体运行足够长时间后,落在每个状态的概率分布。
数学形式:如果 dπ(s)s∈S{d_\pi(s)}{s \in \mathcal{S}}dπ(s)s∈S 是策略 π\piπ 下的 stationary distribution,则有:
dπ(s’)=∑s∈Sdπ(s)∑a∈Aπ(a∣s)P(s’∣s,a),d\pi(s’) = \sum_{s \in \mathcal{S}} d_\pi(s) \sum_{a \in \mathcal{A}} \pi(a|s) P(s’|s,a),dπ(s’)=∑s∈Sdπ(s)∑a∈Aπ(a∣s)P(s’∣s,a),
并且
∑s∈Sdπ(s)=1,dπ(s)≥0.\sum_{s \in \mathcal{S}} d_\pi(s) = 1, \quad d_\pi(s) \ge 0.∑s∈Sdπ(s)=1,dπ(s)≥0.
- 相关的基本概念
- Distribution
- 字面意义:某个变量的概率分布。
- 在这里是 状态分布:即智能体处于每个状态 s∈Ss \in \mathcal{S}s∈S 的概率。
- Stationary
- 指的是 长期稳定 的状态。
- 当时间 t→∞t \to \inftyt→∞ 时,状态分布趋于固定值,不再随着时间波动。
- 也就是说,状态分布收敛到了一个平衡点。
- Steady-state distribution / Limiting distribution
- 同义词:Stationary distribution 也常被称为 steady-state distribution(稳态分布) 或 limiting distribution(极限分布)。
- 强调的是:它是一个长期极限意义下的稳定分布。
- 在强化学习中的意义
- 价值函数逼近
在 近似方法(value function approximation) 中,我们用期望定义目标函数:
J(w)=Es∼dπ[(vπ(s)−v^(s,w))2],J(w) = \mathbb{E}{s \sim d\pi} \big[(v_\pi(s) - \hat v(s,w))^2\big],J(w)=Es∼dπ[(vπ(s)−v^(s,w))2],
- 其中 dπd_\pidπ 就是 stationary distribution。
这样我们对常访问的状态赋予更大权重,更符合策略的实际表现。
- 策略梯度 (Policy Gradient)
在 策略梯度方法 中,性能目标函数通常写为:
J(π)=∑sdπ(s)∑aπ(a∣s)qπ(s,a).J(\pi) = \sum_s d_\pi(s) \sum_a \pi(a|s) q_\pi(s,a).J(π)=∑sdπ(s)∑aπ(a∣s)qπ(s,a).
这里 dπ(s)d_\pi(s)dπ(s) 表示智能体在策略 π\piπ 下,长期处于状态 sss 的概率。
因此,stationary distribution 是策略梯度方法的核心组成部分。
- 直观解释
- 如果智能体执行一个策略 π\piπ 很久以后:
- 常访问的状态在 dπ(s)d_\pi(s)dπ(s) 下概率更高;
- 很少访问的状态概率接近 000。
- 所以 dπ(s)d_\pi(s)dπ(s) 反映了该策略下“现实中真正重要的状态”。
- 总结:
- stationary distribution 描述了策略 π\piπ 下智能体 长期访问状态的概率分布。
- 它也叫 steady-state distribution 或 limiting distribution。
- 在 价值函数逼近 和 策略梯度方法 中起着关键作用,因为它决定了优化时不同状态的重要性权重。
Optimization algorithms
Gradient Descent for Value Function Approximation
-
While we have the objective function, the next step is to optimize it.
-
To minimize the objective function J(w)J(w)J(w), we can use the gradient-descent algorithm:
wk+1=wk−αk∇wJ(wk)w_{k+1} = w_k - \alpha_k \nabla_w J(w_k)wk+1=wk−αk∇wJ(wk)
-
The true gradient is:
∇wJ(w)=∇wE[(vπ(S)−v^(S,w))2]\nabla_w J(w) = \nabla_w \mathbb{E}[(v_\pi(S) - \hat v(S,w))^2]∇wJ(w)=∇wE[(vπ(S)−v^(S,w))2]
=E[∇w(vπ(S)−v^(S,w))2]= \mathbb{E}[\nabla_w (v_\pi(S) - \hat v(S,w))^2]=E[∇w(vπ(S)−v^(S,w))2]
=2E[(vπ(S)−v^(S,w))(−∇wv^(S,w))]= 2\mathbb{E}[(v_\pi(S) - \hat v(S,w))(-\nabla_w \hat v(S,w))]=2E[(vπ(S)−v^(S,w))(−∇wv^(S,w))]
=−2E[(vπ(S)−v^(S,w))∇wv^(S,w)]= -2\mathbb{E}[(v_\pi(S) - \hat v(S,w))\nabla_w \hat v(S,w)]=−2E[(vπ(S)−v^(S,w))∇wv^(S,w)]
-
The true gradient above involves the calculation of an expectation.
Stochastic Gradient
-
We can use the stochastic gradient to replace the true gradient:
wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt),w_{t+1} = w_t + \alpha_t (v_\pi(s_t) - \hat v(s_t, w_t)) \nabla_w \hat v(s_t, w_t),wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt),
- where sts_tst is a sample of SSS. Here, 2αk2\alpha_k2αk is merged to αk\alpha_kαk.
-
This algorithm is not implementable because it requires the true state value vπv_\pivπ, which is the unknown to be estimated.
-
We can replace vπ(st)v_\pi(s_t)vπ(st) with an approximation so that the algorithm is implementable.
Monte Carlo and TD Learning with Function Approximation
-
First, Monte Carlo learning with function approximation
Let gtg_tgt be the discounted return starting from sts_tst in the episode. Then, gtg_tgt can be used to approximate vπ(st)v_\pi(s_t)vπ(st). The algorithm becomes:
wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt).w_{t+1} = w_t + \alpha_t (g_t - \hat v(s_t, w_t)) \nabla_w \hat v(s_t, w_t).wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt).
-
Second, TD learning with function approximation
By the spirit of TD learning, rt+1+γv^(st+1,wt)r_{t+1} + \gamma \hat v(s_{t+1}, w_t)rt+1+γv^(st+1,wt) can be viewed as an approximation of vπ(st)v_\pi(s_t)vπ(st). Then, the algorithm becomes:
wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt).w_{t+1} = w_t + \alpha_t [r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t)] \nabla_w \hat v(s_t, w_t).wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt).
深入解释
为什么要用 Gradient Descent?
我们有一个目标函数:
J(w)=E[(vπ(S)−v^(S,w))2].J(w)=\mathbb{E}[(v_\pi(S)-\hat v(S,w))^2].J(w)=E[(vπ(S)−v^(S,w))2].
目标是最小化 真实状态价值 vπ(S)v_\pi(S)vπ(S) 与 近似函数 v^(S,w)\hat v(S,w)v^(S,w) 的误差。
这相当于一个 回归问题:拟合一个函数 v^\hat vv^ 来逼近真实的 vπv_\pivπ。
于是,可以使用最常见的优化方法 —— 梯度下降 (Gradient Descent):
wk+1=wk−αk∇wJ(wk).w_{k+1} = w_k - \alpha_k \nabla_w J(w_k).wk+1=wk−αk∇wJ(wk).
- 也就是说,每一步更新参数 www,使得 J(w)J(w)J(w) 逐渐减小。
真梯度 (True Gradient) 的含义
通过链式法则展开:
∇wJ(w)=−2E[(vπ(S)−v^(S,w))∇wv^(S,w)].\nabla_w J(w) = -2\mathbb{E}[(v_\pi(S) - \hat v(S,w)) \nabla_w \hat v(S,w)].∇wJ(w)=−2E[(vπ(S)−v^(S,w))∇wv^(S,w)].
解释:
- 误差项 (vπ(S)−v^(S,w))(v_\pi(S) - \hat v(S,w))(vπ(S)−v^(S,w)) 表示预测与真实值之间的差距。
- 乘上 ∇wv^(S,w)\nabla_w \hat v(S,w)∇wv^(S,w),告诉我们 怎样调整参数 www 才能缩小这个差距。
- 负号说明:如果预测比真实值小,就增加 v^\hat vv^;反之减少 v^\hat vv^。
这和标准的监督学习回归完全一致。
为什么要用 Stochastic Gradient?
问题在于:
- 真梯度 ∇wJ(w)\nabla_w J(w)∇wJ(w) 涉及对所有状态 SSS 的期望;
- 这通常不可行,因为状态空间巨大,而且 vπ(S)v_\pi(S)vπ(S) 也未知。
于是用 SGD (Stochastic Gradient Descent) 替代:
wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt),w_{t+1} = w_t + \alpha_t (v_\pi(s_t) - \hat v(s_t,w_t)) \nabla_w \hat v(s_t,w_t),wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt),
- 其中 sts_tst 是一个采样的状态。
- 好处:只需要一个样本就能更新,成本低。
- 问题:它依然需要 vπ(st)v_\pi(s_t)vπ(st),但这正是我们要估计的未知量。
怎么替代 vπ(st)v_\pi(s_t)vπ(st)?
因为 vπ(st)v_\pi(s_t)vπ(st) 无法直接获得,我们需要找到可以近似它的替代量:
- Monte Carlo Learning with Function Approximation
使用 episode 完整回报 gtg_tgt 作为 vπ(st)v_\pi(s_t)vπ(st) 的无偏估计。
更新公式:
wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt).w_{t+1} = w_t + \alpha_t (g_t - \hat v(s_t, w_t)) \nabla_w \hat v(s_t, w_t).wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt).
直观理解:
- gtg_tgt = 从 sts_tst 出发一路走到底的累计奖励。
- 用 gtg_tgt 替代 vπ(st)v_\pi(s_t)vπ(st),再做梯度下降。
- 缺点:要等 整条轨迹结束 才能更新;方差大。
- TD Learning with Function Approximation
使用 TD Target rt+1+γv^(st+1,wt)r_{t+1} + \gamma \hat v(s_{t+1}, w_t)rt+1+γv^(st+1,wt) 来近似 vπ(st)v_\pi(s_t)vπ(st)。
更新公式:
wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt).w_{t+1} = w_t + \alpha_t \big[r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t)\big] \nabla_w \hat v(s_t, w_t).wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt).
直观理解:
- 不用等整条轨迹,只看一步奖励 rt+1r_{t+1}rt+1 加上下一个状态的预测。
- 属于 bootstrapping:用已有估计来辅助更新。
- 优点:在线学习,更新快,方差小;缺点:可能引入偏差。
Pseudocode: TD learning with function approximation
- Initialization: A function v^(s,w)\hat v(s,w)v^(s,w) that is a differentiable in www. Initial parameter w0w_0w0.
- Aim: Approximate the true state values of a given policy π\piπ.
- For each episode generated following the policy π\piπ, do
- For each step (st,rt+1,st+1)(s_t, r_{t+1}, s_{t+1})(st,rt+1,st+1), do
-
In the general case,
wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big] \nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)
-
In the linear case,
wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \phi^T(s_{t+1}) w_t - \phi^T(s_t) w_t \big] \phi(s_t)wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st)
-
- For each step (st,rt+1,st+1)(s_t, r_{t+1}, s_{t+1})(st,rt+1,st+1), do
Selection of function approximators
-
Function selection
-
The first approach, which was widely used before, is to use a linear function
v^(s,w)=ϕT(s)w\hat v(s, w) = \phi^T(s) wv^(s,w)=ϕT(s)w
Here, ϕ(s)\phi(s)ϕ(s) is the feature vector, which can be a polynomial basis, Fourier basis, … .
-
The second approach, which is widely used nowadays, is to use a neural network as a nonlinear function approximator. The input of the NN is the state, the output is v^(s,w)\hat v(s,w)v^(s,w), and the network parameter is www.
-
-
TD-Linear
-
In the linear case where v^(s,w)=ϕT(s)w\hat v(s, w) = \phi^T(s) wv^(s,w)=ϕT(s)w, we have
∇wv^(s,w)=ϕ(s).\nabla_w \hat v(s, w) = \phi(s).∇wv^(s,w)=ϕ(s).
-
Substituting the gradient into the TD algorithm
wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big] \nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)
-
yields
wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st),w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \phi^T(s_{t+1}) w_t - \phi^T(s_t) w_t \big] \phi(s_t),wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st),
-
which is the algorithm of TD learning with linear function approximation (TD-Linear).
-
-
Disadvantages and Advantages of linear function approximation
- Disadvantages of linear function approximation:
- Difficult to select appropriate feature vectors.
- Advantages of linear function approximation:
- The theoretical properties of the TD algorithm in the linear case can be much better understood than in the nonlinear case.
- Linear function approximation is still powerful in the sense that the tabular representation is merely a special case of linear function approximation.
- Disadvantages of linear function approximation:
-
Tabular representation as a special case of linear function approximation
We next show that the tabular representation is a special case of linear function approximation.
-
First, consider the special feature vector for state sss:
ϕ(s)=es∈R∣S∣,\phi(s) = e_s \in \mathbb{R}^{|\mathcal{S}|},ϕ(s)=es∈R∣S∣,
-where ese_ses is a vector with the sss-th entry as 111 and the others as 000.
-
In this case,
v^(s,w)=esTw=w(s),\hat v(s, w) = e_s^T w = w(s),v^(s,w)=esTw=w(s),
- where w(s)w(s)w(s) is the sss-th entry of www.
-
-
Connection with Tabular TD
-
Recall that the TD-Linear algorithm is
wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st).w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \phi^T(s_{t+1}) w_t - \phi^T(s_t) w_t \big] \phi(s_t).wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st).
-
When ϕ(st)=es\phi(s_t) = e_sϕ(st)=es, the above algorithm becomes
wt+1=wt+αt(rt+1+γwt(st+1)−wt(st))est.w_{t+1} = w_t + \alpha_t \big( r_{t+1} + \gamma w_t(s_{t+1}) - w_t(s_t) \big) e_{s_t}.wt+1=wt+αt(rt+1+γwt(st+1)−wt(st))est.
- This is a vector equation that merely updates the sts_tstth entry of wtw_twt.
-
Multiplying estTe_{s_t}^TestT on both sides of the equation gives
wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)−wt(st)),w_{t+1}(s_t) = w_t(s_t) + \alpha_t \big( r_{t+1} + \gamma w_t(s_{t+1}) - w_t(s_t) \big),wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)−wt(st)),
-
which is exactly the tabular TD algorithm.
-
方法选择
- 该用线性还是神经网络?
线性逼近:v^(s,w)=ϕ(s)⊤w\hat v(s,w)=\phi(s)^\top wv^(s,w)=ϕ(s)⊤w
你通过手工设计的特征 ϕ(s)\phi(s)ϕ(s)(多项式、Fourier、tile coding、one-hot…)把状态映射到低维,再学一个权重向量 www。
- 何时优先用:
- 状态空间不大或可良好表征;
- 追求可解释性与收敛保证(尤其 on-policy 情形);
- 算力或数据有限,需要稳健、低方差的学习器。
非线性逼近(NN):直接学 v^(s,w)=NN(s;w)\hat v(s,w)=\text{NN}(s;w)v^(s,w)=NN(s;w)。
- 何时优先用:
- 原始状态是高维/非线性(图像、文本、复杂传感器);
- 目标是端到端的深度 RL;
- 有足够数据与算力,且可以接受训练不稳定时的调参成本。
- TD-Linear 的本质:半梯度 + 投影不动点
线性情形 ∇wv^(s,w)=ϕ(s)\nabla_w \hat v(s,w)=\phi(s)∇wv^(s,w)=ϕ(s)。把它代入 TD(0) 更新:
wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]ϕ(st).w_{t+1}=w_t+\alpha_t\Big[r_{t+1}+\gamma\,\hat v(s_{t+1},w_t)-\hat v(s_t,w_t)\Big]\;\phi(s_t).wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]ϕ(st).
这是半梯度(semi-gradient)方法:把 TD target rt+1+γv^(st+1,wt)r_{t+1}+\gamma \hat v(s_{t+1},w_t)rt+1+γv^(st+1,wt) 当作常数来对 www 求导(不反向传播到下一状态的 v^\hat vv^ 里),否则会得到“全梯度”算法,实践中反而更容易不稳定。
从几何视角看,它在求解投影 Bellman 方程:
Φw≈ΠdπTπ(Φw),\Phi w \approx \Pi_{d_\pi}\, \mathcal T_\pi(\Phi w),Φw≈ΠdπTπ(Φw),
- 其中 Φ\PhiΦ 的列空间是特征张成的函数类,Πdπ\Pi_{d_\pi}Πdπ 是按stationary distribution dπd_\pidπ 的最小二乘投影。
- 含义:环境让你往 Tπv\mathcal T_\pi vTπv 方向走,但你只能留在“可表示”的子空间里,于是把它投影回去。
收敛性(指 on-policy、线性、适当步长):Tπ\mathcal T_\piTπ 是压缩映射,半梯度 TD(0) 在很多条件下收敛到上述投影不动点。
- 表格(tabular)为什么是线性逼近的特例?
取 one-hot 特征:ϕ(s)=es\phi(s)=e_sϕ(s)=es。则
v^(s,w)=es⊤w=w(s),\hat v(s,w)=e_s^\top w = w(s),v^(s,w)=es⊤w=w(s),
- 也就是“每个状态一个参数”。
代回 TD-Linear:
wt+1=wt+αt(rt+1+γwt(st+1)−wt(st))est.w_{t+1}=w_t+\alpha_t\big(r_{t+1}+\gamma w_t(s_{t+1})-w_t(s_t)\big)e_{s_t}.wt+1=wt+αt(rt+1+γwt(st+1)−wt(st))est.
仅 sts_tst 这一维被更新,左乘 est⊤e_{s_t}^\topest⊤ 得
wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)−wt(st)),w_{t+1}(s_t)=w_t(s_t)+\alpha_t\big(r_{t+1}+\gamma w_t(s_{t+1})-w_t(s_t)\big),wt+1(st)=wt(st)+αt(rt+1+γwt(st+1)−wt(st)),
这正是表格 TD(0)。
结论:表格 = 线性逼近 + one-hot 特征。
- 和目标函数 J(w) 的关系
线性 on-policy 情况下,半梯度 TD 并不是直接最小化
J(w)=ES∼dπ[(vπ(S)−v^(S,w))2],J(w)=\mathbb E_{S\sim d_\pi}\!\big[(v_\pi(S)-\hat v(S,w))^2\big],J(w)=ES∼dπ[(vπ(S)−v^(S,w))2],
- 而是逼近投影 Bellman 解。两者在一般情况下并不相同,但在很多问题上,这个解既可计算又效果好。
若你确实想“真最小化” J(w)J(w)J(w),需要能访问 vπv_\pivπ 或用 MC 回报近似它,此时会回到 MC+函数逼近 的更新
wt+1=wt+αt(gt−v^(st,wt))ϕ(st),w_{t+1}=w_t+\alpha_t\,(g_t-\hat v(s_t,w_t))\,\phi(s_t),wt+1=wt+αt(gt−v^(st,wt))ϕ(st),
- 方差更大,但目标一致。
Theoretical analysis
-
The algorithm
wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big] \nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)
does not minimize the following objective function:
J(w)=E[(vπ(S)−v^(S,w))2]J(w) = \mathbb{E}\!\big[ ( v_\pi(S) - \hat v(S, w) )^2 \big]J(w)=E[(vπ(S)−v^(S,w))2]
-
Different objective functions
-
Objective function 1: True value error
$J_E(w) = \mathbb{E}\!\big[ ( v_\pi(S) - \hat v(S, w) )^2 \big]
= | \hat v(w) - v_\pi |^2_D$
-
Objective function 2: Bellman error
JBE(w)=∥v^(w)−(rπ+γPπv^(w))∥D2≐∥v^(w)−Tπ(v^(w))∥D2J_{BE}(w) = \| \hat v(w) - (r_\pi + \gamma P_\pi \hat v(w)) \|^2_D \doteq \| \hat v(w) - T_\pi(\hat v(w)) \|^2_DJBE(w)=∥v^(w)−(rπ+γPπv^(w))∥D2≐∥v^(w)−Tπ(v^(w))∥D2
-
where
Tπ(x)≐rπ+γPπxT_\pi(x) \doteq r_\pi + \gamma P_\pi xTπ(x)≐rπ+γPπx
-
-
Objective function 3: Projected Bellman error
JPBE(w)=∥v^(w)−MTπ(v^(w))∥D2J_{PBE}(w) = \| \hat v(w) - M T_\pi(\hat v(w)) \|^2_DJPBE(w)=∥v^(w)−MTπ(v^(w))∥D2
- where MMM is a projection matrix.
-
算法间的差距
True value error
JE(w)=E[(vπ(S)−v^(S,w))2]=∥v^(w)−vπ∥D2J_E(w) = \mathbb{E}\!\big[(v_\pi(S) - \hat v(S, w))^2\big] = \| \hat v(w) - v_\pi \|^2_DJE(w)=E[(vπ(S)−v^(S,w))2]=∥v^(w)−vπ∥D2
- 含义:直接最小化近似值函数 v^(s,w)\hat v(s,w)v^(s,w) 与真实值函数 vπ(s)v_\pi(s)vπ(s) 的差距。
- 理想目标:这是最自然、最直观的优化目标(类似 supervised learning)。
- 问题:我们 不知道 vπ(s)v_\pi(s)vπ(s),只能通过采样和 Bellman 方程间接近似,因此无法直接最小化这个目标。
Bellman error
JBE(w)=∥v^(w)−Tπ(v^(w))∥D2J_{BE}(w) = \| \hat v(w) - T_\pi(\hat v(w)) \|^2_DJBE(w)=∥v^(w)−Tπ(v^(w))∥D2
其中
Tπ(x)=rπ+γPπxT_\pi(x) = r_\pi + \gamma P_\pi xTπ(x)=rπ+γPπx
含义:衡量 v^(s,w)\hat v(s,w)v^(s,w) 与 Bellman 方程的不一致程度。
- Bellman 方程的固定点就是 vπv_\pivπ。
- 如果 v^(w)\hat v(w)v^(w) 落在“完美的”函数空间里,最小化 Bellman 误差就能得到真实值函数。
- 问题:当函数近似器(如线性函数或神经网络)不能精确表示 vπv_\pivπ 时,直接最小化 Bellman 误差可能会得到“发散”的解。
Projected Bellman error
JPBE(w)=∥v^(w)−MTπ(v^(w))∥D2J_{PBE}(w) = \| \hat v(w) - M T_\pi(\hat v(w)) \|^2_DJPBE(w)=∥v^(w)−MTπ(v^(w))∥D2
- 其中 M 是一个投影矩阵,把 Bellman 更新结果 Tπ(v^)T_\pi(\hat v)Tπ(v^) 投影回函数近似空间。
- 含义:由于函数近似空间(比如线性函数空间)通常比较有限,我们无法保证v^(w)证 \hat v(w)证v^(w) 可以完美满足 Bellman 方程。所以我们要求 投影后的 Bellman 更新结果 尽可能接近 v^(w)\hat v(w)v^(w)。
- 本质:寻找一个在函数空间中“最接近 Bellman 固定点”的近似解。
- 重要性:这是 TD-Linear 实际优化的目标。TD 的更新过程相当于隐式地做了这个投影,因此收敛点是最小化 projected Bellman error 的解,而不是直接的 true value error。
为什么 TD-Linear 对应 projected Bellman error?
TD 更新:
wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1} = w_t + \alpha_t \big[ r_{t+1} + \gamma \hat v(s_{t+1}, w_t) - \hat v(s_t, w_t) \big]\nabla_w \hat v(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)
这里的更新使用的是 TD target rt+1+γv^(st+1,wt)r_{t+1} + \gamma \hat v(s_{t+1}, w_t)rt+1+γv^(st+1,wt),它相当于把 Bellman 更新结果 Tπ(v^)T_\pi(\hat v)Tπ(v^) 投影回函数近似空间。
因此,TD 不会直接最小化 JE(w)J_E(w)JE(w) 或 JBE(w)J_{BE}(w)JBE(w),而是最小化 projected Bellman error。
总结:
- True value error:最理想的目标,但无法直接计算。
- Bellman error:衡量与 Bellman 方程的不一致,但函数逼近时可能不稳定。
- Projected Bellman error:TD 实际最小化的目标,在逼近空间内找到最合理的解,保证了收敛性。
Sarsa & Q-learning with function approximation
Sarsa with function approximation
Sarsa algorithm
-
So far, we merely considered the problem of state value estimation. That is we hope
v^≈vπ\hat v \approx v_\piv^≈vπ
-
To search for optimal policies, we need to estimate action values.
-
The Sarsa algorithm with value function approximation is
wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt).w_{t+1} = w_t + \alpha_t \Big[r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t)\Big] \nabla_w \hat q(s_t, a_t, w_t).wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt).
-
This is the same as the algorithm we introduced previously in this lecture except that v^\hat vv^ is replaced by q^\hat qq^.
Pseudocode: Sarsa with function approximation
- Aim: Search a policy that can lead the agent to the target from an initial state-action pair (s0,a0)(s_0, a_0)(s0,a0).
- For each episode, do
- If the current sts_tst is not the target state, do
-
Take action ata_tat following πt(st)\pi_t(s_t)πt(st), generate rt+1,st+1r_{t+1}, s_{t+1}rt+1,st+1, and then take action at+1a_{t+1}at+1 following πt(st+1)\pi_t(s_{t+1})πt(st+1)
-
Value update (parameter update):
wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt)w_{t+1} = w_t + \alpha_t \Big[r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t)\Big]\nabla_w \hat q(s_t, a_t, w_t)wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt)
-
Policy update:
πt+1(a∣st)=1−ε∣A(s)∣(∣A(s)∣−1)if a=argmaxa∈A(st)q^(st,a,wt+1)\pi_{t+1}(a|s_t) = 1 - \frac{\varepsilon}{|\mathcal{A}(s)|} (|\mathcal{A}(s)| - 1) \quad \text{if } a = \arg\max_{a \in \mathcal{A}(s_t)} \hat q(s_t, a, w_{t+1})πt+1(a∣st)=1−∣A(s)∣ε(∣A(s)∣−1)if a=argmaxa∈A(st)q^(st,a,wt+1)
πt+1(a∣st)=ε∣A(s)∣otherwise\pi_{t+1}(a|s_t) = \frac{\varepsilon}{|\mathcal{A}(s)|} \quad \text{otherwise}πt+1(a∣st)=∣A(s)∣εotherwise
-
- If the current sts_tst is not the target state, do
Sarsa with function approximation
公式:
wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt).w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t).wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)]∇wq^(st,at,wt).
含义:
这里 q^(s,a,w)\hat q(s, a, w)q^(s,a,w) 是 近似的 action-value function,用参数 www 来表示(比如线性函数或神经网络)。
TD target 使用的是 实际执行的下一步动作 at+1a_{t+1}at+1:
rt+1+γq^(st+1,at+1,wt)r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t)rt+1+γq^(st+1,at+1,wt)
这意味着 Sarsa 是 on-policy 算法:
- 行为策略 π\piπ 既用于生成数据(选择 ata_tat,at+1a_{t+1}at+1),
- 也用于更新价值函数。
更新方向由 TD error 决定:
δt=rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)\delta_t = r_{t+1} + \gamma \hat q(s_{t+1}, a_{t+1}, w_t) - \hat q(s_t, a_t, w_t)δt=rt+1+γq^(st+1,at+1,wt)−q^(st,at,wt)
- 然后对当前 q^(st,at,wt)\hat q(s_t, a_t, w_t)q^(st,at,wt) 的参数做梯度修正。
Q-learning with function approximation
Q-learning algorithm
-
Similar to Sarsa, tabular Q-learning can also be extended to the case of value function approximation.
-
The q-value update rule is
wt+1=wt+αt[rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt),w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t),wt+1=wt+αt[rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt),
-
which is the same as Sarsa except that q^(st+1,at+1,wt)\hat q(s_{t+1}, a_{t+1}, w_t)q^(st+1,at+1,wt) is replaced by maxa∈A(st+1)q^(st+1,a,wt)\max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t)maxa∈A(st+1)q^(st+1,a,wt).
Pseudocode: Q-learning with function approximation (on-policy version)
- Initialization: Initial parameter vector w0w_0w0. Initial policy π0\pi_0π0. Small ε>0\varepsilon > 0ε>0.
- Aim: Search a good policy that can lead the agent to the target from an initial state-action pair (s0,a0)(s_0, a_0)(s0,a0).
- For each episode, do
- If the current sts_tst is not the target state, do
-
Take action ata_tat following πt(st)\pi_t(s_t)πt(st), and generate rt+1,st+1r_{t+1}, s_{t+1}rt+1,st+1
-
Value update (parameter update):
wt+1=wt+αt[rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt)w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t)wt+1=wt+αt[rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt)
-
Policy update:
πt+1(a∣st)=1−ε∣A(s)∣(∣A(s)∣−1)if a=argmaxa∈A(st)q^(st,a,wt+1)\pi_{t+1}(a|s_t) = 1 - \frac{\varepsilon}{|\mathcal{A}(s)|} (|\mathcal{A}(s)| - 1) \quad \text{if } a = \arg\max_{a \in \mathcal{A}(s_t)} \hat q(s_t, a, w_{t+1})πt+1(a∣st)=1−∣A(s)∣ε(∣A(s)∣−1)if a=argmaxa∈A(st)q^(st,a,wt+1)
πt+1(a∣st)=ε∣A(s)∣otherwise\pi_{t+1}(a|s_t) = \frac{\varepsilon}{|\mathcal{A}(s)|} \quad \text{otherwise}πt+1(a∣st)=∣A(s)∣εotherwise
-
- If the current sts_tst is not the target state, do
Q-learning with function approximation
公式:
wt+1=wt+αt[rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt).w_{t+1} = w_t + \alpha_t \Big[ r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t) \Big] \nabla_w \hat q(s_t, a_t, w_t).wt+1=wt+αt[rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)]∇wq^(st,at,wt).
含义:
同样 q^(s,a,w)\hat q(s, a, w)q^(s,a,w) 是参数化的 Q 函数。
TD target 使用的是 下一状态所有可能动作的最大值:
rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t)rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)
这意味着 Q-learning 是 off-policy 算法:
- 行为策略(behavior policy)可以是探索性的,比如 ϵ\epsilonϵ-greedy。
- 但是更新时假设 agent 永远选择“最优动作”,因为取了 max\maxmax。
更新方式仍然基于 TD error:
δt=rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)\delta_t = r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat q(s_{t+1}, a, w_t) - \hat q(s_t, a_t, w_t)δt=rt+1+γmaxa∈A(st+1)q^(st+1,a,wt)−q^(st,at,wt)
- 再做参数更新。
Deep Q-learning
Objective function
- Definition
-
Deep Q-learning aims to minimize the objective function/loss function:
J(w)=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2],J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right],J(w)=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2],
- where (S,A,R,S’)(S, A, R, S’)(S,A,R,S’) are random variables.
-
This is actually the Bellman optimality error.
-
That is because
q(s,a)=E[Rt+1+γmaxa∈A(St+1)q(St+1,a)∣St=s,At=a],∀s,aq(s, a) = \mathbb{E}\Big[ R_{t+1} + \gamma \max_{a \in \mathcal{A}(S_{t+1})} q(S_{t+1}, a) \,\Big|\, S_t = s, A_t = a \Big], \quad \forall s, aq(s,a)=E[Rt+1+γmaxa∈A(St+1)q(St+1,a)St=s,At=a],∀s,a
-
The value of
R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w)R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w)
-
should be zero in the expectation sense.
-
-
- How to minimize the objective function? Gradient-descent!
-
In this objective function
J(w)=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2],J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right],J(w)=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2],
-
the parameter w not only appears in q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w) but also in
y≐R+γmaxa∈A(S’)q^(S’,a,w).y \doteq R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w).y≐R+γmaxa∈A(S’)q^(S’,a,w).
-
-
For the sake of simplicity, we can assume that www in yyy is fixed (at least for a while) when we calculate the gradient.
-
Deep Q-learning 的目标函数与 Bellman optimality error 的关系:
Deep Q-learning 的目标函数
J(w)=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2]J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right]J(w)=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2]
- 其中:
- RRR:当前状态执行动作 AAA 后得到的奖励;
- S’S’S’:下一个状态;
- q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w):由神经网络(带参数 www)给出的 QQQ 值近似;
- maxa∈A(S’)\max_{a \in \mathcal{A}(S’)}maxa∈A(S’):在下一个状态选择最优动作对应的 QQQ 值。
- 这个目标函数就是一个 均方误差 (MSE),它度量的是 QQQ 网络的输出 q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w) 与目标值 R+γmaxaq^(S’,a,w)R + \gamma \max_{a} \hat{q}(S’, a, w)R+γmaxaq^(S’,a,w) 之间的差距。
为什么它对应 Bellman optimality error?
Bellman 最优方程定义了最优 QQQ 值的递推关系:
q(s,a)=E[Rt+1+γmaxa’∈A(St+1)q(St+1,a’)∣St=s,At=a]q(s, a) = \mathbb{E}\Big[ R_{t+1} + \gamma \max_{a’ \in \mathcal{A}(S_{t+1})} q(S_{t+1}, a’) \,\Big|\, S_t = s, A_t = a \Big]q(s,a)=E[Rt+1+γmaxa’∈A(St+1)q(St+1,a’)St=s,At=a]
换句话说,如果 q^\hat{q}q^ 是最优的,那么:
R+γmaxaq^(S’,a,w)−q^(S,A,w)=0R + \gamma \max_{a} \hat{q}(S’, a, w) - \hat{q}(S, A, w) = 0R+γmaxaq^(S’,a,w)−q^(S,A,w)=0
- 在期望意义下应该完全成立。
但是实际中 q^\hat{q}q^ 是近似函数,所以它并不能完全满足 Bellman 方程。
- 于是我们就把这个残差(差距)定义为 Bellman optimality error,并通过最小化它来逼近最优 Q 值。
为什么优化比较 tricky?
在损失函数
J(w)=E[(R+γmaxaq^(S’,a,w)−q^(S,A,w))2]J(w) = \mathbb{E}\left[\Big(R + \gamma \max_{a} \hat{q}(S’, a, w) - \hat{q}(S, A, w)\Big)^2\right]J(w)=E[(R+γmaxaq^(S’,a,w)−q^(S,A,w))2]
- 里面,参数 www 既出现在当前的 q^(S,A,w)\hat{q}(S, A, w)q^(S,A,w),也出现在目标值 R+γmaxaq^(S’,a,w)R + \gamma \max_{a} \hat{q}(S’, a, w)R+γmaxaq^(S’,a,w) 中。
- 这就导致梯度计算比较复杂,因为我们同时要对 预测值 和 目标值 求梯度。
为简化计算,DQN 通常采用 固定目标网络(target network) 的方法:
- 在一段时间内,把目标部分 R+γmaxaq^(S’,a,w)R + \gamma \max_{a} \hat{q}(S’, a, w)R+γmaxaq^(S’,a,w) 的参数 www 固定;
- 只更新当前 Q 网络的参数。
这样就可以避免梯度传播的复杂性。
Two networks
- Introduction
-
One is a main network representing q^(s,a,w)\hat q(s,a,w)q^(s,a,w)
-
The other is a target network q^(s,a,wT)\hat q(s,a,w_T)q^(s,a,wT).
-
The objective function in this case degenerates to
J=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))2],J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)^2\Big],J=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))2],
- where wTw_TwT is the target network parameter.
-
- Gradient with fixed target network
-
When wTw_TwT is fixed, the gradient of JJJ can be easily obtained as
∇wJ=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))∇wq^(S,A,w)].\nabla_w J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)\nabla_w \hat q(S,A,w)\Big].∇wJ=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))∇wq^(S,A,w)].
-
The basic idea of deep Q-learning is to use the gradient-descent algorithm to minimize the objective function.
-
在 DQN 里,如果只用一个网络 q^(s,a,w)\hat q(s,a,w)q^(s,a,w) 来估计 QQQ 值并同时更新参数,会遇到 训练不稳定 的问题。原因是目标值(TD target)和估计值(prediction)都依赖于同一个网络,参数更新会相互干扰。
解决办法:
- 引入 两个网络:
- Main network:q^(s,a,w)\hat q(s,a,w)q^(s,a,w),用来学习和更新参数。
- Target network:q^(s,a,wT)\hat q(s,a,w_T)q^(s,a,wT),用来生成相对稳定的目标值。
- 参数 wTw_TwT 会定期从 www 同步(例如每隔 CCC 步复制一次)。
- 这样目标值不会在每一步都随 www 的更新而改变,从而降低训练震荡。
目标函数:
J=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))2]J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)^2\Big]J=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))2]
- 当前 Q 值估计:q^(S,A,w)\hat q(S,A,w)q^(S,A,w)
- 目标 Q 值(TD target):R+γmaxa∈A(S’)q^(S’,a,wT)R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T)R+γmaxa∈A(S’)q^(S’,a,wT)
梯度下降更新:
∇wJ=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))∇wq^(S,A,w)]\nabla_w J = \mathbb{E}\Big[\Big(R+\gamma \max_{a\in \mathcal{A}(S’)} \hat q(S’,a,w_T) - \hat q(S,A,w)\Big)\nabla_w \hat q(S,A,w)\Big]∇wJ=E[(R+γmaxa∈A(S’)q^(S’,a,wT)−q^(S,A,w))∇wq^(S,A,w)]
- 当 wTw_TwT 固定时,梯度计算非常清晰,不会被目标值同时更新而扰动。
总结
- Two networks(Main & Target)解决了 目标值不稳定 的问题。
Two techniques
- First technique: Two networks, a main network and a target network
- Why is it used?
- The mathematical reason has been explained when we calculate the gradient.
- Implementation details:
- Let www and wTw_TwT denote the parameters of the main and target networks, respectively. They are set to be the same initially.
- In every iteration, we draw a mini-batch of samples (s,a,r,s’){(s,a,r,s’)}(s,a,r,s’) from the replay buffer (will be explained later).
- The inputs of the networks include state sss and action aaa.
-
The target output is
yT≐r+γmaxa∈A(s’)q^(s’,a,wT).y_T \doteq r + \gamma \max_{a\in \mathcal{A}(s’)} \hat q(s’,a,w_T).yT≐r+γmaxa∈A(s’)q^(s’,a,wT).
-
Then, we directly minimize the TD error or called loss function
(yT−q^(s,a,w))2(y_T - \hat q(s,a,w))^2(yT−q^(s,a,w))2
- over the mini-batch (s,a,yT){(s,a,y_T)}(s,a,yT).
-
- Why is it used?
- Another technique: Experience replay
-
Question: What is experience replay?
-
Answer:
- After we have collected some experience samples, we do NOT use these samples in the order they were collected.
- Instead, we store them in a set, called replay buffer B≐(s,a,r,s’)\mathcal{B} \doteq {(s,a,r,s’)}B≐(s,a,r,s’).
- Every time we train the neural network, we can draw a mini-batch of random samples from the replay buffer.
- The draw of samples, or called experience replay, should follow a uniform distribution (why?).
-
Question: Why is experience replay necessary in deep Q-learning? Why does the replay must follow a uniform distribution?
-
Answer: The answers lie in the objective function.
J=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2]J = \mathbb{E}\left[ \left( R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w) \right)^2 \right]J=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2]
- (S,A)∼d:(S,A)(S, A) \sim d: (S, A)(S,A)∼d:(S,A) is an index and treated as a single random variable
- R∼p(R∣S,A),S’∼p(S’∣S,A):RR \sim p(R|S,A), S’ \sim p(S’|S,A): RR∼p(R∣S,A),S’∼p(S’∣S,A):R and SSS are determined by the system model.
- The distribution of the state-action pair (S,A)(S, A)(S,A) is assumed to be uniform.
- However, the samples are not uniformly collected because they are generated consequently by certain policies.
- To break the correlation between consequent samples, we can use the experience replay technique by uniformly drawing samples from the replay buffer.
- This is the mathematical reason why experience replay is necessary and why the experience replay must be uniform.
-
Experience replay (经验回放)
问题:
- 如果我们按顺序使用交互数据来更新网络,样本之间是强相关的(例如 st,st+1,st+2s_t, s_{t+1}, s_{t+2}st,st+1,st+2),不符合随机采样的假设,会导致训练不稳定甚至发散。
解决办法:
- 引入 Replay Buffer B\mathcal{B}B 来存储过往的经验 (s,a,r,s’)(s,a,r,s’)(s,a,r,s’)。
- 每次训练时,不是直接用最近的数据,而是 随机抽取一个 mini-batch 来打破样本相关性。
数学解释:
目标函数为:
J=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2]J = \mathbb{E}\left[ \left( R + \gamma \max_{a \in \mathcal{A}(S’)} \hat{q}(S’, a, w) - \hat{q}(S, A, w) \right)^2 \right]J=E[(R+γmaxa∈A(S’)q^(S’,a,w)−q^(S,A,w))2]
- (S,A)∼d(S,A) \sim d(S,A)∼d: (S,A)(S,A)(S,A) 被看作一个随机变量
- 理论上,(S,A)(S,A)(S,A) 的分布应该是 均匀的
- 但实际收集的数据由当前策略产生,不是均匀分布的(可能更集中在某些区域)
经验回放的作用:
- 打破样本相关性(避免梯度更新时出现偏差)。
- 近似均匀采样,使得 (S,A)(S,A)(S,A) 的经验分布接近理论假设的均匀分布。
- 提高数据利用率(同一个样本可以被多次使用)。
总结
- Experience replay 解决了 样本相关性与分布偏差 的问题。
Revisit the tabular case:
-
Question: Why does not tabular Q-learning require experience replay?
- Answer: No uniform distribution requirement.
-
Question: Why Deep Q-learning involves distribution?
-
Answer: The objective function in the deep case is a scalar average over all (S,A)(S, A)(S,A).
The tabular case does not involve any distribution of SSS or AAA.
The algorithm in the tabular case aims to solve a set of equations for all (s,a)(s,a)(s,a) (Bellman optimality equation).
-
-
Question: Can we use experience replay in tabular Q-learning?
- Answer: Yes, we can. And more sample efficient (why?).
为什么 tabular Q-learning 不需要经验回放,而 deep Q-learning 需要
- Tabular Q-learning 的特点
存储方式:每个状态-动作对 (s,a)(s,a)(s,a) 都有一个对应的 QQQ 值表项 Q(s,a)Q(s,a)Q(s,a)。
更新方式:更新是 局部的,只影响当前的 (s,a)(s,a)(s,a):
Q(s,a)←Q(s,a)+α[r+γmaxa’Q(s’,a’)−Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha \Big[ r + \gamma \max_{a’} Q(s’,a’) - Q(s,a) \Big]Q(s,a)←Q(s,a)+α[r+γmaxa’Q(s’,a’)−Q(s,a)]
无分布要求:
- 在表格方法中,我们实际上是在“解方程”(Bellman 方程组),只要所有 (s,a)(s,a)(s,a) 都被访问到,无论采样分布是否均匀,最终都能收敛到最优解 Q∗Q^*Q∗。
- 因此,不需要对采样分布做均匀化的要求,也就不需要经验回放来打破采样的相关性。
- Deep Q-learning 的特点
存储方式:QQQ 值不是用表格存储,而是用 神经网络近似:
Q(s,a;w)≈Q∗(s,a)Q(s,a;w) \approx Q^*(s,a)Q(s,a;w)≈Q∗(s,a)
- 参数 www 是共享的,因此一次更新会影响 所有 (s,a)(s,a)(s,a) 的估计,而不是仅仅一个表项。
目标函数:
深度 Q 学习的目标函数是一个 均方误差 (MSE):
J(w)=E[(r+γmaxa’Q(s’,a’;w)−Q(s,a;w))2]J(w) = \mathbb{E}\Big[\big(r + \gamma \max_{a’} Q(s’,a’;w) - Q(s,a;w)\big)^2\Big]J(w)=E[(r+γmaxa’Q(s’,a’;w)−Q(s,a;w))2]
注意:这里的期望是对 状态-动作对 (s,a)(s,a)(s,a) 的分布 取的。
分布问题:
- 如果训练样本高度相关(例如连续从同一个 episode 采样),网络会过拟合局部轨迹,梯度估计偏差很大。
- 目标函数隐含假设 (S,A)(S,A)(S,A) 是独立同分布 (i.i.d.),但实际 RL 环境中采样是序列相关的。
- 为什么 Deep Q-learning 需要经验回放
- 经验回放 (Experience Replay) 做了两件事:
- 打破相关性:从 replay buffer 中均匀采样,打乱序列相关性,近似满足 i.i.d. 假设。
- 提高样本利用率:同一个样本可以被多次采样更新,而不是一次性丢弃。
- 数学解释:
如果我们不使用经验回放,那么在计算期望时:
(S,A)∼dπ(S,A) \sim d_\pi(S,A)∼dπ
这个分布 dπd_\pidπ 会强烈依赖于当前策略和轨迹,导致梯度估计不稳定。
使用 replay buffer 并均匀采样后,可以近似模拟出一个 接近均匀的采样分布,稳定训练。
- 对比总结
- Tabular Q-learning:更新是局部的,不需要采样分布均匀性,只要覆盖所有 (s,a)(s,a)(s,a),就能收敛。
- Deep Q-learning:更新是全局的,依赖于目标函数的期望分布,需要经验回放来保证样本分布近似均匀,避免梯度偏差。
Pseudocode: Deep Q-learning (off-policy version)
- Aim: Learn an optimal target network to approximate the optimal action values from the experience samples generated by a behavior policy πb\pi_bπb.
- Store the experience samples generated by πb\pi_bπb in a replay buffer B={(s,a,r,s’)}\mathcal{B} = \{(s,a,r,s’)\}B={(s,a,r,s’)}
-
For each iteration, do
-
Uniformly draw a mini-batch of samples from B\mathcal{B}B
-
For each sample (s,a,r,s’)(s,a,r,s’)(s,a,r,s’), calculate the target value as
yT=r+γmaxa∈A(s’)q^(s’,a,wT),y_T = r + \gamma \max_{a \in \mathcal{A}(s’)} \hat q(s’,a,w_T),yT=r+γmaxa∈A(s’)q^(s’,a,wT),
- where wTw_TwT is the parameter of the target network
-
-
Update the main network to minimize
(yT−q^(s,a,w))2(y_T - \hat q(s,a,w))^2(yT−q^(s,a,w))2
- using the mini-batch {(s,a,yT)}\{(s,a,y_T)\}{(s,a,yT)}
-
Set wT=ww_T = wwT=w every CCC iterations
-
总结
从 表格方法 (Tabular Q-learning) 到 函数逼近 (Sarsa/Q-learning with function approximation),再到 深度强化学习 (DQN),核心都是最小化 Bellman 误差,区别在于:
- 表格方法直接解方程,不依赖样本分布;
- 函数逼近需要引入梯度下降与投影;
- 深度 Q-learning 则通过 目标网络 (Target Network) 和 经验回放 (Experience Replay) 稳定训练神经网络近似器。