因果强化学习基础
1. 为什么需要因果强化学习?
1.1 传统RL的核心问题
传统强化学习基于相关性驱动的决策范式,存在三大根本性缺陷:
| 问题 | 描述 | 具体表现 |
|---|---|---|
| 分布偏移脆弱性 | 训练与测试环境分布不同导致性能骤降 | 游戏AI在更换皮肤后失效 |
| 虚假相关性 | 模型可能利用环境中的偶然关联 | 自动驾驶依赖天空颜色判断红灯 |
| 缺乏可解释性 | 决策过程是黑盒的 | 医疗决策系统无法解释诊断依据 |
1.2 因果推断的启示
Judea Pearl的因果阶梯理论为解决上述问题提供了理论基础:
┌─────────────────────────────────────────────────────────────────┐
│ 因果阶梯 (Causal Hierarchy) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 第3层: 反事实层 (Counterfactual) │
│ "如果我没这么做,会发生什么?" │
│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │
│ │
│ 第2层: 干预层 (Intervention) │
│ "如果我这么做,会发生什么?" │
│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │
│ │
│ 第1层: 关联层 (Association) ← 传统ML/RL │
│ "观察到什么?" │
│ │
└─────────────────────────────────────────────────────────────────┘
传统RL仅在第1层运作,而因果RL旨在攀登至第2、3层。
1.3 因果RL的核心优势
┌────────────────────────────────────────────────────────────────┐
│ 因果RL vs 传统RL │
├────────────────────────────────────────────────────────────────┤
│ │
│ 传统RL: │
│ π(a|s) = P(a | s) ← 观察分布 │
│ │
│ 因果RL: │
│ π(a|do(X=x), s) = P(a | do(X=x), s) ← 干预分布 │
│ │
│ 关键区别: │
│ - 因果模型能够区分相关性与因果性 │
│ - 能够预测干预的效果 │
│ - 能够进行反事实推理 │
│ │
└────────────────────────────────────────────────────────────────┘
2. 因果马尔可夫假设与强化学习
2.1 因果马尔可夫假设(CMH)
定义:在因果图 中,给定父节点 ,节点 条件独立于其非后代节点。
2.2 在MDP中的应用
考虑一个MDP ,其对应的因果图:
时间步 t:
S_t ─────→ S_{t+1} ─────→ S_{t+2}
↗ ↘ ↗
A_t R_t A_{t+1}
因果MDP假设:
- 状态转移 由因果机制决定
- 奖励 是因果效应的函数
2.3 状态因果分解
假设状态空间可以分解为:
其中:
- :因果相关状态(直接影响转移和奖励)
- :因果无关状态(仅作为观测噪声)
定理:若环境满足因果马尔可夫假设,则最优策略仅需依赖 而非完整的 。
3. do-calculus在强化学习中的应用
3.1 do-操作符基础
do-操作符表示干预(Intervention),与条件概率有本质区别:
| 符号 | 含义 | 类比 |
|---|---|---|
| 观察到 时 的概率 | ”看到” | |
| 强制设置 时 的概率 | ”做” |
3.2 do-calculus三条规则
设 为变量集合, 为因果图:
规则1(移除观测):
规则2(行动-观察交换):
规则3(忽略后天干预):
3.3 策略干预效应
考虑策略 作为对动作的干预:
关键洞察:do-calculus允许我们计算不同策略的因果效应,而不仅仅是观察分布。
4. 因果价值函数
4.1 传统价值函数回顾
4.2 因果价值函数
因果可达性(因果):
其中 是因果效应函数:
4.3 反事实价值函数
反事实Q函数:
这衡量的是实际动作 与策略 推荐动作的反事实差异。
4.4 因果优势函数
因果优势函数仅考虑动作对因果相关状态的影响。
5. 因果探索与奖励设计
5.1 因果探索问题
传统探索:基于不确定性的探索(UCB、Boltzmann)
因果探索:关注哪些动作对环境有因果效应
┌─────────────────────────────────────────────────────────────────┐
│ 因果探索 vs 传统探索 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 传统探索(信息增益): │
│ - 哪个动作减少对未来预测的不确定性? │
│ │
│ 因果探索(效应发现): │
│ - 哪个动作会导致状态的实际变化? │
│ - 动作与状态之间是否存在因果关系? │
│ │
└─────────────────────────────────────────────────────────────────┘
5.2 因果奖励机制
反事实奖励(Counterfactual Reward):
其中 是反事实效应:
5.3 因果感知的Upper Confidence Bound
Causal-UCB:
其中 是因果重要性权重:
6. 形式化定义
6.1 因果MDP定义
定义:因果MDP是一个七元组 :
| 符号 | 含义 |
|---|---|
| 状态空间 | |
| 动作空间 | |
| 状态-动作因果图 | |
| 因果转移函数:$P_c(s’ | |
| 奖励函数: | |
| 折扣因子 | |
| 初始状态分布 |
6.2 因果最优策略
定义:策略 是因果最优的,当且仅当:
其中 是满足因果约束的策略空间。
7. PyTorch实现
7.1 因果价值函数估计器
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class CausalValueEstimator(nn.Module):
"""
因果价值函数估计器
支持因果干预估计和反事实价值计算
"""
def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
# 因果状态编码器
self.causal_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 动作编码器
self.action_encoder = nn.Sequential(
nn.Linear(action_dim, hidden_dim),
nn.ReLU()
)
# 因果效应估计器
self.effect_estimator = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1) # 输出因果效应
)
# 价值估计器
self.value_estimator = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# Q值估计器
self.q_estimator = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state: Tensor, action: Tensor) -> Tensor:
"""计算因果Q值"""
s_enc = self.causal_encoder(state)
a_enc = self.action_encoder(action)
combined = torch.cat([s_enc, a_enc], dim=-1)
return self.q_estimator(combined)
def compute_causal_effect(self, state: Tensor, action: Tensor,
counterfactual_action: Tensor) -> Tensor:
"""
计算反事实效应
PE(s, a, a') = ||P(s'|do(a),s) - P(s'|do(a'),s)||
"""
# 实际动作的因果编码
actual_enc = torch.cat([
self.causal_encoder(state),
self.action_encoder(action)
], dim=-1)
# 反事实动作的因果编码
cf_enc = torch.cat([
self.causal_encoder(state),
self.action_encoder(counterfactual_action)
], dim=-1)
# 估计效应
actual_effect = self.effect_estimator(actual_enc)
cf_effect = self.effect_estimator(cf_enc)
# 反事实效应作为差异
return torch.abs(actual_effect - cf_effect)
def causal_advantage(self, state: Tensor, action: Tensor,
policy_actions: Tensor) -> Tensor:
"""
计算因果优势函数
考虑动作对因果相关状态的影响
"""
q_sa = self.forward(state, action)
# 策略动作的平均Q值
q_policy = torch.mean(self.forward(state, policy_actions), dim=-1, keepdim=True)
return q_sa - q_policy
class CausalRewardCalculator:
"""
因果奖励计算器
结合标准奖励和反事实效应
"""
def __init__(self, lambda_cf: float = 0.1):
self.lambda_cf = lambda_cf
def compute_reward(self, state: Tensor, action: Tensor, next_state: Tensor,
value_estimator: CausalValueEstimator) -> Tensor:
"""
计算增强的因果奖励
R_cf = R(s,a,s') + λ * Σ PE(s,a,a')
"""
# 标准奖励(这里简化处理)
base_reward = torch.norm(next_state - state, dim=-1, keepdim=True)
# 反事实惩罚
# 假设我们有所有可能动作的反事实状态
# 这里简化为与零动作的差异
zero_action = torch.zeros_like(action)
cf_effect = value_estimator.compute_causal_effect(state, action, zero_action)
return base_reward + self.lambda_cf * cf_effect
def causal_ucb_action_selection(q_values: Tensor, counts: Tensor,
t: int, phi: Tensor,
c: float = 1.0) -> Tensor:
"""
因果UCB动作选择
a_t = argmax[ Q(s,a) + c * sqrt(ln t / N(s,a)) * φ(s,a) ]
"""
# UCB项
ucb_bonus = c * torch.sqrt(torch.log(torch.tensor(t, dtype=torch.float32)) / (counts + 1e-8))
# 因果重要性加权
weighted_bonus = ucb_bonus * phi
return q_values + weighted_bonus7.2 因果MDP环境示例
import numpy as np
from typing import Dict, Tuple, Optional
class CausalMDP:
"""
简单因果MDP环境
用于演示因果转移机制
"""
def __init__(self, n_states: int = 5, n_actions: int = 3,
causal_strength: float = 0.8):
self.n_states = n_states
self.n_actions = n_actions
self.causal_strength = causal_strength
# 因果转移矩阵
# P(s' | do(a), s) - 不依赖于其他变量
self.causal_transition = self._initialize_causal_transition()
# 混淆转移矩阵
# P(s' | s) - 可能被混淆变量影响
self.confounded_transition = self._initialize_confounded_transition()
def _initialize_causal_transition(self) -> np.ndarray:
"""初始化因果转移矩阵"""
P = np.zeros((self.n_actions, self.n_states, self.n_states))
for a in range(self.n_actions):
for s in range(self.n_states):
# 每个动作有自己的转移偏好
probs = np.random.dirichlet(np.ones(self.n_states) * 0.5)
# 添加动作特定的偏移
probs[a] += self.causal_strength
probs = probs / probs.sum()
P[a, s] = probs
return P
def _initialize_confounded_transition(self) -> np.ndarray:
"""初始化混淆转移矩阵"""
return np.zeros((self.n_states, self.n_states))
def step(self, state: int, action: int,
use_causal: bool = True) -> Tuple[int, float, bool]:
"""
执行一步转移
Args:
state: 当前状态
action: 执行的动作
use_causal: 是否使用因果转移(True)或混淆转移(False)
"""
if use_causal:
# 因果转移:P(s' | do(a), s)
probs = self.causal_transition[action, state]
else:
# 混淆转移:使用观察分布
probs = self.confounded_transition[state]
next_state = np.random.choice(self.n_states, p=probs)
# 奖励:到达目标状态(state 0)获得高奖励
reward = 1.0 if next_state == 0 else 0.0
done = next_state == 0
return next_state, reward, done
def do_action(self, action: int, state: int) -> np.ndarray:
"""
执行do操作,返回干预分布 P(s' | do(a), s)
"""
return self.causal_transition[action, state]
def compute_counterfactual_effect(self, state: int,
action_a: int, action_b: int) -> float:
"""
计算动作a和b的反事实效应
"""
dist_a = self.do_action(action_a, state)
dist_b = self.do_action(action_b, state)
# 使用TV距离度量效应
return 0.5 * np.sum(np.abs(dist_a - dist_b))
def demonstrate_causal_vs_confounded():
"""
演示因果转移与混淆转移的区别
"""
np.random.seed(42)
# 创建因果MDP
env = CausalMDP(n_states=5, n_actions=3, causal_strength=0.7)
state = 2
print("=" * 60)
print(f"状态: {state}")
print("=" * 60)
# 观察策略
random_policy = np.ones(env.n_actions) / env.n_actions
print("\n1. 观察分布(可能被混淆):")
observed_next_state = []
for _ in range(1000):
action = np.random.choice(env.n_actions, p=random_policy)
next_s, _, _ = env.step(state, action, use_causal=False)
observed_next_state.append(next_s)
observed_dist = np.bincount(observed_next_state, minlength=env.n_states) / 1000
print(f" P(s'|s) ≈ {observed_dist}")
print("\n2. 因果干预分布(do操作):")
for action in range(env.n_actions):
causal_dist = env.do_action(action, state)
print(f" P(s'|do(A={action}), s) = {causal_dist}")
print("\n3. 反事实效应:")
for a1 in range(env.n_actions):
for a2 in range(a1 + 1, env.n_actions):
effect = env.compute_counterfactual_effect(state, a1, a2)
print(f" PE(A={a1}, A'={a2}) = {effect:.4f}")
if __name__ == "__main__":
demonstrate_causal_vs_confounded()8. 应用场景
8.1 自动驾驶
| 问题 | 因果RL解决方案 |
|---|---|
| 天气变化导致感知漂移 | 因果状态分解,过滤混淆因素 |
| 罕见场景泛化 | 因果探索,快速学习因果结构 |
| 事故责任认定 | 可解释的因果决策链 |
8.2 医疗决策
| 问题 | 因果RL解决方案 |
|---|---|
| 治疗方案选择 | 因果效应估计,预测干预结果 |
| 患者亚群差异 | 分层因果模型 |
| 数据稀缺 | 因果迁移学习 |
8.3 机器人控制
| 问题 | 因果RL解决方案 |
|---|---|
| 物理参数变化 | 因果策略迁移 |
| 人机协作 | 因果意图识别与响应 |
| 故障恢复 | 因果反事实推理 |
9. 相关工作
9.1 理论工作
| 论文 | 年份 | 贡献 |
|---|---|---|
| Causal Markov Decision Processes | 2007 | CMDP形式化 |
| Causal Discovery for Reinforcement Learning | 2020 | 因果发现与RL结合 |
| Unifying Causal RL: Survey and Taxonomy | 2025 | 统一框架与分类法 |
9.2 算法工作
| 论文 | 年份 | 贡献 |
|---|---|---|
| CausalRL | 2020 | 因果奖励函数 |
| Counterfactual RL | 2021 | 反事实价值估计 |
| Causal Exploration | 2022 | 因果探索策略 |
10. 总结
核心要点
- 因果RL的核心思想:从相关性驱动转向因果机制驱动
- do-calculus的桥梁作用:连接观察分布与干预分布
- 因果价值函数:考虑动作的因果效应而非表面相关性
- 因果探索:发现动作与环境之间的真实因果关系
与传统RL的关键区别
| 方面 | 传统RL | 因果RL |
|---|---|---|
| 决策基础 | 观察分布 $P(a | s)$ |
| 泛化能力 | 分布内 | 跨环境因果迁移 |
| 可解释性 | 黑盒 | 因果链条透明 |
| 探索策略 | 信息增益 | 因果效应发现 |
下一步
- 因果MDP与CMDP - 深入形式化框架
- 因果探索策略 - 因果感知的探索算法
- 因果逆RL - 从观察中推断因果奖励和约束