因果MDP与因果POMDP

1. 从标准MDP到因果MDP

1.1 标准MDP回顾

标准MDP由五元组 定义:

  • :状态空间
  • :动作空间
  • :转移概率
  • :奖励函数
  • :折扣因子

核心假设:状态转移遵循马尔可夫性质,即 仅依赖于

1.2 标准MDP的因果缺陷

标准MDP存在三个关键的因果缺陷:

缺陷描述后果
混淆因素状态可能包含非因果信息虚假相关性导致错误决策
动作表示无法区分动作与观察混淆干预与观察
转移机制黑盒转移函数缺乏因果可解释性

1.3 因果MDP的引入

定义:因果MDP(CMDP)是一个七元组

符号含义
状态空间
动作空间
因果结构图
因果转移函数
奖励函数
折扣因子
因果约束集合

2. 因果结构图与MDP

2.1 因果图的定义

定义:因果结构图 是一个有向无环图(DAG),其中:

  • 是节点集合(状态和动作)
  • 是因果边的集合

边类型

边类型表示含义
状态因果流状态的历史依赖
动作因果效应动作对状态的直接影响
混淆因素未观测的混杂变量

2.2 状态因果分解

假设状态空间可以分解为:

其中:

  • 因果状态(由父节点直接决定)
  • 环境状态(由外部因素决定)

因果状态更新方程

2.3 因果马尔可夫条件

定理(因果马尔可夫条件):在因果图 下,给定父节点 条件独立于所有非后代节点。


3. 因果转移函数

3.1 从观察分布到因果机制

标准MDP使用观察分布

因果MDP使用因果机制

3.2 do-操作与转移

定义:因果转移函数 满足:

其中 在因果图中的父节点。

3.3 识别条件

定理(因果转移可识别):若以下条件之一成立,则 可从观察数据识别:

  1. 后门路径阻断:存在集合 阻断所有从 的后门路径
  2. 前门准则:存在集合 使得 ,且无直接边
  3. do-calculus可判定:通过do-calculus三条规则可推导出可识别表达式

3.4 因果转移 vs 观察转移

┌─────────────────────────────────────────────────────────────────┐
│                  因果转移 vs 观察转移                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   观察转移:                                                       │
│   P(s'|s, a) = Σ_u P(s'|s, a, u) P(u|s)                        │
│                 ↑                                               │
│                 包含混淆因素u的影响                               │
│                                                                  │
│   因果转移:                                                       │
│   P(s'|do(a), s) = Σ_u P(s'|s, do(a), u) P(u|s)                │
│                    = Σ_u P(s'|s, u) P(u|s)  ← do移除a的影响     │
│                                                                  │
│   关键区别:                                                       │
│   因果转移排除了动作a对混淆因素u的间接效应                        │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

4. 因果约束与安全策略

4.1 约束类型

因果约束 可以表示为:

每条约束是关于因果效应的函数:

约束类型形式示例
因果不等式$P(s’do(a), s) \geq \epsilon$
反事实约束反事实价值不超过阈值
干预约束动作间的平均因果效应

4.2 约束满足的MDP

定义:约束CMDP(Constrained CMDP)是满足约束的CMDP:

其中约束函数 可以是:

  • 期望累积成本
  • 因果效应约束
  • 反事实风险约束

4.3 拉格朗日松弛

使用拉格朗日乘子法处理约束:

投影梯度下降

def constrained_policy_update(policy, rewards, constraints, lambda_vec, alpha):
    """
    约束策略更新
    """
    # 计算无约束梯度
    policy_gradient = compute_policy_gradient(rewards)
    
    # 计算约束违反梯度
    constraint_gradient = compute_constraint_gradient(constraints)
    
    # 拉格朗日更新
    lambda_new = relu(lambda_vec + alpha * (constraint_gradient - kappa))
    
    # 策略更新
    new_policy = policy + policy_gradient - lambda_new * constraint_gradient
    
    return project_to_constraints(new_policy), lambda_new

5. 因果POMDP

5.1 标准POMDP回顾

POMDP由七元组 定义:

  • :观测空间
  • :观测函数

信念状态

5.2 因果POMDP的定义

定义:因果POMDP(-POMDP)扩展了POMDP,加入因果结构:

新增组件含义
状态-动作-观测的因果图
因果转移函数
初始因果信念

5.3 因果信念状态

定义:因果信念状态 是对潜在因果状态混淆因素的联合信念:

其中:

  • :可观测的因果状态
  • :未观测的混淆因素

5.4 因果观测模型

观测函数分解为:

其中:

  • :因果状态
  • :环境状态

观测因果条件独立


6. 价值函数的形式化

6.1 因果价值函数

标准价值函数

因果价值函数

6.2 因果贝尔曼方程

标准贝尔曼方程

因果贝尔曼方程

6.3 因果最优方程

最优价值函数

最优策略


7. 算法与实现

7.1 因果Q学习

import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from typing import Dict, Tuple, Optional
 
class CausalQNetwork(nn.Module):
    """
    因果Q网络
    学习因果转移函数而非观察转移
    """
    def __init__(self, state_dim: int, action_dim: int, 
                 hidden_dim: int = 128, n_causal_factors: int = 8):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.n_causal_factors = n_causal_factors
        
        # 状态编码器
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 因果因子提取器
        self.causal_extractor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_causal_factors)
        )
        
        # Q值估计器
        self.q_estimator = nn.Sequential(
            nn.Linear(hidden_dim + n_causal_factors + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # 因果转移模型
        self.causal_transition_model = nn.Sequential(
            nn.Linear(hidden_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_causal_factors * state_dim)  # 输出因果转移参数
        )
    
    def forward(self, state: Tensor, action: Tensor) -> Tensor:
        """计算Q值"""
        s_enc = self.state_encoder(state)
        c_factors = self.causal_extractor(s_enc)
        
        combined = torch.cat([s_enc, c_factors, action], dim=-1)
        return self.q_estimator(combined)
    
    def predict_causal_transition(self, state: Tensor, 
                                 action: Tensor) -> Tensor:
        """
        预测因果转移
        返回: (batch, state_dim) 因果转移后的状态预测
        """
        s_enc = self.state_encoder(state)
        combined = torch.cat([s_enc, action], dim=-1)
        
        # 预测因果效应
        causal_effect = self.causal_transition_model(combined)
        
        # 重塑为状态维度的缩放因子
        effect = causal_effect.view(-1, self.n_causal_factors, self.state_dim)
        effect_scale = torch.mean(effect, dim=1)  # 聚合因果因子
        
        # 应用因果效应到状态
        return state + effect_scale
    
    def compute_counterfactual_q(self, state: Tensor, 
                                action: Tensor, 
                                next_state: Tensor,
                                reward: Tensor,
                                gamma: float,
                                target_network: 'CausalQNetwork') -> Tensor:
        """
        计算反事实Q值
        Q_cf(s,a) = R + γ * max_a' E[P(s'|do(a),s)]
        """
        # 预测因果转移
        predicted_next_state = self.predict_causal_transition(state, action)
        
        # 计算反事实优势
        with torch.no_grad():
            next_q = target_network(state, action)
            target_q = reward + gamma * next_q
        
        return target_q
 
 
class CausalQLearning:
    """
    因果Q学习算法
    """
    def __init__(self, state_dim: int, action_dim: int,
                 hidden_dim: int = 128,
                 lr: float = 1e-3,
                 gamma: float = 0.99,
                 epsilon: float = 1.0,
                 epsilon_decay: float = 0.995,
                 epsilon_min: float = 0.01):
        
        self.q_network = CausalQNetwork(state_dim, action_dim, hidden_dim)
        self.target_network = CausalQNetwork(state_dim, action_dim, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        
        self.training_step = 0
    
    def select_action(self, state: Tensor) -> int:
        """ε-贪心动作选择"""
        if torch.rand(1).item() < self.epsilon:
            return torch.randint(0, self.q_network.action_dim, (1,)).item()
        
        with torch.no_grad():
            q_values = []
            for a in range(self.q_network.action_dim):
                action = torch.zeros(1, self.q_network.action_dim)
                action[0, a] = 1.0
                q = self.q_network(state, action)
                q_values.append(q)
            return torch.argmax(torch.cat(q_values)).item()
    
    def update(self, state: Tensor, action: int, 
               reward: Tensor, next_state: Tensor, done: bool):
        """
        更新Q网络
        """
        # 准备动作tensor
        action_tensor = torch.zeros(1, self.q_network.action_dim)
        action_tensor[0, action] = 1.0
        
        # 计算目标Q值(使用因果转移预测)
        predicted_next = self.q_network.predict_causal_transition(state, action_tensor)
        
        with torch.no_grad():
            if done:
                target_q = reward
            else:
                # 选择下一个动作
                next_action = self.select_action(next_state)
                next_action_tensor = torch.zeros(1, self.q_network.action_dim)
                next_action_tensor[0, next_action] = 1.0
                
                target_q = reward + self.gamma * self.q_network(
                    predicted_next, next_action_tensor
                )
        
        # 计算当前Q值
        current_q = self.q_network(state, action_tensor)
        
        # MSE损失
        loss = nn.MSELoss()(current_q, target_q)
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
        self.optimizer.step()
        
        # 更新epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
        # 定期更新目标网络
        self.training_step += 1
        if self.training_step % 100 == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        
        return loss.item()

7.2 因果策略梯度

class CausalPolicyGradient:
    """
    因果策略梯度算法
    使用因果优势函数估计
    """
    def __init__(self, state_dim: int, action_dim: int, 
                 hidden_dim: int = 128, lr: float = 3e-4):
        
        self.policy_net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        self.value_net = CausalQNetwork(state_dim, action_dim, hidden_dim)
        
        self.optimizer = optim.Adam(
            list(self.policy_net.parameters()) + list(self.value_net.parameters()),
            lr=lr
        )
    
    def compute_causal_advantage(self, states: Tensor, actions: Tensor,
                                 rewards: Tensor, next_states: Tensor,
                                 dones: Tensor, gamma: float = 0.99,
                                 lambda_gae: float = 0.95) -> Tuple[Tensor, Tensor]:
        """
        计算因果GAE(Generalized Advantage Estimation)
        
        考虑因果转移而非观察转移
        """
        with torch.no_grad():
            # 预测因果转移后的状态
            predicted_next = self.value_net.predict_causal_transition(
                states, actions
            )
            
            # 因果V值
            values = self.value_net(states, actions)
            next_values = self.value_net(predicted_next, actions)
            
            # TD误差
            td_errors = rewards + gamma * next_values * (1 - dones) - values
            
            # GAE
            advantages = torch.zeros_like(td_errors)
            gae = 0
            for t in reversed(range(len(td_errors))):
                gae = td_errors[t] + gamma * lambda_gae * gae * (1 - dones[t])
                advantages[t] = gae
        
        returns = advantages + values.detach()
        return advantages, returns
    
    def update(self, states: Tensor, actions: Tensor, 
              rewards: Tensor, next_states: Tensor, dones: Tensor):
        """
        策略更新
        """
        # 计算因果优势
        advantages, returns = self.compute_causal_advantage(
            states, actions, rewards, next_states, dones
        )
        
        # 策略损失
        action_probs = self.policy_net(states)
        action_indices = actions.unsqueeze(1)
        selected_probs = torch.gather(action_probs, 1, action_indices).squeeze()
        
        # 策略梯度损失
        policy_loss = -(selected_probs * advantages.detach()).mean()
        
        # 值函数损失
        values = self.value_net(states, actions)
        value_loss = nn.MSELoss()(values, returns)
        
        # 总损失
        total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy(action_probs)
        
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
        self.optimizer.step()
        
        return policy_loss.item(), value_loss.item()
 
 
def entropy(probs: Tensor) -> Tensor:
    """计算策略熵"""
    return -(probs * torch.log(probs + 1e-8)).sum(dim=-1).mean()

7.3 CMDP约束优化

class ConstrainedCMDP:
    """
    约束CMDP求解器
    使用投影梯度法处理因果约束
    """
    def __init__(self, q_network: CausalQNetwork,
                 constraint_threshold: float = 0.1,
                 lr_pi: float = 1e-4,
                 lr_lambda: float = 1e-3):
        
        self.q_network = q_network
        self.constraint_threshold = constraint_threshold
        
        self.policy_net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        self.optimizer_pi = optim.Adam(self.policy_net.parameters(), lr=lr_pi)
        self.optimizer_lambda = optim.Adam([self.lambda_vec], lr=lr_lambda)
        
        # 拉格朗日乘子
        self.lambda_vec = nn.Parameter(torch.tensor([0.0]))
    
    def compute_causal_constraint(self, states: Tensor, 
                                 actions: Tensor) -> Tensor:
        """
        计算因果约束值
        
        约束: 因果转移的方差应小于阈值
        这鼓励策略选择更稳定的因果动作
        """
        with torch.no_grad():
            # 预测多次因果转移以估计方差
            predictions = []
            for _ in range(5):
                pred = self.q_network.predict_causal_transition(states, actions)
                predictions.append(pred)
            
            predictions = torch.stack(predictions)
            
            # 计算预测的方差
            variance = torch.var(predictions, dim=0).mean()
            
            return variance
    
    def update(self, states: Tensor, actions: Tensor,
              rewards: Tensor, next_states: Tensor, dones: Tensor):
        """
        约束策略更新
        """
        # 1. 计算因果约束
        constraint_value = self.compute_causal_constraint(states, actions)
        
        # 2. 计算约束违反
        constraint_violation = torch.relu(constraint_value - self.constraint_threshold)
        
        # 3. 更新拉格朗日乘子(梯度上升)
        lambda_loss = -self.lambda_vec * (constraint_value - self.constraint_threshold)
        self.optimizer_lambda.zero_grad()
        lambda_loss.backward()
        self.optimizer_lambda.step()
        
        # 确保lambda非负
        with torch.no_grad():
            self.lambda_vec.clamp_(min=0)
        
        # 4. 更新策略
        q_values = self.q_network(states, actions)
        policy_loss = -q_values.mean() + self.lambda_vec * constraint_violation
        
        self.optimizer_pi.zero_grad()
        policy_loss.backward()
        self.optimizer_pi.step()
        
        return constraint_value.item(), self.lambda_vec.item()

8. 实例分析:因果GridWorld

8.1 环境设置

考虑一个简化的GridWorld,其中某些状态转移受混淆因素影响:

+-------------------+
|  S  |     |  G   |    S: 起始状态
|-------------------|    G: 目标状态
|     | [U] |     |    U: 混淆区域
|-------------------|    [U]: 混淆因素影响此区域
|     |     |     |
+-------------------+

8.2 因果结构

混淆因素 U
     ↓
状态 S ──────→ 状态 S'
     ↓           ↑
动作 A          |
     ↓           |
     └───────────┘

8.3 代码实现

import numpy as np
from typing import Tuple, List
 
class CausalGridWorld:
    """
    因果GridWorld环境
    包含混淆因素的MDP
    """
    def __init__(self, size: int = 4):
        self.size = size
        self.n_states = size * size
        self.n_actions = 4  # 上、下、左、右
        
        # 状态坐标
        self.state_to_pos = {i: (i // size, i % size) for i in range(self.n_states)}
        
        # 目标状态
        self.goal_state = self.n_states - 1
        
        # 混淆区域
        self.confounded_states = [5, 6, 9, 10]
        
        # 因果转移概率(不受混淆影响)
        self.causal_transition_prob = self._init_causal_transition()
        
        # 观察转移概率(受混淆影响)
        self.observed_transition_prob = self._init_observed_transition()
    
    def _init_causal_transition(self) -> np.ndarray:
        """初始化因果转移矩阵 P(s'|do(a), s)"""
        P = np.zeros((self.n_actions, self.n_states, self.n_states))
        
        for a in range(self.n_actions):
            for s in range(self.n_states):
                probs = np.zeros(self.n_states)
                
                # 计算目标位置
                row, col = self.state_to_pos[s]
                dr, dc = [(0, 1), (0, -1), (-1, 0), (1, 0)][a]
                nr, nc = row + dr, col + dc
                
                # 检查边界
                if 0 <= nr < self.size and 0 <= nc < self.size:
                    next_s = nr * self.size + nc
                    probs[next_s] = 0.9
                    probs[s] = 0.1  # 小的失败概率
                else:
                    probs[s] = 1.0  # 撞墙,保持原状态
                
                P[a, s] = probs
        
        return P
    
    def _init_observed_transition(self) -> np.ndarray:
        """初始化观察转移矩阵(包含混淆效应)"""
        P = self.causal_transition_prob.copy()
        
        # 在混淆区域添加混淆效应
        for s in self.confounded_states:
            for a in range(self.n_actions):
                # 混淆因素使得转移随机化
                random_prob = 0.3
                uniform = np.ones(self.n_states) / self.n_states
                
                P[a, s] = (1 - random_prob) * P[a, s] + random_prob * uniform
        
        return P
    
    def do_action(self, state: int, action: int) -> np.ndarray:
        """
        执行do操作,返回因果转移分布
        """
        return self.causal_transition_prob[action, state]
    
    def step(self, state: int, action: int, 
             use_causal: bool = False) -> Tuple[int, float, bool]:
        """
        执行一步转移
        
        Args:
            state: 当前状态
            action: 动作
            use_causal: True则使用因果转移,False则使用观察转移
        """
        if use_causal:
            probs = self.do_action(state, action)
        else:
            probs = self.observed_transition_prob[action, state]
        
        next_state = np.random.choice(self.n_states, p=probs)
        
        # 奖励
        reward = 1.0 if next_state == self.goal_state else -0.01
        
        # 完成
        done = next_state == self.goal_state
        
        return next_state, reward, done
    
    def compute_causal_effect(self, state: int, action_a: int, 
                            action_b: int) -> float:
        """
        计算动作a和b的因果效应
        """
        dist_a = self.do_action(state, action_a)
        dist_b = self.do_action(state, action_b)
        
        # 使用总变差距离
        return 0.5 * np.sum(np.abs(dist_a - dist_b))
    
    def identify_optimal_policy(self) -> np.ndarray:
        """
        识别因果最优策略
        使用因果转移而非观察转移
        """
        # 简化版本:使用值迭代
        V = np.zeros(self.n_states)
        policy = np.zeros(self.n_states, dtype=int)
        
        for _ in range(1000):
            for s in range(self.n_states):
                if s == self.goal_state:
                    continue
                
                q_values = []
                for a in range(self.n_actions):
                    # 使用因果转移
                    next_probs = self.do_action(s, a)
                    q_a = np.sum(next_probs * (self.observed_transition_prob[a, s] * 0 + 
                                               [1.0 if i == self.goal_state else -0.01 
                                                for i in range(self.n_states)]))
                    q_values.append(q_a)
                
                best_a = np.argmax(q_values)
                V[s] = max(q_values)
                policy[s] = best_a
        
        return policy
 
 
def compare_policies():
    """
    比较因果策略和观察策略
    """
    np.random.seed(42)
    
    env = CausalGridWorld(size=4)
    
    print("=" * 60)
    print("因果GridWorld分析")
    print("=" * 60)
    
    # 分析混淆区域
    print("\n混淆区域分析:")
    for s in env.confounded_states:
        pos = env.state_to_pos[s]
        print(f"\n状态 {s} (位置 {pos}):")
        
        for a in range(env.n_actions):
            causal = env.do_action(s, a)
            observed = env.observed_transition_prob[a, s]
            
            diff = np.sum(np.abs(causal - observed))
            print(f"  动作 {a}: 因果-观察差异 = {diff:.4f}")
    
    # 计算反事实效应
    print("\n\n反事实效应分析:")
    test_state = 5  # 混淆区域
    for a1 in range(env.n_actions):
        for a2 in range(a1 + 1, env.n_actions):
            effect = env.compute_causal_effect(test_state, a1, a2)
            print(f"  PE(A={a1}, A'={a2}) = {effect:.4f}")
 
 
if __name__ == "__main__":
    compare_policies()

9. 总结

核心要点

  1. 因果MDP的优势:明确建模因果机制,支持干预和反事实推理
  2. 因果转移识别:通过do-calculus从观察数据中识别因果效应
  3. 约束处理:通过拉格朗日方法处理因果约束
  4. 因果POMDP:处理部分可观测环境中的因果推断

与标准MDP的关系

方面标准MDP因果MDP
转移函数$P(s’s,a)$
价值函数
最优性局部最优因果稳定
泛化能力分布内跨环境

下一步


参考文献