因果强化学习基础

1. 为什么需要因果强化学习?

1.1 传统RL的核心问题

传统强化学习基于相关性驱动的决策范式,存在三大根本性缺陷:

问题描述具体表现
分布偏移脆弱性训练与测试环境分布不同导致性能骤降游戏AI在更换皮肤后失效
虚假相关性模型可能利用环境中的偶然关联自动驾驶依赖天空颜色判断红灯
缺乏可解释性决策过程是黑盒的医疗决策系统无法解释诊断依据

1.2 因果推断的启示

Judea Pearl的因果阶梯理论为解决上述问题提供了理论基础:

┌─────────────────────────────────────────────────────────────────┐
│                     因果阶梯 (Causal Hierarchy)                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   第3层: 反事实层 (Counterfactual)                               │
│   "如果我没这么做,会发生什么?"                                   │
│   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━        │
│                                                                  │
│   第2层: 干预层 (Intervention)                                   │
│   "如果我这么做,会发生什么?"                                     │
│   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━        │
│                                                                  │
│   第1层: 关联层 (Association) ← 传统ML/RL                        │
│   "观察到什么?"                                                 │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

传统RL仅在第1层运作,而因果RL旨在攀登至第2、3层。

1.3 因果RL的核心优势

┌────────────────────────────────────────────────────────────────┐
│                    因果RL vs 传统RL                              │
├────────────────────────────────────────────────────────────────┤
│                                                                 │
│   传统RL:                                                         │
│   π(a|s) = P(a | s)           ← 观察分布                       │
│                                                                 │
│   因果RL:                                                         │
│   π(a|do(X=x), s) = P(a | do(X=x), s)  ← 干预分布             │
│                                                                 │
│   关键区别:                                                      │
│   - 因果模型能够区分相关性与因果性                               │
│   - 能够预测干预的效果                                           │
│   - 能够进行反事实推理                                           │
│                                                                 │
└────────────────────────────────────────────────────────────────┘

2. 因果马尔可夫假设与强化学习

2.1 因果马尔可夫假设(CMH)

定义:在因果图 中,给定父节点 ,节点 条件独立于其非后代节点。

2.2 在MDP中的应用

考虑一个MDP ,其对应的因果图:

时间步 t:
                                    
    S_t ─────→ S_{t+1} ─────→ S_{t+2}
              ↗      ↘              ↗
            A_t      R_t           A_{t+1}

因果MDP假设

  • 状态转移 由因果机制决定
  • 奖励 是因果效应的函数

2.3 状态因果分解

假设状态空间可以分解为:

其中:

  • 因果相关状态(直接影响转移和奖励)
  • 因果无关状态(仅作为观测噪声)

定理:若环境满足因果马尔可夫假设,则最优策略仅需依赖 而非完整的


3. do-calculus在强化学习中的应用

3.1 do-操作符基础

do-操作符表示干预(Intervention),与条件概率有本质区别:

符号含义类比
观察到 的概率”看到”
强制设置 的概率”做”

3.2 do-calculus三条规则

为变量集合, 为因果图:

规则1(移除观测)

规则2(行动-观察交换)

规则3(忽略后天干预)

3.3 策略干预效应

考虑策略 作为对动作的干预:

关键洞察:do-calculus允许我们计算不同策略的因果效应,而不仅仅是观察分布。


4. 因果价值函数

4.1 传统价值函数回顾

4.2 因果价值函数

因果可达性(因果

其中 因果效应函数

4.3 反事实价值函数

反事实Q函数

这衡量的是实际动作 与策略 推荐动作的反事实差异

4.4 因果优势函数

因果优势函数仅考虑动作对因果相关状态的影响。


5. 因果探索与奖励设计

5.1 因果探索问题

传统探索:基于不确定性的探索(UCB、Boltzmann)

因果探索:关注哪些动作对环境有因果效应

┌─────────────────────────────────────────────────────────────────┐
│                     因果探索 vs 传统探索                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   传统探索(信息增益):                                          │
│   - 哪个动作减少对未来预测的不确定性?                            │
│                                                                  │
│   因果探索(效应发现):                                          │
│   - 哪个动作会导致状态的实际变化?                                │
│   - 动作与状态之间是否存在因果关系?                              │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

5.2 因果奖励机制

反事实奖励(Counterfactual Reward)

其中 反事实效应

5.3 因果感知的Upper Confidence Bound

Causal-UCB

其中 因果重要性权重


6. 形式化定义

6.1 因果MDP定义

定义:因果MDP是一个七元组

符号含义
状态空间
动作空间
状态-动作因果图
因果转移函数:$P_c(s’
奖励函数:
折扣因子
初始状态分布

6.2 因果最优策略

定义:策略 是因果最优的,当且仅当:

其中 是满足因果约束的策略空间。


7. PyTorch实现

7.1 因果价值函数估计器

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
 
class CausalValueEstimator(nn.Module):
    """
    因果价值函数估计器
    支持因果干预估计和反事实价值计算
    """
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # 因果状态编码器
        self.causal_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 动作编码器
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 因果效应估计器
        self.effect_estimator = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # 输出因果效应
        )
        
        # 价值估计器
        self.value_estimator = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Q值估计器
        self.q_estimator = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state: Tensor, action: Tensor) -> Tensor:
        """计算因果Q值"""
        s_enc = self.causal_encoder(state)
        a_enc = self.action_encoder(action)
        combined = torch.cat([s_enc, a_enc], dim=-1)
        return self.q_estimator(combined)
    
    def compute_causal_effect(self, state: Tensor, action: Tensor, 
                            counterfactual_action: Tensor) -> Tensor:
        """
        计算反事实效应
        PE(s, a, a') = ||P(s'|do(a),s) - P(s'|do(a'),s)||
        """
        # 实际动作的因果编码
        actual_enc = torch.cat([
            self.causal_encoder(state),
            self.action_encoder(action)
        ], dim=-1)
        
        # 反事实动作的因果编码
        cf_enc = torch.cat([
            self.causal_encoder(state),
            self.action_encoder(counterfactual_action)
        ], dim=-1)
        
        # 估计效应
        actual_effect = self.effect_estimator(actual_enc)
        cf_effect = self.effect_estimator(cf_enc)
        
        # 反事实效应作为差异
        return torch.abs(actual_effect - cf_effect)
    
    def causal_advantage(self, state: Tensor, action: Tensor, 
                        policy_actions: Tensor) -> Tensor:
        """
        计算因果优势函数
        考虑动作对因果相关状态的影响
        """
        q_sa = self.forward(state, action)
        
        # 策略动作的平均Q值
        q_policy = torch.mean(self.forward(state, policy_actions), dim=-1, keepdim=True)
        
        return q_sa - q_policy
 
 
class CausalRewardCalculator:
    """
    因果奖励计算器
    结合标准奖励和反事实效应
    """
    def __init__(self, lambda_cf: float = 0.1):
        self.lambda_cf = lambda_cf
    
    def compute_reward(self, state: Tensor, action: Tensor, next_state: Tensor,
                      value_estimator: CausalValueEstimator) -> Tensor:
        """
        计算增强的因果奖励
        
        R_cf = R(s,a,s') + λ * Σ PE(s,a,a')
        """
        # 标准奖励(这里简化处理)
        base_reward = torch.norm(next_state - state, dim=-1, keepdim=True)
        
        # 反事实惩罚
        # 假设我们有所有可能动作的反事实状态
        # 这里简化为与零动作的差异
        zero_action = torch.zeros_like(action)
        cf_effect = value_estimator.compute_causal_effect(state, action, zero_action)
        
        return base_reward + self.lambda_cf * cf_effect
 
 
def causal_ucb_action_selection(q_values: Tensor, counts: Tensor, 
                                t: int, phi: Tensor,
                                c: float = 1.0) -> Tensor:
    """
    因果UCB动作选择
    
    a_t = argmax[ Q(s,a) + c * sqrt(ln t / N(s,a)) * φ(s,a) ]
    """
    # UCB项
    ucb_bonus = c * torch.sqrt(torch.log(torch.tensor(t, dtype=torch.float32)) / (counts + 1e-8))
    
    # 因果重要性加权
    weighted_bonus = ucb_bonus * phi
    
    return q_values + weighted_bonus

7.2 因果MDP环境示例

import numpy as np
from typing import Dict, Tuple, Optional
 
class CausalMDP:
    """
    简单因果MDP环境
    用于演示因果转移机制
    """
    def __init__(self, n_states: int = 5, n_actions: int = 3, 
                 causal_strength: float = 0.8):
        self.n_states = n_states
        self.n_actions = n_actions
        self.causal_strength = causal_strength
        
        # 因果转移矩阵
        # P(s' | do(a), s) - 不依赖于其他变量
        self.causal_transition = self._initialize_causal_transition()
        
        # 混淆转移矩阵
        # P(s' | s) - 可能被混淆变量影响
        self.confounded_transition = self._initialize_confounded_transition()
    
    def _initialize_causal_transition(self) -> np.ndarray:
        """初始化因果转移矩阵"""
        P = np.zeros((self.n_actions, self.n_states, self.n_states))
        for a in range(self.n_actions):
            for s in range(self.n_states):
                # 每个动作有自己的转移偏好
                probs = np.random.dirichlet(np.ones(self.n_states) * 0.5)
                # 添加动作特定的偏移
                probs[a] += self.causal_strength
                probs = probs / probs.sum()
                P[a, s] = probs
        return P
    
    def _initialize_confounded_transition(self) -> np.ndarray:
        """初始化混淆转移矩阵"""
        return np.zeros((self.n_states, self.n_states))
    
    def step(self, state: int, action: int, 
             use_causal: bool = True) -> Tuple[int, float, bool]:
        """
        执行一步转移
        
        Args:
            state: 当前状态
            action: 执行的动作
            use_causal: 是否使用因果转移(True)或混淆转移(False)
        """
        if use_causal:
            # 因果转移:P(s' | do(a), s)
            probs = self.causal_transition[action, state]
        else:
            # 混淆转移:使用观察分布
            probs = self.confounded_transition[state]
        
        next_state = np.random.choice(self.n_states, p=probs)
        
        # 奖励:到达目标状态(state 0)获得高奖励
        reward = 1.0 if next_state == 0 else 0.0
        
        done = next_state == 0
        
        return next_state, reward, done
    
    def do_action(self, action: int, state: int) -> np.ndarray:
        """
        执行do操作,返回干预分布 P(s' | do(a), s)
        """
        return self.causal_transition[action, state]
    
    def compute_counterfactual_effect(self, state: int, 
                                      action_a: int, action_b: int) -> float:
        """
        计算动作a和b的反事实效应
        """
        dist_a = self.do_action(action_a, state)
        dist_b = self.do_action(action_b, state)
        
        # 使用TV距离度量效应
        return 0.5 * np.sum(np.abs(dist_a - dist_b))
 
 
def demonstrate_causal_vs_confounded():
    """
    演示因果转移与混淆转移的区别
    """
    np.random.seed(42)
    
    # 创建因果MDP
    env = CausalMDP(n_states=5, n_actions=3, causal_strength=0.7)
    
    state = 2
    
    print("=" * 60)
    print(f"状态: {state}")
    print("=" * 60)
    
    # 观察策略
    random_policy = np.ones(env.n_actions) / env.n_actions
    
    print("\n1. 观察分布(可能被混淆):")
    observed_next_state = []
    for _ in range(1000):
        action = np.random.choice(env.n_actions, p=random_policy)
        next_s, _, _ = env.step(state, action, use_causal=False)
        observed_next_state.append(next_s)
    
    observed_dist = np.bincount(observed_next_state, minlength=env.n_states) / 1000
    print(f"   P(s'|s) ≈ {observed_dist}")
    
    print("\n2. 因果干预分布(do操作):")
    for action in range(env.n_actions):
        causal_dist = env.do_action(action, state)
        print(f"   P(s'|do(A={action}), s) = {causal_dist}")
    
    print("\n3. 反事实效应:")
    for a1 in range(env.n_actions):
        for a2 in range(a1 + 1, env.n_actions):
            effect = env.compute_counterfactual_effect(state, a1, a2)
            print(f"   PE(A={a1}, A'={a2}) = {effect:.4f}")
 
 
if __name__ == "__main__":
    demonstrate_causal_vs_confounded()

8. 应用场景

8.1 自动驾驶

问题因果RL解决方案
天气变化导致感知漂移因果状态分解,过滤混淆因素
罕见场景泛化因果探索,快速学习因果结构
事故责任认定可解释的因果决策链

8.2 医疗决策

问题因果RL解决方案
治疗方案选择因果效应估计,预测干预结果
患者亚群差异分层因果模型
数据稀缺因果迁移学习

8.3 机器人控制

问题因果RL解决方案
物理参数变化因果策略迁移
人机协作因果意图识别与响应
故障恢复因果反事实推理

9. 相关工作

9.1 理论工作

论文年份贡献
Causal Markov Decision Processes2007CMDP形式化
Causal Discovery for Reinforcement Learning2020因果发现与RL结合
Unifying Causal RL: Survey and Taxonomy2025统一框架与分类法

9.2 算法工作

论文年份贡献
CausalRL2020因果奖励函数
Counterfactual RL2021反事实价值估计
Causal Exploration2022因果探索策略

10. 总结

核心要点

  1. 因果RL的核心思想:从相关性驱动转向因果机制驱动
  2. do-calculus的桥梁作用:连接观察分布与干预分布
  3. 因果价值函数:考虑动作的因果效应而非表面相关性
  4. 因果探索:发现动作与环境之间的真实因果关系

与传统RL的关键区别

方面传统RL因果RL
决策基础观察分布 $P(as)$
泛化能力分布内跨环境因果迁移
可解释性黑盒因果链条透明
探索策略信息增益因果效应发现

下一步


参考文献