因果逆强化学习与约束推断

1. 逆强化学习基础回顾

1.1 标准IRL问题

正向问题:给定MDP ,求最优策略

逆向问题:给定MDP 和专家演示 ,恢复奖励函数

┌─────────────────────────────────────────────────────────────────┐
│                    逆强化学习问题                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   正向:  (S, A, P, γ) + R  →  π*                               │
│           ↑                                                      │
│           │                                                     │
│           ▼                                                      │
│   逆向:  (S, A, P, γ) + π*  →  R                                │
│                                                                  │
│   问题: 通常有多个R使得π*是最优的(不可识别性)                   │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

1.2 标准IRL的局限性

局限性描述后果
奖励不可识别多个奖励函数可解释同一策略无法恢复真实奖励
混淆因素演示可能受混淆因素影响学习到虚假相关性
缺乏因果解释无法区分因果动作和巧合动作泛化能力差
约束缺失假设专家完全理性无法处理安全约束

1.3 因果IRL的必要性

因果IRL的核心思想

  1. 因果奖励恢复:学习奖励的因果结构,而非表面相关性
  2. 约束推断:从演示中推断安全约束和偏好
  3. 反事实校正:校正混淆因素导致的偏差

2. 因果约束推断

2.1 约束推断问题

定义:给定专家演示 和环境模型,推断约束集合 ,使得专家策略 满足这些约束。

其中 是满足约束 的最优策略。

2.2 约束类型

约束类型数学形式示例
安全约束$P(\text{safe}do(\pi)) \geq 1-\epsilon$
效率约束完成任务时间不超过K
偏好约束$P(a_1s) > P(a_2
因果约束动作必须有因果效应

2.3 约束推断的数学框架

约束推断的优化目标

其中:

  • :约束违反损失
  • :正则化项

约束违反损失

其中 表示约束 被满足。


3. Inverse Constrained Reinforcement Learning (ICRL)

3.1 ICRL基本框架

ICRL(Inverse Constrained RL)同时推断奖励函数和约束:

其中 是满足约束 的策略集合。

3.2 ICRL算法流程

┌─────────────────────────────────────────────────────────────────┐
│                      ICRL算法流程                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   1. 初始化奖励函数R和约束C                                       │
│                                                                  │
│   2. 交替优化循环:                                                │
│      ┌─────────────────────────────────────────────────┐        │
│      │  a) 给定R,C,使用IRL学习约束策略π_C              │        │
│      │                                                    │        │
│      │  b) 给定π_C,更新约束C以最小化约束违反            │        │
│      │                                                    │        │
│      │  c) 给定C,更新奖励R以最大化演示可能性             │        │
│      └─────────────────────────────────────────────────┘        │
│                                                                  │
│   3. 直到收敛                                                     │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

3.3 约束更新规则

投影梯度下降

def update_constraints(constraints, demonstrations, policy):
    """
    更新约束参数
    """
    violations = compute_constraint_violations(constraints, demonstrations)
    
    # 梯度上升(最小化违反)
    constraints = constraints - alpha * violations
    
    # 投影到可行域
    constraints = project_to_constraints(constraints)
    
    return constraints

4. 因果IRL的数学基础

4.1 因果奖励函数

定义:因果奖励函数 满足:

其中 是状态转移的因果效应。

4.2 因果IRL的优化目标

因果最大熵IRL

其中 是因果奖励函数空间。

4.3 因果约束的识别

定理(因果约束可识别性):设因果图 已知,则约束集合 可以从专家演示中识别,如果:

  1. 充分性 能够完全解释专家行为
  2. 最小性 是满足充分性的最小集合
  3. 因果完备性:所有因果相关约束都在

5. Preference-Based IRL

5.1 偏好学习框架

偏好数据

其中

5.2 Bradley-Terry模型

偏好概率

其中 是状态-动作特征, 是偏好参数。

5.3 因果偏好学习

因果偏好模型

其中:

  • :因果效应
  • :反事实对比

6. PyTorch实现

6.1 因果约束推断器

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Tuple, List, Optional, Dict
import numpy as np
 
class CausalConstraintInferrer(nn.Module):
    """
    因果约束推断器
    从专家演示中学习因果约束
    """
    def __init__(self, state_dim: int, action_dim: int,
                 n_constraints: int = 4,
                 hidden_dim: int = 128):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.n_constraints = n_constraints
        
        # 状态编码器
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 动作编码器
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim // 2),
            nn.ReLU()
        )
        
        # 约束评分网络
        self.constraint_scorer = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_constraints),
            nn.Sigmoid()  # 输出约束满足概率
        )
        
        # 因果效应估计器(复用之前的模块)
        from causal_effect_estimator import CausalEffectEstimator
        self.causal_estimator = CausalEffectEstimator(
            state_dim, action_dim, hidden_dim
        )
    
    def forward(self, state: Tensor, action: Tensor) -> Tensor:
        """
        预测每个约束的满足概率
        
        Returns:
            constraint_probs: (batch, n_constraints) 每个约束的满足概率
        """
        s_enc = self.state_encoder(state)
        a_enc = self.action_encoder(action)
        
        combined = torch.cat([s_enc, a_enc], dim=-1)
        return self.constraint_scorer(combined)
    
    def compute_constraint_loss(self, states: Tensor, actions: Tensor,
                               expert_mask: Optional[Tensor] = None) -> Tensor:
        """
        计算约束违反损失
        
        鼓励专家动作满足约束,非专家动作违反约束
        """
        constraint_probs = self.forward(states, actions)
        
        if expert_mask is not None:
            # 专家动作应该满足约束
            expert_loss = -(1 - constraint_probs) * expert_mask.unsqueeze(1)
            # 非专家动作可以违反约束
            non_expert_loss = constraint_probs * (1 - expert_mask).unsqueeze(1)
            loss = (expert_loss + non_expert_loss).mean()
        else:
            # 无监督版本:鼓励约束满足
            loss = -(constraint_probs.mean())
        
        return loss
    
    def infer_constraints(self, demonstrations: List[Dict]) -> Dict[str, float]:
        """
        从演示中推断约束
        """
        self.eval()
        
        with torch.no_grad():
            constraint_satisfactions = {i: [] for i in range(self.n_constraints)}
            
            for demo in demonstrations:
                states = torch.FloatTensor(demo["states"])
                actions = torch.FloatTensor(demo["actions"])
                
                probs = self.forward(states, actions)
                
                for c in range(self.n_constraints):
                    constraint_satisfactions[c].extend(probs[:, c].tolist())
        
        return {
            f"constraint_{c}": np.mean(satisfactions) 
            for c, satisfactions in constraint_satisfactions.items()
        }
 
 
class CausalIRL:
    """
    因果逆强化学习算法
    结合约束推断和因果奖励学习
    """
    def __init__(self, state_dim: int, action_dim: int,
                 n_constraints: int = 4,
                 hidden_dim: int = 256,
                 lr: float = 1e-3,
                 gamma: float = 0.99):
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        
        # 因果奖励网络
        self.reward_network = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # 约束推断器
        self.constraint_inferrer = CausalConstraintInferrer(
            state_dim, action_dim, n_constraints, hidden_dim
        )
        
        # 因果效应估计器
        self.causal_estimator = CausalEffectEstimator(
            state_dim, action_dim, hidden_dim
        )
        
        # 鉴别器(用于区分专家演示和学到的策略)
        self.discriminator = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # 优化器
        self.optimizer = optim.Adam(
            list(self.reward_network.parameters()) +
            list(self.constraint_inferrer.parameters()) +
            list(self.causal_estimator.parameters()) +
            list(self.discriminator.parameters()),
            lr=lr
        )
    
    def compute_causal_reward(self, state: Tensor, 
                             action: Tensor) -> Tensor:
        """
        计算因果奖励
        R_c(s, a) = f(CE(s, a))
        """
        # 估计因果效应
        effect_mean, effect_var = self.causal_estimator(state, action)
        
        # 因果效应作为奖励
        causal_reward = torch.norm(effect_mean, dim=-1, keepdim=True)
        
        # 结合网络预测
        sa_combined = torch.cat([state, action], dim=-1)
        network_reward = self.reward_network(sa_combined)
        
        return causal_reward + 0.1 * network_reward
    
    def compute_constraint_reward(self, state: Tensor,
                                  action: Tensor) -> Tensor:
        """
        计算约束奖励(惩罚违反约束的动作)
        """
        constraint_probs = self.constraint_inferrer(state, action)
        
        # 低概率 = 高惩罚
        constraint_penalty = (1 - constraint_probs).mean(dim=-1, keepdim=True)
        
        return -constraint_penalty
    
    def update(self, demonstrations: List[Dict],
              generated_trajectories: List[Dict],
              lambda_constraint: float = 0.5) -> Dict[str, float]:
        """
        更新因果IRL模型
        
        Args:
            demonstrations: 专家演示列表
            generated_trajectories: 学到的策略生成的轨迹
            lambda_constraint: 约束奖励权重
        """
        self.optimizer.zero_grad()
        
        total_loss = 0.0
        losses = {}
        
        # 1. 对抗损失:让鉴别器区分专家和生成轨迹
        expert_loss = 0.0
        generated_loss = 0.0
        
        for demo in demonstrations:
            states = torch.FloatTensor(demo["states"])
            actions = torch.FloatTensor(demo["actions"])
            
            sa_pairs = torch.cat([states, actions], dim=-1)
            expert_loss -= torch.log(self.discriminator(sa_pairs) + 1e-8).mean()
        
        for traj in generated_trajectories:
            states = torch.FloatTensor(traj["states"])
            actions = torch.FloatTensor(traj["actions"])
            
            sa_pairs = torch.cat([states, actions], dim=-1)
            generated_loss -= torch.log(1 - self.discriminator(sa_pairs) + 1e-8).mean()
        
        adversarial_loss = expert_loss + generated_loss
        losses["adversarial"] = adversarial_loss.item()
        
        # 2. 因果奖励一致性损失
        reward_consistency_loss = 0.0
        for demo in demonstrations:
            states = torch.FloatTensor(demo["states"])
            actions = torch.FloatTensor(demo["actions"])
            
            # 专家动作应该有高因果奖励
            causal_reward = self.compute_causal_reward(states, actions)
            reward_consistency_loss -= causal_reward.mean()
        
        reward_consistency_loss /= max(len(demonstrations), 1)
        losses["reward_consistency"] = reward_consistency_loss.item()
        
        # 3. 约束推断损失
        constraint_loss = 0.0
        for demo in demonstrations:
            states = torch.FloatTensor(demo["states"])
            actions = torch.FloatTensor(demo["actions"])
            
            # 标记为专家动作
            expert_mask = torch.ones(len(demo["states"]))
            constraint_loss += self.constraint_inferrer.compute_constraint_loss(
                states, actions, expert_mask
            )
        
        constraint_loss /= max(len(demonstrations), 1)
        losses["constraint"] = constraint_loss.item()
        
        # 总损失
        total_loss = (adversarial_loss + 
                     reward_consistency_loss + 
                     lambda_constraint * constraint_loss)
        
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.optimizer.param_groups[0]["params"], 1.0)
        self.optimizer.step()
        
        return losses
    
    def predict_reward(self, state: Tensor, action: Tensor) -> Tensor:
        """预测总奖励"""
        causal_reward = self.compute_causal_reward(state, action)
        constraint_penalty = self.compute_constraint_reward(state, action)
        return causal_reward + lambda_constraint * constraint_penalty

6.2 约束推断ICRL实现

class ICRLConstraintInference:
    """
    Inverse Constrained RL 约束推断
    实现ICRL算法
    """
    def __init__(self, env, irl_module: CausalIRL,
                 lambda_constraints: float = 1.0,
                 constraint_lr: float = 1e-2):
        
        self.env = env
        self.irl = irl_module
        self.lambda_constraints = lambda_constraints
        self.constraint_lr = constraint_lr
        
        # 约束参数
        self.constraint_weights = torch.zeros(1, irl_module.n_constraints)
        self.constraint_weights = nn.Parameter(self.constraint_weights)
        
        # 约束优化器
        self.constraint_optimizer = optim.Adam(
            [self.constraint_weights], lr=constraint_lr
        )
    
    def compute_constraint_violation(self, state: Tensor, 
                                     action: Tensor) -> Tensor:
        """
        计算约束违反程度
        
        Returns:
            violation: 约束违反标量
        """
        constraint_probs = self.irl.constraint_inferrer(state, action)
        
        # 违反 = 1 - 满足概率
        violations = 1 - constraint_probs
        
        # 加权求和
        weighted_violations = violations * torch.softmax(
            self.constraint_weights, dim=-1
        )
        
        return weighted_violations.sum(dim=-1).mean()
    
    def update_constraints(self, demonstrations: List[Dict],
                          n_iterations: int = 100) -> List[float]:
        """
        迭代更新约束参数
        """
        violations_history = []
        
        for iteration in range(n_iterations):
            total_violation = 0.0
            
            for demo in demonstrations:
                states = torch.FloatTensor(demo["states"])
                actions = torch.FloatTensor(demo["actions"])
                
                # 计算违反
                violation = self.compute_constraint_violation(states, actions)
                
                # 梯度下降(最小化违反)
                self.constraint_optimizer.zero_grad()
                (-violation).backward()
                self.constraint_optimizer.step()
                
                total_violation += violation.item()
            
            avg_violation = total_violation / len(demonstrations)
            violations_history.append(avg_violation)
            
            # 投影到非负域
            with torch.no_grad():
                self.constraint_weights.clamp_(min=0)
        
        return violations_history
    
    def combined_update(self, demonstrations: List[Dict],
                       generated_trajectories: List[Dict],
                       n_irl_iterations: int = 5,
                       n_constraint_iterations: int = 10) -> Dict[str, float]:
        """
        交替更新IRL和约束
        
        1. 固定约束,更新IRL
        2. 固定IRL,更新约束
        """
        all_losses = {"irl": [], "constraint": []}
        
        for outer_iter in range(n_irl_iterations):
            # 步骤1:固定约束,更新IRL
            for _ in range(3):
                irl_losses = self.irl.update(
                    demonstrations, generated_trajectories,
                    self.lambda_constraints
                )
                all_losses["irl"].append(irl_losses)
            
            # 步骤2:固定IRL,更新约束
            constraint_violations = self.update_constraints(
                demonstrations, n_constraint_iterations
            )
            all_losses["constraint"].extend(constraint_violations)
        
        return all_losses
 
 
class PreferenceBasedCausalIRL:
    """
    基于偏好的因果IRL
    从偏好数据中学习因果奖励和约束
    """
    def __init__(self, state_dim: int, action_dim: int,
                 hidden_dim: int = 256):
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # 因果奖励网络
        self.reward_network = CausalRewardNetwork(state_dim, action_dim, hidden_dim)
        
        # 偏好鉴别器
        self.preference_discriminator = nn.Sequential(
            nn.Linear(state_dim + 2 * action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # 优化器
        self.optimizer = optim.Adam(self.reward_network.parameters(), lr=1e-3)
    
    def bradley_terry_loss(self, states: Tensor, 
                          action_pairs: Tensor, 
                          preferences: Tensor) -> Tensor:
        """
        Bradley-Terry偏好损失
        
        Args:
            states: 状态(batch, state_dim)
            action_pairs: 动作对(batch, 2, action_dim)
            preferences: 偏好标签(batch,),1表示第一个动作优选,0表示第二个
        """
        a1, a2 = action_pairs[:, 0], action_pairs[:, 1]
        
        # 计算动作对的奖励差异
        r1 = self.reward_network(states, a1)
        r2 = self.reward_network(states, a2)
        
        # 预测偏好概率
        diff = r1 - r2
        pred_prob = torch.sigmoid(diff)
        
        # 二元交叉熵损失
        loss = F.binary_cross_entropy(
            pred_prob.squeeze(), 
            preferences.float()
        )
        
        return loss
    
    def counterfactual_preference_loss(self, states: Tensor,
                                      action_pairs: Tensor,
                                      preferences: Tensor) -> Tensor:
        """
        反事实偏好损失
        """
        a1, a2 = action_pairs[:, 0], action_pairs[:, 1]
        
        # 计算因果效应
        ce1 = self.reward_network.causal_estimator.estimate_causal_effect(
            states, a1, a2
        )
        ce2 = self.reward_network.causal_estimator.estimate_causal_effect(
            states, a2, a1
        )
        
        # 偏好与因果效应一致
        # 如果a1优选,则CE(a1)应该大于CE(a2)
        ce_diff = ce1 - ce2
        
        loss = F.margin_ranking_loss(
            ce_diff.squeeze(),
            torch.zeros_like(ce_diff.squeeze()),
            preferences.float() * 2 - 1,  # 转换为+1/-1
            margin=0.5
        )
        
        return loss
    
    def update(self, preferences: Dict) -> float:
        """
        更新偏好IRL模型
        """
        states = torch.FloatTensor(preferences["states"])
        action_pairs = torch.FloatTensor(preferences["action_pairs"])
        prefs = torch.LongTensor(preferences["preferences"])
        
        # 标准偏好损失
        bt_loss = self.bradley_terry_loss(states, action_pairs, prefs)
        
        # 反事实偏好损失
        cf_loss = self.counterfactual_preference_loss(
            states, action_pairs, prefs
        )
        
        # 总损失
        loss = bt_loss + 0.5 * cf_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def predict_reward(self, state: Tensor, action: Tensor) -> Tensor:
        """预测奖励"""
        return self.reward_network(state, action)

7. 因果约束推断的变体

7.1 Safe IRL

Safe IRL专注于安全约束推断:

7.2 Inverse Reward Design (IRD)

IRD从奖励函数的不确定性中推断偏好:

7.3 Reward Inference from Demonstrations

RIDE通过轨迹比较推断奖励:


8. 实际应用

8.1 自动驾驶

场景:从人类驾驶数据中学习安全约束

约束推断:
- 不变约束:始终保持车道、不闯红灯
- 可变约束:跟车距离、变道时机

因果分析:
- 刹车动作与停车有因果关系
- 方向盘角度与轨迹弯曲有因果关系

8.2 手术机器人

场景:从外科医生演示中学习手术约束

约束推断:
- 安全约束:不损伤特定组织
- 精度约束:手术器械位置精度
- 效率约束:手术时间

因果分析:
- 器械移动与组织响应有因果关系
- 力度与组织变形有因果关系

8.3 工业自动化

场景:从熟练工人操作中学习机器人约束

约束推断:
- 效率约束:完成任务时间
- 质量约束:产品质量
- 安全约束:工人安全距离

因果分析:
- 动作与产品质量有因果关系
- 动作与能耗有因果关系

9. 收敛性与理论保证

9.1 ICRL的收敛性

定理(ICRL收敛性):ICRL算法在以下条件下收敛:

  1. 约束空间 是凸的
  2. 约束违反损失 是凸的
  3. 学习率满足 Robbins-Monro 条件

9.2 因果IRL的PAC界

定理(因果IRL样本复杂度):以概率至少 ,因果IRL在

样本内收敛到 -最优奖励和约束。


10. 总结

核心要点

  1. 约束推断:从专家演示中学习隐式约束,而非显式奖励
  2. 因果IRL:结合因果效应和反事实推理增强IRL
  3. 偏好学习:从偏好比较中推断因果偏好结构
  4. 安全性:确保学到的策略满足推断的安全约束

算法对比

方法优点缺点适用场景
标准IRL理论基础好不可识别性奖励可恢复场景
ICRL处理约束计算复杂安全关键应用
因果IRL泛化强需要因果假设跨环境迁移
偏好IRL数据高效需要偏好标注人类反馈学习

下一步


参考文献