因果MDP与因果POMDP
1. 从标准MDP到因果MDP
1.1 标准MDP回顾
标准MDP由五元组 定义:
- :状态空间
- :动作空间
- :转移概率
- :奖励函数
- :折扣因子
核心假设:状态转移遵循马尔可夫性质,即 仅依赖于 。
1.2 标准MDP的因果缺陷
标准MDP存在三个关键的因果缺陷:
| 缺陷 | 描述 | 后果 |
|---|---|---|
| 混淆因素 | 状态可能包含非因果信息 | 虚假相关性导致错误决策 |
| 动作表示 | 无法区分动作与观察 | 混淆干预与观察 |
| 转移机制 | 黑盒转移函数 | 缺乏因果可解释性 |
1.3 因果MDP的引入
定义:因果MDP(CMDP)是一个七元组 :
| 符号 | 含义 |
|---|---|
| 状态空间 | |
| 动作空间 | |
| 因果结构图 | |
| 因果转移函数 | |
| 奖励函数 | |
| 折扣因子 | |
| 因果约束集合 |
2. 因果结构图与MDP
2.1 因果图的定义
定义:因果结构图 是一个有向无环图(DAG),其中:
- 是节点集合(状态和动作)
- 是因果边的集合
边类型:
| 边类型 | 表示 | 含义 |
|---|---|---|
| 状态因果流 | 状态的历史依赖 | |
| 动作因果效应 | 动作对状态的直接影响 | |
| 混淆因素 | 未观测的混杂变量 |
2.2 状态因果分解
假设状态空间可以分解为:
其中:
- :因果状态(由父节点直接决定)
- :环境状态(由外部因素决定)
因果状态更新方程:
2.3 因果马尔可夫条件
定理(因果马尔可夫条件):在因果图 下,给定父节点 , 条件独立于所有非后代节点。
3. 因果转移函数
3.1 从观察分布到因果机制
标准MDP使用观察分布:
因果MDP使用因果机制:
3.2 do-操作与转移
定义:因果转移函数 满足:
其中 是 在因果图中的父节点。
3.3 识别条件
定理(因果转移可识别):若以下条件之一成立,则 可从观察数据识别:
- 后门路径阻断:存在集合 阻断所有从 到 的后门路径
- 前门准则:存在集合 使得 ,且无直接边
- do-calculus可判定:通过do-calculus三条规则可推导出可识别表达式
3.4 因果转移 vs 观察转移
┌─────────────────────────────────────────────────────────────────┐
│ 因果转移 vs 观察转移 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 观察转移: │
│ P(s'|s, a) = Σ_u P(s'|s, a, u) P(u|s) │
│ ↑ │
│ 包含混淆因素u的影响 │
│ │
│ 因果转移: │
│ P(s'|do(a), s) = Σ_u P(s'|s, do(a), u) P(u|s) │
│ = Σ_u P(s'|s, u) P(u|s) ← do移除a的影响 │
│ │
│ 关键区别: │
│ 因果转移排除了动作a对混淆因素u的间接效应 │
│ │
└─────────────────────────────────────────────────────────────────┘
4. 因果约束与安全策略
4.1 约束类型
因果约束 可以表示为:
每条约束是关于因果效应的函数:
| 约束类型 | 形式 | 示例 |
|---|---|---|
| 因果不等式 | $P(s’ | do(a), s) \geq \epsilon$ |
| 反事实约束 | 反事实价值不超过阈值 | |
| 干预约束 | 动作间的平均因果效应 |
4.2 约束满足的MDP
定义:约束CMDP(Constrained CMDP)是满足约束的CMDP:
其中约束函数 可以是:
- 期望累积成本
- 因果效应约束
- 反事实风险约束
4.3 拉格朗日松弛
使用拉格朗日乘子法处理约束:
投影梯度下降:
def constrained_policy_update(policy, rewards, constraints, lambda_vec, alpha):
"""
约束策略更新
"""
# 计算无约束梯度
policy_gradient = compute_policy_gradient(rewards)
# 计算约束违反梯度
constraint_gradient = compute_constraint_gradient(constraints)
# 拉格朗日更新
lambda_new = relu(lambda_vec + alpha * (constraint_gradient - kappa))
# 策略更新
new_policy = policy + policy_gradient - lambda_new * constraint_gradient
return project_to_constraints(new_policy), lambda_new5. 因果POMDP
5.1 标准POMDP回顾
POMDP由七元组 定义:
- :观测空间
- :观测函数
信念状态:
5.2 因果POMDP的定义
定义:因果POMDP(-POMDP)扩展了POMDP,加入因果结构:
| 新增组件 | 含义 |
|---|---|
| 状态-动作-观测的因果图 | |
| 因果转移函数 | |
| 初始因果信念 |
5.3 因果信念状态
定义:因果信念状态 是对潜在因果状态和混淆因素的联合信念:
其中:
- :可观测的因果状态
- :未观测的混淆因素
5.4 因果观测模型
观测函数分解为:
其中:
- :因果状态
- :环境状态
观测因果条件独立:
6. 价值函数的形式化
6.1 因果价值函数
标准价值函数:
因果价值函数:
6.2 因果贝尔曼方程
标准贝尔曼方程:
因果贝尔曼方程:
6.3 因果最优方程
最优价值函数:
最优策略:
7. 算法与实现
7.1 因果Q学习
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from typing import Dict, Tuple, Optional
class CausalQNetwork(nn.Module):
"""
因果Q网络
学习因果转移函数而非观察转移
"""
def __init__(self, state_dim: int, action_dim: int,
hidden_dim: int = 128, n_causal_factors: int = 8):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.n_causal_factors = n_causal_factors
# 状态编码器
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 因果因子提取器
self.causal_extractor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_causal_factors)
)
# Q值估计器
self.q_estimator = nn.Sequential(
nn.Linear(hidden_dim + n_causal_factors + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# 因果转移模型
self.causal_transition_model = nn.Sequential(
nn.Linear(hidden_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_causal_factors * state_dim) # 输出因果转移参数
)
def forward(self, state: Tensor, action: Tensor) -> Tensor:
"""计算Q值"""
s_enc = self.state_encoder(state)
c_factors = self.causal_extractor(s_enc)
combined = torch.cat([s_enc, c_factors, action], dim=-1)
return self.q_estimator(combined)
def predict_causal_transition(self, state: Tensor,
action: Tensor) -> Tensor:
"""
预测因果转移
返回: (batch, state_dim) 因果转移后的状态预测
"""
s_enc = self.state_encoder(state)
combined = torch.cat([s_enc, action], dim=-1)
# 预测因果效应
causal_effect = self.causal_transition_model(combined)
# 重塑为状态维度的缩放因子
effect = causal_effect.view(-1, self.n_causal_factors, self.state_dim)
effect_scale = torch.mean(effect, dim=1) # 聚合因果因子
# 应用因果效应到状态
return state + effect_scale
def compute_counterfactual_q(self, state: Tensor,
action: Tensor,
next_state: Tensor,
reward: Tensor,
gamma: float,
target_network: 'CausalQNetwork') -> Tensor:
"""
计算反事实Q值
Q_cf(s,a) = R + γ * max_a' E[P(s'|do(a),s)]
"""
# 预测因果转移
predicted_next_state = self.predict_causal_transition(state, action)
# 计算反事实优势
with torch.no_grad():
next_q = target_network(state, action)
target_q = reward + gamma * next_q
return target_q
class CausalQLearning:
"""
因果Q学习算法
"""
def __init__(self, state_dim: int, action_dim: int,
hidden_dim: int = 128,
lr: float = 1e-3,
gamma: float = 0.99,
epsilon: float = 1.0,
epsilon_decay: float = 0.995,
epsilon_min: float = 0.01):
self.q_network = CausalQNetwork(state_dim, action_dim, hidden_dim)
self.target_network = CausalQNetwork(state_dim, action_dim, hidden_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.training_step = 0
def select_action(self, state: Tensor) -> int:
"""ε-贪心动作选择"""
if torch.rand(1).item() < self.epsilon:
return torch.randint(0, self.q_network.action_dim, (1,)).item()
with torch.no_grad():
q_values = []
for a in range(self.q_network.action_dim):
action = torch.zeros(1, self.q_network.action_dim)
action[0, a] = 1.0
q = self.q_network(state, action)
q_values.append(q)
return torch.argmax(torch.cat(q_values)).item()
def update(self, state: Tensor, action: int,
reward: Tensor, next_state: Tensor, done: bool):
"""
更新Q网络
"""
# 准备动作tensor
action_tensor = torch.zeros(1, self.q_network.action_dim)
action_tensor[0, action] = 1.0
# 计算目标Q值(使用因果转移预测)
predicted_next = self.q_network.predict_causal_transition(state, action_tensor)
with torch.no_grad():
if done:
target_q = reward
else:
# 选择下一个动作
next_action = self.select_action(next_state)
next_action_tensor = torch.zeros(1, self.q_network.action_dim)
next_action_tensor[0, next_action] = 1.0
target_q = reward + self.gamma * self.q_network(
predicted_next, next_action_tensor
)
# 计算当前Q值
current_q = self.q_network(state, action_tensor)
# MSE损失
loss = nn.MSELoss()(current_q, target_q)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
self.optimizer.step()
# 更新epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
# 定期更新目标网络
self.training_step += 1
if self.training_step % 100 == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
return loss.item()7.2 因果策略梯度
class CausalPolicyGradient:
"""
因果策略梯度算法
使用因果优势函数估计
"""
def __init__(self, state_dim: int, action_dim: int,
hidden_dim: int = 128, lr: float = 3e-4):
self.policy_net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
self.value_net = CausalQNetwork(state_dim, action_dim, hidden_dim)
self.optimizer = optim.Adam(
list(self.policy_net.parameters()) + list(self.value_net.parameters()),
lr=lr
)
def compute_causal_advantage(self, states: Tensor, actions: Tensor,
rewards: Tensor, next_states: Tensor,
dones: Tensor, gamma: float = 0.99,
lambda_gae: float = 0.95) -> Tuple[Tensor, Tensor]:
"""
计算因果GAE(Generalized Advantage Estimation)
考虑因果转移而非观察转移
"""
with torch.no_grad():
# 预测因果转移后的状态
predicted_next = self.value_net.predict_causal_transition(
states, actions
)
# 因果V值
values = self.value_net(states, actions)
next_values = self.value_net(predicted_next, actions)
# TD误差
td_errors = rewards + gamma * next_values * (1 - dones) - values
# GAE
advantages = torch.zeros_like(td_errors)
gae = 0
for t in reversed(range(len(td_errors))):
gae = td_errors[t] + gamma * lambda_gae * gae * (1 - dones[t])
advantages[t] = gae
returns = advantages + values.detach()
return advantages, returns
def update(self, states: Tensor, actions: Tensor,
rewards: Tensor, next_states: Tensor, dones: Tensor):
"""
策略更新
"""
# 计算因果优势
advantages, returns = self.compute_causal_advantage(
states, actions, rewards, next_states, dones
)
# 策略损失
action_probs = self.policy_net(states)
action_indices = actions.unsqueeze(1)
selected_probs = torch.gather(action_probs, 1, action_indices).squeeze()
# 策略梯度损失
policy_loss = -(selected_probs * advantages.detach()).mean()
# 值函数损失
values = self.value_net(states, actions)
value_loss = nn.MSELoss()(values, returns)
# 总损失
total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy(action_probs)
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
self.optimizer.step()
return policy_loss.item(), value_loss.item()
def entropy(probs: Tensor) -> Tensor:
"""计算策略熵"""
return -(probs * torch.log(probs + 1e-8)).sum(dim=-1).mean()7.3 CMDP约束优化
class ConstrainedCMDP:
"""
约束CMDP求解器
使用投影梯度法处理因果约束
"""
def __init__(self, q_network: CausalQNetwork,
constraint_threshold: float = 0.1,
lr_pi: float = 1e-4,
lr_lambda: float = 1e-3):
self.q_network = q_network
self.constraint_threshold = constraint_threshold
self.policy_net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
self.optimizer_pi = optim.Adam(self.policy_net.parameters(), lr=lr_pi)
self.optimizer_lambda = optim.Adam([self.lambda_vec], lr=lr_lambda)
# 拉格朗日乘子
self.lambda_vec = nn.Parameter(torch.tensor([0.0]))
def compute_causal_constraint(self, states: Tensor,
actions: Tensor) -> Tensor:
"""
计算因果约束值
约束: 因果转移的方差应小于阈值
这鼓励策略选择更稳定的因果动作
"""
with torch.no_grad():
# 预测多次因果转移以估计方差
predictions = []
for _ in range(5):
pred = self.q_network.predict_causal_transition(states, actions)
predictions.append(pred)
predictions = torch.stack(predictions)
# 计算预测的方差
variance = torch.var(predictions, dim=0).mean()
return variance
def update(self, states: Tensor, actions: Tensor,
rewards: Tensor, next_states: Tensor, dones: Tensor):
"""
约束策略更新
"""
# 1. 计算因果约束
constraint_value = self.compute_causal_constraint(states, actions)
# 2. 计算约束违反
constraint_violation = torch.relu(constraint_value - self.constraint_threshold)
# 3. 更新拉格朗日乘子(梯度上升)
lambda_loss = -self.lambda_vec * (constraint_value - self.constraint_threshold)
self.optimizer_lambda.zero_grad()
lambda_loss.backward()
self.optimizer_lambda.step()
# 确保lambda非负
with torch.no_grad():
self.lambda_vec.clamp_(min=0)
# 4. 更新策略
q_values = self.q_network(states, actions)
policy_loss = -q_values.mean() + self.lambda_vec * constraint_violation
self.optimizer_pi.zero_grad()
policy_loss.backward()
self.optimizer_pi.step()
return constraint_value.item(), self.lambda_vec.item()8. 实例分析:因果GridWorld
8.1 环境设置
考虑一个简化的GridWorld,其中某些状态转移受混淆因素影响:
+-------------------+
| S | | G | S: 起始状态
|-------------------| G: 目标状态
| | [U] | | U: 混淆区域
|-------------------| [U]: 混淆因素影响此区域
| | | |
+-------------------+
8.2 因果结构
混淆因素 U
↓
状态 S ──────→ 状态 S'
↓ ↑
动作 A |
↓ |
└───────────┘
8.3 代码实现
import numpy as np
from typing import Tuple, List
class CausalGridWorld:
"""
因果GridWorld环境
包含混淆因素的MDP
"""
def __init__(self, size: int = 4):
self.size = size
self.n_states = size * size
self.n_actions = 4 # 上、下、左、右
# 状态坐标
self.state_to_pos = {i: (i // size, i % size) for i in range(self.n_states)}
# 目标状态
self.goal_state = self.n_states - 1
# 混淆区域
self.confounded_states = [5, 6, 9, 10]
# 因果转移概率(不受混淆影响)
self.causal_transition_prob = self._init_causal_transition()
# 观察转移概率(受混淆影响)
self.observed_transition_prob = self._init_observed_transition()
def _init_causal_transition(self) -> np.ndarray:
"""初始化因果转移矩阵 P(s'|do(a), s)"""
P = np.zeros((self.n_actions, self.n_states, self.n_states))
for a in range(self.n_actions):
for s in range(self.n_states):
probs = np.zeros(self.n_states)
# 计算目标位置
row, col = self.state_to_pos[s]
dr, dc = [(0, 1), (0, -1), (-1, 0), (1, 0)][a]
nr, nc = row + dr, col + dc
# 检查边界
if 0 <= nr < self.size and 0 <= nc < self.size:
next_s = nr * self.size + nc
probs[next_s] = 0.9
probs[s] = 0.1 # 小的失败概率
else:
probs[s] = 1.0 # 撞墙,保持原状态
P[a, s] = probs
return P
def _init_observed_transition(self) -> np.ndarray:
"""初始化观察转移矩阵(包含混淆效应)"""
P = self.causal_transition_prob.copy()
# 在混淆区域添加混淆效应
for s in self.confounded_states:
for a in range(self.n_actions):
# 混淆因素使得转移随机化
random_prob = 0.3
uniform = np.ones(self.n_states) / self.n_states
P[a, s] = (1 - random_prob) * P[a, s] + random_prob * uniform
return P
def do_action(self, state: int, action: int) -> np.ndarray:
"""
执行do操作,返回因果转移分布
"""
return self.causal_transition_prob[action, state]
def step(self, state: int, action: int,
use_causal: bool = False) -> Tuple[int, float, bool]:
"""
执行一步转移
Args:
state: 当前状态
action: 动作
use_causal: True则使用因果转移,False则使用观察转移
"""
if use_causal:
probs = self.do_action(state, action)
else:
probs = self.observed_transition_prob[action, state]
next_state = np.random.choice(self.n_states, p=probs)
# 奖励
reward = 1.0 if next_state == self.goal_state else -0.01
# 完成
done = next_state == self.goal_state
return next_state, reward, done
def compute_causal_effect(self, state: int, action_a: int,
action_b: int) -> float:
"""
计算动作a和b的因果效应
"""
dist_a = self.do_action(state, action_a)
dist_b = self.do_action(state, action_b)
# 使用总变差距离
return 0.5 * np.sum(np.abs(dist_a - dist_b))
def identify_optimal_policy(self) -> np.ndarray:
"""
识别因果最优策略
使用因果转移而非观察转移
"""
# 简化版本:使用值迭代
V = np.zeros(self.n_states)
policy = np.zeros(self.n_states, dtype=int)
for _ in range(1000):
for s in range(self.n_states):
if s == self.goal_state:
continue
q_values = []
for a in range(self.n_actions):
# 使用因果转移
next_probs = self.do_action(s, a)
q_a = np.sum(next_probs * (self.observed_transition_prob[a, s] * 0 +
[1.0 if i == self.goal_state else -0.01
for i in range(self.n_states)]))
q_values.append(q_a)
best_a = np.argmax(q_values)
V[s] = max(q_values)
policy[s] = best_a
return policy
def compare_policies():
"""
比较因果策略和观察策略
"""
np.random.seed(42)
env = CausalGridWorld(size=4)
print("=" * 60)
print("因果GridWorld分析")
print("=" * 60)
# 分析混淆区域
print("\n混淆区域分析:")
for s in env.confounded_states:
pos = env.state_to_pos[s]
print(f"\n状态 {s} (位置 {pos}):")
for a in range(env.n_actions):
causal = env.do_action(s, a)
observed = env.observed_transition_prob[a, s]
diff = np.sum(np.abs(causal - observed))
print(f" 动作 {a}: 因果-观察差异 = {diff:.4f}")
# 计算反事实效应
print("\n\n反事实效应分析:")
test_state = 5 # 混淆区域
for a1 in range(env.n_actions):
for a2 in range(a1 + 1, env.n_actions):
effect = env.compute_causal_effect(test_state, a1, a2)
print(f" PE(A={a1}, A'={a2}) = {effect:.4f}")
if __name__ == "__main__":
compare_policies()9. 总结
核心要点
- 因果MDP的优势:明确建模因果机制,支持干预和反事实推理
- 因果转移识别:通过do-calculus从观察数据中识别因果效应
- 约束处理:通过拉格朗日方法处理因果约束
- 因果POMDP:处理部分可观测环境中的因果推断
与标准MDP的关系
| 方面 | 标准MDP | 因果MDP |
|---|---|---|
| 转移函数 | $P(s’ | s,a)$ |
| 价值函数 | ||
| 最优性 | 局部最优 | 因果稳定 |
| 泛化能力 | 分布内 | 跨环境 |