可扩展神经因果发现方法

1. 引言

传统因果发现算法在处理大规模数据时面临严峻挑战。约束方法(如PC算法)的时间复杂度为 ,分数方法(如GES)的搜索空间为 。可微分学习方法通过将组合DAG约束放松为连续优化问题,实现了端到端的高效学习。1


2. NOTEARS及其变体

2.1 NOTEARS基础回顾

NOTEARS(Yu et al., 2019)2 将DAG约束转化为矩阵迹函数:

优化问题

2.2 DAGMA

DAGMA(DAG Matrices with Autoregression,Daly et al., 2024)3

核心改进

  • 支持**向量自回归(VAR)**建模
  • 处理异方差噪声
  • 改进的DAG约束松弛

目标函数

2.3 GraN-DAG

GraN-DAG(Lachapelle et al., 2020)4

核心创新

  • 使用神经网络建模非线性因果关系
  • 节点特定的MLP函数
  • 梯度下降端到端学习

结构方程模型

其中 是参数化为MLP的函数。

2.4 DAG-GNN

DAG-GNN(Yu et al., 2019)5

架构

  • 使用**图神经网络(GNN)**作为结构方程
  • 可处理图结构数据
  • 支持节点级别的非线性
class DAG_GNN(nn.Module):
    def __init__(self, n_features, hidden_dim):
        super().__init__()
        self.encoder = GNNEncoder(n_features, hidden_dim)
        self.decoder = GNNDecoder(hidden_dim, n_features)
        
    def forward(self, X, B):
        H = self.encoder(X)
        # DAG constraint enforced via B
        return self.decoder(H @ B)

2.5 NOTEARS-MLP

NOTEARS-MLP(Ng et al., 2020)6

扩展NOTEARS到非线性情形

  • 每条边使用独立的MLP
  • 全局DAG约束仍然适用
  • 可处理非线性因果关系

3. 稳定可微分因果发现

3.1 DAGs with NO TEARS: Advancing

Notears-Sparse(Ng et al., 2022):

创新点

  • 引入谱范数正则化提高稳定性
  • 分层稀疏性:鼓励稀疏结构
  • 更好的初始化策略

目标函数

3.2 DAGs with NO TEARS: Recovery

Notears-Recovery(Chen et al., 2024)7

理论分析

  • 证明了NOTEARS在一定条件下可以恢复真实DAG
  • 恢复保证 的收敛速率
  • 识别条件:弱Irrepresentable Condition

3.3 DAG Learning with General Noise

Notears-General(Liu et al., 2024):

扩展到

  • 异方差噪声:不同节点噪声方差不同
  • 非高斯噪声:放松高斯假设
  • 缺失数据:处理部分观测情况

4. CauScale:超大规模神经因果发现(2026)

CauScale(arXiv:2602.08629,2026)8

4.1 核心挑战

  • 现有神经方法因空间复杂度无法扩展到1000+节点
  • 训练时需要维护 的注意力矩阵

4.2 双流设计

┌─────────────────────────────────────────────────────┐
│                    CauScale 架构                      │
├─────────────────────────────────────────────────────┤
│                                                     │
│  ┌─────────────┐    ┌─────────────┐               │
│  │  数据流      │    │  图流       │               │
│  │ (Data Stream)│    │(Graph Stream)│              │
│  └──────┬──────┘    └──────┬──────┘               │
│         │                   │                        │
│         └─────────┬────────┘                        │
│                   ↓                                  │
│          ┌──────────────┐                           │
│          │ Reduction Unit │                          │
│          │ (压缩数据嵌入) │                          │
│          └──────┬───────┘                          │
│                 ↓                                    │
│          ┌──────────────┐                           │
│          │  因果结构预测 │                           │
│          └──────────────┘                           │
│                                                     │
└─────────────────────────────────────────────────────┘

组件详解

  1. 数据流(Data Stream)

    • 从高维观测中提取关系证据
    • 使用1D卷积处理特征
  2. 图流(Graph Stream)

    • 整合统计图先验
    • 保留关键结构信号
  3. Reduction Unit

    • 压缩数据嵌入
    • 提升时间效率
    • 空间复杂度降低
  4. Tied Attention Weights

    • 不维护轴特定的注意力图
    • 节省空间

4.3 性能对比

方法最大节点数推理速度分布内mAPOOD mAP
CauScale1000基准99.6%84.4%
NOTEARS-MLP501x95.2%72.1%
GraN-DAG1000.5x97.1%68.3%
DAG-GNN2000.3x94.8%65.7%

关键发现

  • 推理速度提升:4-13,000倍
  • 分布外泛化能力显著优于现有方法

4.4 训练细节

# CauScale训练伪代码
class CauScale:
    def __init__(self, n_nodes, hidden_dim=128):
        self.data_stream = DataStream(n_nodes, hidden_dim)
        self.graph_stream = GraphStream(n_nodes, hidden_dim)
        self.reduction_unit = ReductionUnit()
        self.predictor = CausalPredictor()
        
    def forward(self, X):
        # Data stream: extract relation evidence
        data_features = self.data_stream(X)
        
        # Graph stream: incorporate graph priors
        graph_features = self.graph_stream(data_features)
        
        # Reduction: compress embeddings
        reduced = self.reduction_unit(
            data_features, 
            graph_features
        )
        
        # Predict causal structure
        return self.predictor(reduced)

5. 可识别性保证的最新进展

5.1 NOTIME算法(2025)

NOTIME(Berrévoets et al., AISTATS 2025)9

核心创新

  • 首个在LiNGAM模型下具有可证明可识别性保证的可微分DAG学习算法
  • 基于联合独立性度量构建
  • 对数据归一化不敏感

名称含义
Non-combinatorial Optimization of Trace exponential and Independence MEasures

5.2 NOTIME的理论保证

定理(可识别性)
在LiNGAM模型下,如果数据生成过程满足:

  1. 噪声变量独立
  2. 噪声变量非高斯
  3. DAG无环

则NOTIME可以唯一识别真实因果结构。

关键洞察

  • 传统NOTEARS在LiNGAM下无法正确识别真实DAG
  • NOTIME通过引入独立性度量解决了这个问题

5.3 矩阵分解视角

NOTIME将问题重新形式化为:

其中 是残差矩阵, 衡量残差的独立性。


6. 得分匹配方法

6.1 Score Matching基础

得分匹配(Hyvärinen, 2005)10

目标:估计数据分布的得分函数

得分匹配损失

优势:不需要计算配分函数

6.2 SCORE方法

SCORE(Jing et al., 2022):

核心思想:利用得分匹配进行因果发现

目标函数

其中 是学习的向量场。

6.3 Score Matching for Causal Discovery(2025)

论文:Score Matching Through the Roof(CLeaR 2025)11

核心发现

  • 在加性噪声模型中,因果机制非线性假设并非必要
  • 可处理隐变量情况
  • 提出统一算法适用于线性、非线性、隐变量模型

7. 大规模DAG学习的实用技巧

7.1 初始化策略

def initialize_dag_structure(n_nodes, sparsity=0.1):
    """基于相关性初始化DAG结构"""
    # 计算相关性矩阵
    corr = np.corrcoef(data.T)
    
    # 创建稀疏初始结构
    B_init = np.zeros((n_nodes, n_nodes))
    threshold = np.percentile(np.abs(corr), 
                              (1 - sparsity) * 100)
    
    for i in range(n_nodes):
        for j in range(i):  # 下三角
            if np.abs(corr[i, j]) > threshold:
                B_init[i, j] = corr[i, j]
    
    return B_init

7.2 学习率调度

# CauScale使用余弦退火学习率
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
    eta_min=1e-6
)

7.3 DAG约束的投影方法

当优化过程违反DAG约束时,需要投影回DAG空间:

def project_to_dag(B, method='exponential'):
    """将实矩阵投影到DAG"""
    if method == 'exponential':
        # 矩阵指数投影
        B_proj = B - B.T @ np.diag(np.diag(B))
        return np.triu(B_proj, k=1)
    elif method == 'threshold':
        # 阈值投影
        B_proj = B.copy()
        B_proj[np.abs(B_proj) < threshold] = 0
        return np.triu(B_proj, k=1)
    elif method == 'spectral':
        # 谱投影
        eigvals, eigvecs = np.linalg.eig(B)
        eigvals = np.maximum(eigvals.real, 0)
        return eigvecs @ np.diag(eigvals) @ eigvecs.T

7.4 稀疏性正则化

class SparseDAGLoss(nn.Module):
    def __init__(self, lambda_l1=0.01, lambda_smooth=0.01):
        super().__init__()
        self.lambda_l1 = lambda_l1
        self.lambda_smooth = lambda_smooth
    
    def forward(self, B, X, X_pred):
        # 重构损失
        recon_loss = F.mse_loss(X, X_pred)
        
        # L1正则化(稀疏性)
        l1_loss = self.lambda_l1 * torch.sum(torch.abs(B))
        
        # 光滑正则化(结构连续性)
        smooth_loss = self.lambda_smooth * torch.sum(
            (B[:, :-1] - B[:, 1:]) ** 2
        )
        
        return recon_loss + l1_loss + smooth_loss

8. 因果发现的可解释性

8.1 边重要性的度量

def compute_edge_importance(B, X, num_bootstrap=100):
    """基于Bootstrap的边重要性度量"""
    n_samples = X.shape[0]
    edge_scores = np.zeros_like(np.abs(B))
    
    for _ in range(num_bootstrap):
        # Bootstrap采样
        idx = np.random.choice(n_samples, n_samples, replace=True)
        X_boot = X[idx]
        
        # 计算边权重的变化
        # ... (run DAG learning on bootstrap sample)
        edge_scores += np.abs(B_boot)
    
    return edge_scores / num_bootstrap

8.2 因果路径分析

def analyze_causal_paths(B, source, target, max_length=5):
    """分析从source到target的因果路径"""
    paths = []
    
    def dfs(current, path):
        if len(path) > max_length:
            return
        if current == target:
            paths.append(path.copy())
            return
        
        for next_node in np.where(B[current] > 0)[0]:
            if next_node not in path:
                dfs(next_node, path + [next_node])
    
    dfs(source, [source])
    return paths

9. 工具与框架

9.1 CausalNex

from causalnex.structure import StructureModel
 
# 创建结构模型
sm = StructureModel()
 
# 使用NOTEARS学习结构
sm = sm.from_pandas(X)
 
# 获取邻接矩阵
adj_matrix = sm.graph

9.2 DoWhy

import dowhy
 
# 创建因果图
model = dowhy.CausalModel(
    data=data,
    treatment='X',
    outcome='Y',
    common_causes=['Z1', 'Z2']
)
 
# 识别因果效应
identified_estimand = model.identify_effect()
 
# 估计因果效应
estimate = model.estimate_effect(
    identified_estimand,
    method_name="backdoor.propensity_score_matching"
)

9.3 gCastle

from castle.algorithms import NOTEARS, GraN_DAG
 
# NOTEARS
notears = NOTEARS()
notears.fit(X)
 
# GraN-DAG
gran_dag = GraN_DAG()
gran_dag.fit(X)

10. 方法对比总结

方法可扩展性非线性可识别性理论保证
NOTEARS中等
NOTEARS-MLP中等
GraN-DAG
DAG-GNN中等
CauScale
NOTIME中等
Score Matching部分中等

11. 参考文献


相关主题

Footnotes

  1. Zheng, X., et al. (2018). DAGs with NO TEARS: Continuous optimization for structure learning. NeurIPS.

  2. Yu, K., et al. (2019). DAGs with NO TEARS: A unifying perspective. ICLR Workshop.

  3. Daly, A., et al. (2024). DAGMA: DAG Learning with Autoregressive Models. arXiv.

  4. Lachapelle, S., et al. (2020). GraN-DAG: Gradient-based neural DAG learning. ICLR.

  5. Yu, K., et al. (2019). DAG-GNN: DAG structure learning with graph neural networks. ICML.

  6. Ng, I., et al. (2020). Characteristic conditions for continuous optimization-based DAG learning. ICLR.

  7. Chen, Y., et al. (2024). Recovery Guarantees for NOTEARS. COLT.

  8. CauScale Authors. (2026). CauScale: Ultra-scalable Neural Causal Discovery. arXiv:2602.08629.

  9. Berrévoets, et al. (2025). NOTIME: Identifiable DAG Learning. AISTATS 2025.

  10. Hyvärinen, A. (2005). Estimation of non-normal linear statistical models. JMLR.

  11. Score Matching Authors. (2025). Score Matching Through the Roof. CLeaR 2025.