Understanding Policy Gradients in Reinforcement Learning

Policy gradient methods are a fundamental class of reinforcement learning algorithms that directly optimize the policy function. Unlike value-based methods that learn a value function and derive a policy from it, policy gradients optimize the policy parameters directly.

The Core Idea

The objective in reinforcement learning is to maximize expected return:

J(θ)=Eτπθ[R(τ)]J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)]

where θ\theta represents policy parameters, τ\tau is a trajectory, and R(τ)R(\tau) is the return.

Policy Gradient Theorem

The key insight of the policy gradient theorem is that we can compute the gradient of the expected return:

θJ(θ)=Eτπθ[t=0Tθlogπθ(atst)R(τ)]\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t) R(\tau)\right]

This tells us how to adjust our policy parameters to increase expected return.

REINFORCE Algorithm

The REINFORCE algorithm is the simplest policy gradient method. Here’s a basic implementation:

import torch
import torch.nn as nn

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.network(state)

def reinforce_update(policy, trajectories, optimizer, gamma=0.99):
    """Update policy using REINFORCE algorithm."""
    policy_loss = []

    for trajectory in trajectories:
        rewards = trajectory['rewards']
        log_probs = trajectory['log_probs']

        # Compute discounted returns
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        # Compute policy gradient
        for log_prob, G in zip(log_probs, returns):
            policy_loss.append(-log_prob * G)

    optimizer.zero_grad()
    loss = torch.stack(policy_loss).sum()
    loss.backward()
    optimizer.step()

    return loss.item()

Variance Reduction

A key challenge with policy gradients is high variance in gradient estimates. Several techniques help:

  1. Baseline subtraction: Subtract a baseline b(st)b(s_t) from returns without changing expectation
  2. Advantage estimation: Use A(st,at)=Q(st,at)V(st)A(s_t, a_t) = Q(s_t, a_t) - V(s_t) instead of raw returns
  3. Generalized Advantage Estimation (GAE): Combines multiple advantage estimates

Modern Extensions

Policy gradients have evolved into several powerful algorithms:

  • TRPO (Trust Region Policy Optimization): Constrains policy updates to a trust region
  • PPO (Proximal Policy Optimization): Simpler alternative to TRPO with clipped objectives
  • A3C (Asynchronous Advantage Actor-Critic): Parallel training with advantage estimates

Practical Considerations

When implementing policy gradients:

  • Start with small learning rates (e.g., 3×1043 \times 10^{-4})
  • Use baseline/advantage to reduce variance
  • Monitor policy entropy to ensure exploration
  • Consider PPO for most practical applications

Conclusion

Policy gradients provide a principled way to optimize policies directly from reward signals. While they require careful tuning and variance reduction, they enable learning in complex domains where value-based methods struggle.

The mathematical elegance of the policy gradient theorem continues to inspire new algorithms, making this a vibrant area of research.