因果探索策略

1. 探索问题的重新审视

1.1 传统探索的局限性

传统探索策略基于信息增益不确定性减少

策略原理局限
ε-贪心随机探索无针对性,效率低
UCB乐观上界仅考虑值函数不确定性
Boltzmann概率探索依赖准确的Q值估计
贝叶斯后验采样计算复杂

根本问题:这些策略关注的是预测不确定性,而非因果效应

1.2 因果探索的核心思想

因果探索关注:哪些动作对环境有真实的因果效应?

┌─────────────────────────────────────────────────────────────────┐
│                    探索目标的转变                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   传统探索:                                                      │
│   "哪个动作能减少我对未来的预测不确定性?"                         │
│                                                                  │
│   因果探索:                                                      │
│   "哪个动作会导致状态的实际变化?"                               │
│   "动作与状态之间是否存在真实的因果关系?"                        │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

1.3 因果探索的三大优势

优势描述示例
效率直接探索因果关系,避免虚假相关性区分天空颜色(虚假)与刹车灯(因果)
泛化因果知识跨环境迁移学会”刹车导致停车”后适应任何颜色的刹车灯
解释探索过程可解释”探索刹车是因为它与停车有因果关系”

2. 因果效应发现

2.1 因果效应的定义

定义:在状态 下执行动作 因果效应定义为:

其中 是随机策略。

2.2 因果效应的估计

基于采样的估计

基于模型的估计

其中 是学习的因果效应预测器。

2.3 因果效应的不确定性

效应方差

高不确定性意味着

  • 动作 对状态的影响不稳定
  • 需要更多探索来理解因果机制
  • 可能存在未发现的因果路径

3. 因果探索算法

3.1 Causal-UCB

Causal Upper Confidence Bound (Causal-UCB)

核心思想:探索奖励 = 值函数上界 + 因果效应不确定性

3.2 因果汤普森采样

Causal Thompson Sampling

def causal_thompson_sampling(state, causal_model, policy):
    # 1. 从后验分布采样因果效应
    ce_samples = causal_model.sample_causal_effect(state)
    
    # 2. 计算每个动作的采样因果Q值
    for action in actions:
        q_value = estimate_q_value(state, action)
        ce_value = ce_samples[action]
        sampled_value = q_value + lambda_ce * ce_value
    
    # 3. 选择采样Q值最高的动作
    return argmax(sampled_value)

3.3 因果信息增益

Causal Information Gain (CIG)

这衡量的是执行动作 后我们对因果机制的了解增加量


4. 反事实奖励机制

4.1 反事实奖励的定义

反事实奖励(Counterfactual Reward)是结合反事实推理的增强奖励:

4.2 反事实效应(PE)

反事实效应(Counterfactual Effect):

这衡量的是动作 与替代动作 的因果差异

4.3 反事实奖励的设计原则

原则描述实现
因果导向奖励应与因果效应正相关
反事实对比鼓励探索因果上不同的动作
稀疏性惩罚惩罚没有因果效应的动作

4.4 完整公式

其中最后一项鼓励动作的多样性。


5. 因果探索的数学框架

5.1 因果探索优化目标

目标:最大化累积因果效应同时最小化探索成本:

5.2 因果探索复杂度

引理:设环境满足因果马尔可夫假设,则学习因果结构所需的样本复杂度为:

其中 是因果图的VC维度。

5.3 因果探索的PAC界

定理(因果探索PAC界):以概率至少 ,因果探索策略 满足:

其中 是因果估计误差。


6. PyTorch实现

6.1 因果探索模型

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Tuple, Optional, List
import numpy as np
 
class CausalEffectEstimator(nn.Module):
    """
    因果效应估计器
    学习 P(S' | do(A), S) 而非 P(S' | A, S)
    """
    def __init__(self, state_dim: int, action_dim: int, 
                 hidden_dim: int = 256, n_samples: int = 32):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.n_samples = n_samples
        
        # 共享编码器
        self.encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )
        
        # 动作编码器
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim // 2),
            nn.ReLU()
        )
        
        # 因果效应预测器
        # 预测: do(a) 干预后的状态分布参数
        self.effect_predictor = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim * 2)  # 均值和方差
        )
        
        # 因果效应方差估计器(用于不确定性)
        self.variance_estimator = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Softplus()  # 确保方差为正
        )
    
    def forward(self, state: Tensor, action: Tensor) -> Tuple[Tensor, Tensor]:
        """
        预测因果效应
        
        Returns:
            effect_mean: 效应均值
            effect_var: 效应方差(不确定性)
        """
        s_enc = self.encoder(state)
        a_enc = self.action_encoder(action)
        
        combined = torch.cat([s_enc, a_enc], dim=-1)
        
        # 预测效应分布参数
        effect_params = self.effect_predictor(combined)
        effect_mean = effect_params[:, :self.state_dim]
        effect_log_var = effect_params[:, self.state_dim:]
        effect_var = torch.exp(torch.clamp(effect_log_var, -10, 10))
        
        return effect_mean, effect_var
    
    def estimate_causal_effect(self, state: Tensor, action: Tensor,
                               random_action: Tensor) -> Tensor:
        """
        估计因果效应
        CE(s, a) = E[S' | do(a), s] - E[S' | do(random), s]
        """
        # 动作a的效应
        effect_a_mean, _ = self.forward(state, action)
        
        # 随机动作的效应
        effect_random_mean, _ = self.forward(state, random_action)
        
        # 因果效应 = 差异
        return torch.norm(effect_a_mean - effect_random_mean, dim=-1, keepdim=True)
    
    def compute_effect_uncertainty(self, state: Tensor, 
                                   actions: Tensor) -> Tensor:
        """
        计算因果效应的不确定性
        用于指导探索
        """
        uncertainties = []
        for a in range(actions.shape[1]):
            action = actions[:, a:a+1].expand(-1, self.action_dim)
            _, var = self.forward(state, action)
            uncertainties.append(var.mean(dim=-1))
        
        return torch.stack(uncertainties, dim=1)
 
 
class CausalExplorer:
    """
    因果探索器
    结合因果效应估计和不确定性进行智能探索
    """
    def __init__(self, effect_estimator: CausalEffectEstimator,
                 q_network: nn.Module,
                 action_dim: int,
                 lambda_causal: float = 0.1,
                 lambda_explore: float = 0.5,
                 epsilon_start: float = 1.0,
                 epsilon_end: float = 0.01,
                 epsilon_decay: float = 0.995):
        
        self.effect_estimator = effect_estimator
        self.q_network = q_network
        self.action_dim = action_dim
        self.lambda_causal = lambda_causal
        self.lambda_explore = lambda_explore
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        
        # 动作访问计数
        self.action_counts = np.zeros(action_dim) + 1e-8
    
    def select_action(self, state: Tensor, 
                     training: bool = True) -> Tuple[int, dict]:
        """
        因果探索动作选择
        
        Returns:
            action: 选中的动作
            info: 附加信息(用于分析)
        """
        if training and np.random.rand() < self.epsilon:
            # 随机探索
            return np.random.randint(self.action_dim), {"mode": "random"}
        
        with torch.no_grad():
            state_batch = state.unsqueeze(0) if state.dim() == 1 else state
            
            # 计算每个动作的Q值和因果效应
            q_values = []
            causal_effects = []
            uncertainties = []
            
            for a in range(self.action_dim):
                # Q值
                action_onehot = torch.zeros(1, self.action_dim)
                action_onehot[0, a] = 1.0
                q = self.q_network(state_batch, action_onehot)
                q_values.append(q.item())
                
                # 因果效应
                random_action = torch.zeros(1, self.action_dim)
                random_action[0, np.random.randint(self.action_dim)] = 1.0
                
                ce = self.effect_estimator.estimate_causal_effect(
                    state_batch, action_onehot, random_action
                )
                causal_effects.append(ce.item())
                
                # 不确定性
                unc = self.effect_estimator.compute_effect_uncertainty(
                    state_batch, torch.eye(self.action_dim).unsqueeze(0)
                )
                uncertainties.append(unc[0, a].item())
            
            # 综合得分
            scores = []
            for a in range(self.action_dim):
                # Q值归一化
                q_norm = (q_values[a] - np.mean(q_values)) / (np.std(q_values) + 1e-8)
                
                # 因果效应归一化
                ce_norm = (causal_effects[a] - np.mean(causal_effects)) / (np.std(causal_effects) + 1e-8)
                
                # UCB项
                count = self.action_counts[a]
                ucb = np.sqrt(np.log(self.action_counts.sum() + 1) / count)
                
                # 综合得分
                score = q_norm + self.lambda_causal * ce_norm + self.lambda_explore * ucb
                scores.append(score)
            
            action = int(np.argmax(scores))
            self.action_counts[action] += 1
            
            # 衰减epsilon
            self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
            
            return action, {
                "mode": "causal_ucb",
                "q_values": q_values,
                "causal_effects": causal_effects,
                "uncertainties": uncertainties,
                "scores": scores
            }
 
 
class CounterfactualRewardCalculator:
    """
    反事实奖励计算器
    """
    def __init__(self, effect_estimator: CausalEffectEstimator,
                 lambda_ce: float = 0.1,
                 lambda_pe: float = 0.05,
                 lambda_entropy: float = 0.01):
        
        self.effect_estimator = effect_estimator
        self.lambda_ce = lambda_ce
        self.lambda_pe = lambda_pe
        self.lambda_entropy = lambda_entropy
    
    def compute_reward(self, state: Tensor, action: Tensor,
                      next_state: Tensor, 
                      policy_action_dist: Optional[Tensor] = None) -> Tensor:
        """
        计算反事实增强奖励
        
        R_cf = R(s,a,s') + λ_ce * CE(s,a) + λ_pe * PE(s,a,π(s))
        """
        # 标准奖励(简化为状态变化奖励)
        base_reward = torch.norm(next_state - state, dim=-1, keepdim=True)
        
        # 1. 因果效应奖励
        random_action = torch.zeros_like(action)
        random_action[:, torch.randint(action.shape[1], (action.shape[0],))] = 1.0
        
        ce = self.effect_estimator.estimate_causal_effect(
            state, action, random_action
        )
        
        # 2. 反事实效应惩罚(鼓励探索替代动作)
        pe = torch.zeros_like(ce)
        if policy_action_dist is not None:
            # 计算与策略分布的反事实效应
            for a in range(action.shape[1]):
                counterfactual = torch.zeros_like(action)
                counterfactual[:, a] = 1.0
                counterfactual_action = F.gumbel_softmax(
                    torch.ones(1, action.shape[1]), tau=1.0
                ) if policy_action_dist is None else policy_action_dist[:, a:a+1]
                
                pe_a = self.effect_estimator.estimate_causal_effect(
                    state, counterfactual, counterfactual_action
                )
                pe += pe_a * policy_action_dist[:, a].unsqueeze(1)
        
        # 3. 动作熵惩罚
        entropy_bonus = torch.zeros_like(ce)
        if policy_action_dist is not None:
            entropy_bonus = -(policy_action_dist * torch.log(policy_action_dist + 1e-8)).sum(
                dim=1, keepdim=True
            )
        
        # 综合奖励
        reward = base_reward + self.lambda_ce * ce - self.lambda_pe * pe - self.lambda_entropy * entropy_bonus
        
        return reward

6.2 因果探索DQN实现

class CausalDQN:
    """
    因果深度Q网络
    结合因果效应探索的DQN变体
    """
    def __init__(self, state_dim: int, action_dim: int,
                 hidden_dim: int = 256,
                 lr: float = 1e-3,
                 gamma: float = 0.99,
                 target_update_freq: int = 100,
                 exploration_steps: int = 1000,
                 batch_size: int = 64):
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.target_update_freq = target_update_freq
        self.exploration_steps = exploration_steps
        self.batch_size = batch_size
        self.total_steps = 0
        
        # Q网络
        self.q_network = CausalQNetwork(state_dim, action_dim, hidden_dim)
        self.target_network = CausalQNetwork(state_dim, action_dim, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        # 因果效应估计器
        self.causal_estimator = CausalEffectEstimator(
            state_dim, action_dim, hidden_dim
        )
        
        # 探索器
        self.explorer = CausalExplorer(
            self.causal_estimator, self.q_network, action_dim
        )
        
        # 优化器
        self.optimizer = optim.Adam(
            list(self.q_network.parameters()) + 
            list(self.causal_estimator.parameters()),
            lr=lr
        )
        
        # 经验回放
        self.replay_buffer = ReplayBuffer(capacity=100000)
    
    def update_causal_model(self, states: Tensor, actions: Tensor,
                           next_states: Tensor) -> float:
        """
        更新因果效应模型
        学习 P(S' | do(A), S)
        """
        # 随机动作作为反事实对照
        random_actions = torch.zeros_like(actions)
        random_actions.scatter_(1, torch.randint(0, self.action_dim, 
                        (actions.shape[0], 1)), 1)
        
        # 预测do(actions)和do(random_actions)的效应
        pred_actions, var_actions = self.causal_estimator(states, actions)
        pred_random, var_random = self.causal_estimator(states, random_actions)
        
        # 因果效应损失
        causal_effect_target = next_states - states
        ce_loss = F.mse_loss(
            (pred_actions - pred_random), 
            causal_effect_target.detach()
        )
        
        # 方差正则化(鼓励确定性效应)
        var_loss = -0.01 * torch.log(var_actions + 1e-8).mean()
        
        causal_loss = ce_loss + var_loss
        
        self.optimizer.zero_grad()
        causal_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.causal_estimator.parameters(), 1.0
        )
        self.optimizer.step()
        
        return causal_loss.item()
    
    def update_q_network(self, batch: dict) -> float:
        """
        更新Q网络
        """
        states = batch["states"]
        actions = batch["actions"]
        rewards = batch["rewards"]
        next_states = batch["next_states"]
        dones = batch["dones"]
        
        # 计算目标Q值
        with torch.no_grad():
            # 预测下一个状态的因果效应
            next_causal_effect, _ = self.causal_estimator(
                next_states, 
                torch.zeros_like(actions)
            )
            
            # 使用因果增强的TD目标
            next_q = self.target_network(next_states, actions)
            td_target = rewards + self.gamma * next_q * (1 - dones)
        
        # 当前Q值
        current_q = self.q_network(states, actions)
        
        # Q学习损失
        q_loss = F.mse_loss(current_q, td_target)
        
        self.optimizer.zero_grad()
        q_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10.0)
        self.optimizer.step()
        
        return q_loss.item()
    
    def train_step(self, state: np.ndarray, action: int,
                   reward: float, next_state: np.ndarray, 
                   done: bool) -> dict:
        """
        单步训练
        """
        self.total_steps += 1
        
        # 存储经验
        self.replay_buffer.push(state, action, reward, next_state, done)
        
        # 更新因果模型
        if len(self.replay_buffer) > self.batch_size:
            batch = self.replay_buffer.sample(self.batch_size)
            causal_loss = self.update_causal_model(
                batch["states"], batch["actions"], batch["next_states"]
            )
        else:
            causal_loss = 0.0
        
        # 更新Q网络
        if len(self.replay_buffer) > self.exploration_steps:
            batch = self.replay_buffer.sample(self.batch_size)
            q_loss = self.update_q_network(batch)
        else:
            q_loss = 0.0
        
        # 更新目标网络
        if self.total_steps % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        
        return {
            "causal_loss": causal_loss,
            "q_loss": q_loss,
            "epsilon": self.explorer.epsilon
        }
    
    def select_action(self, state: np.ndarray, 
                     training: bool = True) -> Tuple[int, dict]:
        """
        选择动作
        """
        state_tensor = torch.FloatTensor(state)
        return self.explorer.select_action(state_tensor, training)
 
 
class ReplayBuffer:
    """
    经验回放缓冲区
    """
    def __init__(self, capacity: int):
        self.buffer = []
        self.capacity = capacity
        self.position = 0
    
    def push(self, state: np.ndarray, action: int, 
            reward: float, next_state: np.ndarray, done: bool):
        """添加经验"""
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size: int) -> dict:
        """采样批次"""
        batch = random.sample(self.buffer, batch_size)
        
        states = torch.FloatTensor(np.array([b[0] for b in batch]))
        actions = torch.LongTensor([[b[1]] for b in batch])
        rewards = torch.FloatTensor([[b[2]] for b in batch])
        next_states = torch.FloatTensor(np.array([b[3] for b in batch]))
        dones = torch.FloatTensor([[b[4]] for b in batch])
        
        return {
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "next_states": next_states,
            "dones": dones
        }
    
    def __len__(self):
        return len(self.buffer)

7. 收敛性分析

7.1 因果探索的遗憾界

定理(因果UCB遗憾界):设 是真实因果效应,则Causal-UCB的累积遗憾满足:

其中 是动作数, 是因果估计误差。

7.2 因果探索的PAC性质

定理(因果PAC探索):以概率至少 ,因果探索器在

步内找到 -最优策略。


8. 实验与分析

8.1 基准环境

class CausalExplorationBenchmark:
    """
    因果探索基准环境
    用于评估不同探索策略
    """
    def __init__(self, n_states: int = 10, n_actions: int = 4,
                 causal_strength: float = 0.8,
                 confounding: float = 0.3):
        
        self.n_states = n_states
        self.n_actions = n_actions
        
        # 因果转移(强因果效应)
        self.causal_transition = self._init_causal_transition(causal_strength)
        
        # 混淆转移(添加混淆因素)
        self.confounded_transition = self._add_confounding(
            self.causal_transition, confounding
        )
    
    def evaluate_strategy(self, strategy: str, n_episodes: int = 1000) -> dict:
        """
        评估探索策略
        
        Args:
            strategy: 'random', 'ucb', 'thompson', 'causal_ucb', 'causal_thompson'
            n_episodes: 评估episode数量
        """
        rewards = []
        causal_discoveries = []
        
        for episode in range(n_episodes):
            state = self.reset()
            episode_reward = 0
            
            for step in range(100):
                # 选择动作
                action = self.select_action(strategy, state)
                
                # 执行(使用混淆转移)
                next_state, reward, done = self.step(state, action, use_causal=False)
                
                # 记录
                episode_reward += reward
                
                # 评估因果发现
                if strategy.startswith('causal'):
                    ce = self.compute_causal_effect(state, action)
                    causal_discoveries.append(ce)
                
                state = next_state
                if done:
                    break
            
            rewards.append(episode_reward)
        
        return {
            "mean_reward": np.mean(rewards),
            "std_reward": np.std(rewards),
            "mean_causal_discovery": np.mean(causal_discoveries) if causal_discoveries else 0
        }
    
    def compare_strategies(self) -> pd.DataFrame:
        """比较所有策略"""
        strategies = ['random', 'ucb', 'thompson', 
                     'causal_ucb', 'causal_thompson']
        
        results = []
        for strategy in strategies:
            metrics = self.evaluate_strategy(strategy)
            results.append({
                "Strategy": strategy,
                "Mean Reward": metrics["mean_reward"],
                "Std Reward": metrics["std_reward"],
                "Causal Discovery": metrics["mean_causal_discovery"]
            })
        
        return pd.DataFrame(results)

8.2 预期结果

策略平均奖励因果发现收敛速度
Random0
UCB
Thompson
Causal-UCB
Causal-Thompson

9. 实际应用

9.1 机器人控制

场景:机械臂抓取不同材质的物体

传统探索问题:
- 随机探索浪费时间
- 可能学会错误的策略(如:基于背景颜色判断物体位置)

因果探索解决方案:
- 快速识别哪些动作有因果效应(如:闭合爪子 vs 改变背景)
- 泛化到新物体(因果知识迁移)

9.2 推荐系统

场景:学习用户偏好

传统探索问题:
- 可能利用虚假相关性(如:用户点击某视频因为封面颜色)
- 难以泛化到新用户

因果探索解决方案:
- 识别用户真实偏好(因果关系)与表面特征(混淆因素)
- 跨用户迁移推荐策略

9.3 医疗治疗

场景:个性化治疗方案

传统探索问题:
- 难以区分治疗效果与患者自愈
- 混淆因素(年龄、性别)干扰决策

因果探索解决方案:
- 估计治疗的因果效应(ATE)
- 针对不同患者亚群定制治疗

10. 总结

核心要点

  1. 因果探索 vs 传统探索:关注因果效应而非预测不确定性
  2. 反事实奖励:结合do操作和反事实推理增强奖励信号
  3. Causal-UCB:将因果效应方差纳入置信上界
  4. 收敛保证:PAC理论和遗憾界提供理论支持

算法对比

算法因果效应不确定性计算复杂度
ε-贪心
UCB
因果UCB
Thompson
因果Thompson

下一步


参考文献