All Lessons

Kullback–Leibler (KL) Divergence and its Role in Constraining Policy Updates

An exploration of how relative entropy serves as a critical regularization mechanism in reinforcement learning. This lesson details the transition from vanilla policy gradients to trust-region methods.

AI Narration Press play to listen
0  / 6 paragraphs
Click any paragraph to jump · Scroll freely without breaking narration

At its core, the Kullback–Leibler (KL) divergence is a measure of how one probability distribution differs from a second, reference probability distribution. In the context of Machine Learning, we often think of it as the 'information gain' achieved when moving from a prior belief to a posterior one. Intuitively, if we have a distribution $P$ that represents the true underlying physics of an environment and a distribution $Q$ that represents our current agent's policy, the KL divergence tells us how much 'surprise' we encounter when we use $Q$ to approximate $P$. Unlike a standard distance metric, KL divergence is asymmetric: $D_{KL}(P \| Q) \\≠ D_{KL}(Q \| P)$, meaning the cost of approximating $P$ with $Q$ is not the same as approximating $Q$ with $P$.

Mathematically, for discrete probability distributions, the KL divergence is defined as the expected value of the logarithmic difference between the two distributions. For distributions $P$ and $Q$ over a space $\mathcal{X}$, it is expressed as: $$D_{KL}(P \| Q) = \sum_{x \\∈ \mathcal{X}} P(x) \log \frac{P(x)}{Q(x)}$$ In the continuous case, the summation is replaced by an integral: $$D_{KL}(P \| Q) = \\∈t_{-\\∈fty}^{\\∈fty} p(x) \log \frac{p(x)}{q(x)} dx$$ Since $\log(A/B) = \log A - \log B$, we can see that the KL divergence is equivalent to the difference between the cross-entropy of $P$ and $Q$ and the entropy of $P$ itself: $D_{KL}(P \| Q) = H(P, Q) - H(P)$.

In Reinforcement Learning (RL), we optimize a policy $\pi_{\theta}(a|s)$ to maximize the expected return. A naive approach is the Policy Gradient method, where we update parameters via $\theta_{t+1} = \theta_t + \alpha \nabla J(\theta)$. However, this often leads to the 'collapse' of the policy. Because the gradient is estimated from samples, a single large update in parameter space can lead to a catastrophic drop in performance. This'catastrophic forgetting' occurs because a small change in $\theta$ can cause a massive change in the resulting distribution $\pi_{\theta}$, pushing the agent into a region of the state space from which it cannot recover.

To solve this, we introduce KL divergence as a constraint on the update. Instead of constraining the change in the parameters $\theta$ (which is often meaningless since the mapping from $\theta$ to the distribution is non-linear), we constrain the change in the distribution itself. We seek to find a new policy $\pi_{\theta'}$ that maximizes the objective $J(\theta')$ such that the KL divergence between the old policy $\pi_{\theta}$ and the new policy $\pi_{\theta'}$ remains below a threshold $\delta$: $$\\max_{\theta'} J(\theta') \quad ext{subject to} \quad D_{KL}(\pi_{\theta}(\\·|s) \| \pi_{\theta'}(\\·|s)) \\≤ \delta$$ This ensures that the 'step size' is measured in the space of probability distributions, effectively creating a 'Trust Region'.

The most prominent implementation of this concept is Trust Region Policy Optimization (TRPO). TRPO utilizes a second-order approximation of the objective and a first-order approximation of the KL constraint. By using the Fisher Information Matrix $F$, which is the Hessian of the KL divergence, TRPO solves the constrained optimization problem. The update rule approximately follows: $$\theta_{t+1} = \theta_t + \sqrt{\frac{2\delta}{F}} \frac{\nabla J(\theta)}{\sqrt{F} \nabla J(\theta)}$$ This ensures that the update is invariant to the parametrization of the policy, providing a stable and monotonic improvement guarantee.

Finally, the Practical implementation of this constraint evolved into Proximal Policy Optimization (PPO). PPO avoids the heavy computation of the Fisher Information Matrix by using a clipped surrogate objective. While not a hard KL constraint, PPO effectively penalizes updates that push the ratio $r_t(\theta) = \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{old}}(a|s)}$ too far from 1. This mimics the behavior of the KL constraint by limiting the divergence between the current and previous policies, thereby maintaining stability while significantly reducing computational overhead.