All Lessons

KL Divergence and Trust Region Constraints in Policy Optimization

An exploration of how Kullback-Leibler divergence prevents catastrophic forgetting and ensures stable convergence in Deep Reinforcement Learning. We analyze the transition from vanilla policy gradients to trust region methods.

AI Narration Press play to listen
0  / 6 paragraphs
Click any paragraph to jump · Scroll freely without breaking narration

In the realm of Reinforcement Learning (RL), we often seek to optimize a policy $\pi_{\theta}$ to maximize the expected cumulative reward. However, the surface of the objective function is often volatile. If we perform a standard gradient ascent update using a large learning rate, the policy may change drastically. In a high-dimensional parameter space, a small step in $\theta$ can lead to a massive shift in the resulting probability distribution $\pi_{\theta}(a|s)$. This 'collapse' often results in a catastrophic drop in performance, as the agent forgets previously learned stable behaviors in favor of a noisy, high-variance update.

To quantify the difference between two probability distributions, we employ the Kullback-Leibler (KL) divergence. Intuitively, KL divergence measures the 'information gain' or the surprise experienced when using distribution $Q$ to approximate distribution $P$. In policy optimization, we use it to measure the distance between the old policy $\pi_{\theta_{old}}$ and the updated policy $\pi_{\theta}$. Unlike the Euclidean distance between parameter vectors $\theta_{old}$ and $\theta$, the KL divergence operates in the space of distributions, making it a far more meaningful metric for stability.

Mathematically, for discrete action spaces, the KL divergence between two distributions $P$ and $Q$ is defined as: $$D_{KL}(P \parallel Q) = \sum_{x \\∈ \\mathcal{X}} P(x) \log \frac{P(x)}{Q(x)}$$ This formula expresses the expected value of the logarithmic difference between the probabilities. It is important to note that KL divergence is asymmetric, meaning $D_{KL}(P \parallel Q) \\≠ D_{KL}(Q \parallel P)$. In RL, we typically measure $D_{KL}(\pi_{\theta_{old}} \parallel \pi_{\theta})$ to ensure the new policy does not deviate too far from the behavior that generated the current data samples.

The primary role of KL divergence in policy updates is to define a 'Trust Region.' Rather than relying on a global learning rate, we constrain the update such that the average KL divergence over the state distribution $ ho_{\pi}$ remains below a threshold $\delta$: $$E_{s \sim \rho_{\pi}} [D_{KL}(\pi_{\theta_{old}}(\\·|s) \parallel \pi_{\theta}(\\·|s))] \le \delta$$ This constraint ensures that the update is 'safe.' By bounding the change in the policy's output, we guarantee that the local approximation of the objective function—derived from sampled trajectories—remains valid for the new policy.

This conceptual framework is most famously realized in Trust Region Policy Optimization (TRPO) and Proximal Policy Optimization (PPO). TRPO solves a constrained optimization problem using the Natural Policy Gradient, where the update is scaled by the inverse of the Fisher Information Matrix $F$, which is the second-order approximation of the KL divergence: $F = \nabla_{\theta}^2 D_{KL}(\pi_{\theta_{old}} \parallel \pi_{\theta})$. By moving in the direction of the natural gradient, the agent makes a constant step in the space of distributions rather than the space of parameters: $\Delta \theta \approx F^{-1} \nabla_{\theta} J(\theta)$.

Modern implementations like PPO simplify this by using a clipped objective or a KL penalty. For example, a KL penalty modifies the loss function to $L(\theta) = E [\hat{A}_t \log \pi_{\theta}(a_t|s_t)] - \beta D_{KL}(\pi_{\theta_{old}} \parallel \pi_{\theta})$, where $\beta$ governs the strength of the constraint. This prevents the policy from changing too abruptly, effectively smoothing the optimization landscape. By constraining the update via KL divergence, we bridge the gap between aggressive exploration and stable exploitation, enabling deep neural networks to converge reliably in complex environments.