策略梯度定理深度解析

策略梯度方法直接对策略进行优化,是现代强化学习的核心技术之一。

1. 策略梯度目标

1.1 平均值函数目标

1.2 起始状态目标

1.3 平均奖励目标

2. 策略梯度定理

2.1 定理陈述

定理(策略梯度定理):
对于可微策略 ,策略梯度为:

或等价形式:

2.2 证明(基于起始状态目标)

步骤1:对值函数求导

步骤2:使用对数梯度恒等式

步骤3:Bellman方程代入

求导,利用

步骤4:递归展开

通过递归展开,得到:

其中 是轨迹。

2.3 轨迹视角

3. REINFORCE算法

3.1 蒙特卡洛策略梯度

def REINFORCE(env, policy, optimizer, num_episodes):
    for episode in range(num_episodes):
        trajectory = collect_episode(env, policy)
        G = 0
        
        for t in reversed(range(len(trajectory))):
            s, a, r = trajectory[t]
            G = r + gamma * G  # 计算回报
            
            # 策略梯度更新
            log_prob = policy.log_prob(s, a)
            loss = -log_prob * G  # 最大化回报 = 最小化负回报
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

3.2 梯度估计器

3.3 收敛性条件

定理:若满足以下条件,REINFORCE几乎必然收敛到局部最优策略:

  1. 学习率 满足 ,
  2. 策略可微且满足正则性条件

4. 方差缩减技术

4.1 基线(Baseline)

减去基线函数 不改变期望:

证明

4.2 最优基线

最小化方差的最优基线:

实际中常用

4.3 优势函数替换

使用优势函数 可进一步降低方差。

5. Actor-Critic架构

5.1 基本思想

用Critic网络近似值函数,Actor网络更新策略:

组件输出目标
Actor ()策略分布最大化期望回报
Critic ()值函数估计最小化TD误差

5.2 策略梯度更新

其中 由Critic估计。

5.3 Critic更新

# TD(0)更新
delta = r + gamma * V_phi(s_next) - V_phi(s)
phi = phi + beta * delta * grad_V_phi(s)

6. 自然策略梯度

6.1 Fisher信息矩阵

6.2 自然梯度更新

6.3 KL散度约束

自然梯度等价于在策略分布的黎曼流形上进行最陡下降,相邻策略间的KL散度受约束:

7. 信任域方法

7.1 信任域策略优化(TRPO)

7.2 共轭梯度求解

使用共轭梯度法高效求解约束优化问题。

7.3 线搜索

为保证约束满足,执行线搜索:

for alpha in [1, 0.5, 0.25, ...]:
    theta_new = theta + alpha * delta_theta
    if KL_check(theta, theta_new) and improvement():
        theta = theta_new
        break

8. GAE优势估计

8.1 n步优势估计

8.2 GAE定义

物理意义

  • (TD(0))
  • (蒙特卡洛)

8.3 偏差-方差权衡

偏差方差
低偏差高方差
高偏差低方差

9. 与深度学习的联系

9.1 策略网络架构

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def log_prob(self, x, action):
        logits = self.forward(x)
        dist = Categorical(logits=logits)
        return dist.log_prob(action)

9.2 PyTorch实现

def update_policy(policy, optimizer, states, actions, returns):
    log_probs = policy.log_prob(states, actions)
    
    # 使用返回作为基线
    loss = -(log_probs * returns).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

10. 现代变体

10.1 PPO (Proximal Policy Optimization)

引入裁剪机制限制策略更新幅度:

其中

10.2 AWR (Advantage-Weighted Regression)

11. 参考文献


相关主题MDP数学基础 | PPO全局收敛性理论 | 无折扣策略梯度理论