策略梯度定理深度解析
策略梯度方法直接对策略进行优化,是现代强化学习的核心技术之一。
1. 策略梯度目标
1.1 平均值函数目标
1.2 起始状态目标
1.3 平均奖励目标
2. 策略梯度定理
2.1 定理陈述
定理(策略梯度定理):
对于可微策略 ,策略梯度为:
或等价形式:
2.2 证明(基于起始状态目标)
步骤1:对值函数求导
步骤2:使用对数梯度恒等式
步骤3:Bellman方程代入
对 求导,利用 :
步骤4:递归展开
通过递归展开,得到:
其中 是轨迹。
2.3 轨迹视角
3. REINFORCE算法
3.1 蒙特卡洛策略梯度
def REINFORCE(env, policy, optimizer, num_episodes):
for episode in range(num_episodes):
trajectory = collect_episode(env, policy)
G = 0
for t in reversed(range(len(trajectory))):
s, a, r = trajectory[t]
G = r + gamma * G # 计算回报
# 策略梯度更新
log_prob = policy.log_prob(s, a)
loss = -log_prob * G # 最大化回报 = 最小化负回报
optimizer.zero_grad()
loss.backward()
optimizer.step()3.2 梯度估计器
3.3 收敛性条件
定理:若满足以下条件,REINFORCE几乎必然收敛到局部最优策略:
- 学习率 满足 ,
- 策略可微且满足正则性条件
4. 方差缩减技术
4.1 基线(Baseline)
减去基线函数 不改变期望:
证明:
4.2 最优基线
最小化方差的最优基线:
实际中常用 。
4.3 优势函数替换
使用优势函数 可进一步降低方差。
5. Actor-Critic架构
5.1 基本思想
用Critic网络近似值函数,Actor网络更新策略:
| 组件 | 输出 | 目标 |
|---|---|---|
| Actor () | 策略分布 | 最大化期望回报 |
| Critic () | 值函数估计 | 最小化TD误差 |
5.2 策略梯度更新
其中 由Critic估计。
5.3 Critic更新
# TD(0)更新
delta = r + gamma * V_phi(s_next) - V_phi(s)
phi = phi + beta * delta * grad_V_phi(s)6. 自然策略梯度
6.1 Fisher信息矩阵
6.2 自然梯度更新
6.3 KL散度约束
自然梯度等价于在策略分布的黎曼流形上进行最陡下降,相邻策略间的KL散度受约束:
7. 信任域方法
7.1 信任域策略优化(TRPO)
7.2 共轭梯度求解
使用共轭梯度法高效求解约束优化问题。
7.3 线搜索
为保证约束满足,执行线搜索:
for alpha in [1, 0.5, 0.25, ...]:
theta_new = theta + alpha * delta_theta
if KL_check(theta, theta_new) and improvement():
theta = theta_new
break8. GAE优势估计
8.1 n步优势估计
8.2 GAE定义
物理意义:
- :(TD(0))
- :(蒙特卡洛)
8.3 偏差-方差权衡
| 偏差 | 方差 | |
|---|---|---|
| 低 | 低偏差 | 高方差 |
| 高 | 高偏差 | 低方差 |
9. 与深度学习的联系
9.1 策略网络架构
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, x):
return self.net(x)
def log_prob(self, x, action):
logits = self.forward(x)
dist = Categorical(logits=logits)
return dist.log_prob(action)9.2 PyTorch实现
def update_policy(policy, optimizer, states, actions, returns):
log_probs = policy.log_prob(states, actions)
# 使用返回作为基线
loss = -(log_probs * returns).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()10. 现代变体
10.1 PPO (Proximal Policy Optimization)
引入裁剪机制限制策略更新幅度:
其中 。
10.2 AWR (Advantage-Weighted Regression)
11. 参考文献
相关主题:MDP数学基础 | PPO全局收敛性理论 | 无折扣策略梯度理论