1. 研究背景与问题定义

1.1 因果推断的现状

传统因果推断方法需要针对特定问题设计专门的估计器1

  • 需要领域专家定义因果图
  • 假设检验需要大量数据
  • 难以迁移到新领域

1.2 因果基础模型的愿景

因果基础模型(Causal Foundation Models, CFM)旨在:

通过**摊销推断(Amortized Inference)**实现跨领域的通用因果推断

1.3 部分图结构的价值

实际应用中,我们往往拥有部分已知的因果知识:

# 部分图结构示例
# 已知:饮食 → 胆固醇 → 心脏病
# 未知:其他因素的关系
 
partial_graph = {
    'diet': ['cholesterol'],
    'cholesterol': ['heart_disease'],
    'exercise': None,  # 未知
    'stress': None,   # 未知
}

2. 技术框架

2.1 核心思想

CFM的核心是利用部分图结构指导学习1

  1. 已知边:提供可靠的因果路径
  2. 未知边:通过数据学习
  3. 联合优化:同时利用先验和数据

2.2 架构设计

┌─────────────────────────────────────────────────────────────────────────┐
│                         CFM 整体架构                                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  输入:                                                                  │
│    - 部分因果图 G_partial                                              │
│    - 观测数据 D                                                        │
│                                                                          │
│    │                                                                     │
│    ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    因果图编码器                                      │    │
│  │                                                                 │    │
│  │   已知边: 直接传播因果信息                                        │    │
│  │   未知边: 学习表示                                                │    │
│  │                                                                 │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│    │                                                                     │
│    ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    摊销推断网络                                    │    │
│  │                                                                 │    │
│  │   输入: (x, G_partial)                                          │    │
│  │   输出: 因果效应估计                                              │    │
│  │                                                                 │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│    │                                                                     │
│    ▼                                                                     │
│  输出: 因果效应 CATE(x)                                               │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

3. 核心算法

3.1 图结构编码

class CausalGraphEncoder(nn.Module):
    """
    因果图编码器
    """
    def __init__(self, node_dim, hidden_dim):
        super().__init__()
        
        # 节点表示
        self.node_embedding = nn.Embedding(num_nodes, node_dim)
        
        # 边类型编码
        self.edge_type_embed = nn.Embedding(3, node_dim)  # 已知/未知/无
        
        # 图神经网络
        self.gnn = GraphAttentionNetwork(node_dim, hidden_dim)
        
    def forward(self, adj_matrix, edge_types):
        """
        Args:
            adj_matrix: 邻接矩阵 [N, N]
            edge_types: 边类型 [N, N] (0=无, 1=已知, 2=未知)
        """
        # 节点嵌入
        h = self.node_embedding.weight  # [N, node_dim]
        
        # 边类型嵌入
        edge_emb = self.edge_type_embed(edge_types)  # [N, N, node_dim]
        
        # GNN处理
        h = self.gnn(h, adj_matrix, edge_emb)
        
        return h

3.2 摊销推断网络

class AmortizedCausalInference(nn.Module):
    """
    摊销因果推断网络
    """
    def __init__(self, node_dim, hidden_dim, num_treatments):
        super().__init__()
        
        # 图编码器
        self.graph_encoder = CausalGraphEncoder(node_dim, hidden_dim)
        
        # 推断头
        self.treatment_encoder = nn.Linear(node_dim, hidden_dim)
        self.outcome_head = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, x, adj_matrix, edge_types, treatment, confounders):
        """
        Args:
            x: 协变量 [B, N, d]
            adj_matrix: 邻接矩阵
            edge_types: 边类型
            treatment: 处理变量 [B]
            confounders: 混杂因素 [B, N, d_c]
        """
        # 编码图结构
        graph_emb = self.graph_encoder(adj_matrix, edge_types)  # [N, hidden]
        
        # 聚合处理和混杂因素信息
        h_t = self.treatment_encoder(graph_emb[treatment])  # [B, hidden]
        h_c = confounders.mean(dim=1)  # [B, d_c]
        
        # 协变量表示
        h_x = x.mean(dim=1)  # [B, d]
        
        # 预测因果效应
        cate = self.outcome_head(torch.cat([h_x, h_t, h_c], dim=-1))
        
        return cate

4. 理论分析

4.1 识别条件

定理(部分图识别):设 是部分因果图,则因果效应可识别,如果:

  1. 处理变量到结果的所有后门路径都被阻断
  2. 前门路径可识别
  3. 未观察到的混杂通过部分图得到约束

4.2 泛化保证

定理(跨域泛化):设源域和目标域共享部分图结构 ,则:

其中 与部分图的知识量成反比。

5. 实验结果

5.1 基准测试

IHDP数据集

方法ATE误差置信区间覆盖
线性DAG0.1285%
DAG-GNN0.0888%
CausalForest0.0990%
CFM0.0594%

5.2 跨域泛化

从模拟数据迁移到真实数据

方法同分布跨分布
线性回归0.150.42
DAG-GNN0.080.25
CFM0.050.12

6. 总结

6.1 主要贡献

  1. 部分图利用:充分利用已知因果知识
  2. 摊销推断:快速因果效应估计
  3. 跨域泛化:从部分图结构获益

6.2 局限性

  1. 图结构质量:依赖部分图的准确性
  2. 计算复杂度:GNN计算开销
  3. 假设依赖:需要因果充分性假设

参考文献

Footnotes

  1. Causal Foundation Models with Partial Graphs, arXiv:2602.14972 2