因果发现基准与评估方法

1. 引言

因果发现算法的评估面临独特挑战:ground truth因果结构通常未知,评估只能在合成数据或半真实数据上进行。1

1.1 评估挑战

挑战描述
Ground Truth缺失真实因果关系难以获得
方法假设不同不同方法基于不同假设
数据生成偏差合成数据可能偏离真实场景
分布偏移实验室条件≠实际应用

1.2 评估框架

因果发现评估框架:

┌──────────────────────────────────────────────────┐
│                   数据生成                        │
│  ┌─────────┐  ┌─────────┐  ┌─────────┐        │
│  │ 合成数据 │  │ 半真实  │  │ 真实数据 │        │
│  │ (DAG)   │  │ (扰动)  │  │ (专家)   │        │
│  └────┬────┘  └────┬────┘  └────┬────┘        │
└────────┼────────────┼────────────┼─────────────┘
         │            │            │
         ↓            ↓            ↓
┌──────────────────────────────────────────────────┐
│               因果发现算法                        │
│  ┌─────────┐  ┌─────────┐  ┌─────────┐        │
│  │ 约束方法 │  │ 分数方法 │  │ 深度方法 │        │
│  └────┬────┘  └────┬────┘  └────┬────┘        │
└────────┼────────────┼────────────┼─────────────┘
         │            │            │
         ↓            ↓            ↓
┌──────────────────────────────────────────────────┐
│                   评估指标                        │
│  ┌─────────┐  ┌─────────┐  ┌─────────┐        │
│  │ 结构精度 │  │ 效应估计 │  │ 计算效率 │        │
│  └─────────┘  └─────────┘  └─────────┘        │
└──────────────────────────────────────────────────┘

2. 评估指标

2.1 结构精度指标

Structural Hamming Distance (SHD)

定义:与真实因果图(或CPDAG)的边差异数

其中 表示对称差集运算。

计算

  • 缺失的边:每条计1
  • 多余的边:每条计1
  • 方向错误的边:计1(CPDAG情况)
def calculate_shd(true_graph, pred_graph, is_dag=True):
    """
    计算SHD
    
    Args:
        true_graph: 真实邻接矩阵
        pred_graph: 预测邻接矩阵
        is_dag: 是否为DAG(否则为CPDAG)
    
    Returns:
        shd: Structural Hamming Distance
    """
    if is_dag:
        # DAG情况:直接比较
        diff = np.sum(np.abs(true_graph - pred_graph) > 0.5)
    else:
        # CPDAG情况:考虑等价类
        true_skeleton = (np.abs(true_graph) > 0.5).astype(int)
        pred_skeleton = (np.abs(pred_graph) > 0.5).astype(int)
        
        # 骨架差异
        skeleton_diff = np.sum(
            np.abs(true_skeleton - pred_skeleton) > 0
        )
        
        # 方向差异(仅在骨架正确时计算)
        direction_diff = 0
        correct_skeleton = (true_skeleton == pred_skeleton)
        direction_diff = np.sum(
            (true_graph != pred_graph) & correct_skeleton
        ) / 2  # 每条边计一次
        
        diff = skeleton_diff + direction_diff
    
    return diff

F1-Score

定义:边恢复的精确率和召回率的调和平均

其中:

def calculate_f1(true_graph, pred_graph, threshold=0.5):
    """计算边F1-Score"""
    true_edges = set(zip(*np.where(np.abs(true_graph) > threshold)))
    pred_edges = set(zip(*np.where(np.abs(pred_graph) > threshold)))
    
    tp = len(true_edges & pred_edges)
    fp = len(pred_edges - true_edges)
    fn = len(true_edges - pred_edges)
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    return f1, precision, recall

Matthews Correlation Coefficient (MCC)

定义:考虑所有四种预测情况(TP/TN/FP/FN)的平衡指标

优势:在类别不平衡时比F1更可靠

def calculate_mcc(true_graph, pred_graph, threshold=0.5):
    """计算MCC"""
    true_binary = (np.abs(true_graph) > threshold).astype(int)
    pred_binary = (np.abs(pred_graph) > threshold).astype(int)
    
    # 展平为1D
    y_true = true_binary.flatten()
    y_pred = pred_binary.flatten()
    
    tp = np.sum((y_true == 1) & (y_pred == 1))
    tn = np.sum((y_true == 0) & (y_pred == 0))
    fp = np.sum((y_true == 0) & (y_pred == 1))
    fn = np.sum((y_true == 1) & (y_pred == 0))
    
    numerator = tp * tn - fp * fn
    denominator = np.sqrt(
        (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
    )
    
    mcc = numerator / denominator if denominator > 0 else 0
    
    return mcc

IoU (Intersection over Union)

定义:边集合的Jaccard相似度

2.2 因果效应估计指标

Average Treatment Effect (ATE) 误差

定义:因果效应估计的相对误差

其中 是真实ATE, 是估计ATE。

def calculate_ate_error(true_ate, estimated_ate):
    """计算ATE相对误差"""
    if np.abs(true_ate) < 1e-6:
        return np.abs(estimated_ate - true_ate)
    return np.abs(estimated_ate - true_ate) / np.abs(true_ate)

Precision in Estimation of Heterogeneous Effects (PEHE)

定义:异质性因果效应的估计精度

其中 是个体因果效应。

2.3 计算效率指标

指标描述
运行时间算法执行所需时间
内存峰值最大内存占用
参数数量需要调整的超参数数量
可扩展性运行时间随数据规模的增长

3. 基准测试数据集

3.1 经典基准数据集

Sachs数据集

特性
节点数11
边数17
数据类型蛋白质信号
观测数~750
特点真实生物数据,存在未观测混杂

Tuebingen因果基准

特性描述
规模多种规模(10-100节点)
类型合成+半真实
噪声多种噪声水平
用途方法比较

3.2 综合评估基准

OCDB (Open Causal Discovery Benchmark)

OCDB(Open Causal Discovery Benchmark,arXiv:2406.04598)2

目标:提供公平、全面的因果发现方法评估

特点

  • 基于真实数据的基准
  • 评估指标:因果结构和因果效应
  • 考虑无向边,公平比较DAG和CPDAG

基准设置

class OCDBBenchmark:
    def __init__(self, dataset='synthetic'):
        self.datasets = {
            'linear_gaussian': LinearGaussianGenerator(),
            'nonlinear': NonlinearGenerator(),
            'hidden_confounder': HiddenConfounderGenerator(),
            'intervention': InterventionGenerator()
        }
    
    def evaluate(self, method, dataset_name):
        """评估因果发现方法"""
        # 生成数据
        X, true_graph, true_effect = self.datasets[dataset_name].generate()
        
        # 运行方法
        pred_graph = method.fit_predict(X)
        
        # 计算指标
        metrics = {
            'shd': calculate_shd(true_graph, pred_graph),
            'f1': calculate_f1(true_graph, pred_graph)[0],
            'mcc': calculate_mcc(true_graph, pred_graph),
            'ate_error': calculate_ate_error(true_effect, pred_effect)
        }
        
        return metrics

CausalTime基准(ICLR 2024)

CausalTime3

目标:时序因果发现的综合评估

特点

  • 基于动力学因果模型生成
  • 包含分布偏移场景
  • 覆盖连续、离散、混合数据
class CausalTimeBenchmark:
    def __init__(self):
        self.causal_mechanisms = [
            'linear', 'polynomial', 'sigmoid', 'sinusoidal'
        ]
        self.noise_types = ['gaussian', 'laplace', 'uniform']
    
    def generate(self, n_vars, n_samples, mechanism='linear'):
        """生成时序因果数据"""
        # 1. 生成因果图
        adj_matrix = generate_random_dag(n_vars, edge_prob=0.2)
        
        # 2. 生成时间序列
        time_series = generate_time_series(
            adj_matrix,
            n_samples=n_samples,
            mechanism=mechanism,
            time_lags=[1, 2, 3]
        )
        
        return time_series, adj_matrix

3.3 大规模时序基准

CausalRivers基准(ICLR 2025)

CausalRivers4

特点

  • 最大规模:666节点(德国东部)、494节点(巴伐利亚)
  • 真实数据:水文、气象数据
  • 时间跨度:2019-2023年,15分钟分辨率
  • 分布偏移:包含洪水事件等极端场景
CausalRivers基准统计:

┌─────────────────┬────────────┬────────────┐
│     特性        │  德国东部   │   巴伐利亚  │
├─────────────────┼────────────┼────────────┤
│ 节点数          │    666     │    494     │
│ 时间步数        │   ~140,000 │   ~140,000 │
│ 时间跨度        │  2019-2023 │  2019-2023 │
│ 分辨率          │   15分钟    │   15分钟   │
│ 事件数          │    15      │     12     │
│ 分布偏移        │    洪水     │   季节变化 │
└─────────────────┴────────────┴────────────┘

4. 数据生成方法

4.1 DAG生成

Erdős–Rényi模型

def generate_er_dag(n_nodes, edge_prob):
    """生成Erdős–Rényi随机DAG"""
    # 随机上三角矩阵
    B = np.triu(np.random.rand(n_nodes, n_nodes) < edge_prob, k=1)
    
    # 确保为DAG(通过拓扑排序重排)
    perm = np.random.permutation(n_nodes)
    B = B[perm][:, perm]
    
    return B

Scale-Free模型

def generate_scale_free_dag(n_nodes, alpha=2.5):
    """生成无标度DAG(Barabási-Albert模型)"""
    # 初始种子
    G = nx.DiGraph()
    G.add_edge(0, 1)
    
    # 逐步添加节点
    for new_node in range(2, n_nodes):
        # 计算度偏好概率
        degrees = np.array([G.in_degree(i) + G.out_degree(i) 
                          for i in range(new_node)])
        probs = degrees ** alpha
        probs /= probs.sum()
        
        # 选择连接目标
        n_edges = np.random.poisson(3) + 1
        targets = np.random.choice(
            new_node, size=min(n_edges, new_node), 
            replace=False, p=probs
        )
        
        # 添加边
        for target in targets:
            G.add_edge(new_node, target)
    
    # 转换为邻接矩阵
    return nx.to_numpy_array(G, nodelist=range(n_nodes))

4.2 数据生成过程

线性高斯SEM

def generate_linear_gaussian_data(B, n_samples, noise_std=1.0):
    """生成线性高斯数据"""
    n_nodes = B.shape[0]
    
    # 确保B是严格下三角(DAG)
    B = np.tril(B, k=-1)
    
    # 拓扑排序
    order = np.argsort([0 if np.sum(B[:, i]) == 0 else 1 
                        for i in range(n_nodes)])
    B = B[order][:, order]
    
    # 生成数据
    X = np.zeros((n_samples, n_nodes))
    for i in range(n_nodes):
        parents = np.where(B[:i, i])[0]
        noise = np.random.randn(n_samples) * noise_std
        X[:, i] = X[:, parents] @ B[parents, i] + noise
    
    # 恢复原始顺序
    X = X[:, np.argsort(order)]
    
    return X

非线性SEM

def generate_nonlinear_data(B, n_samples, noise_std=1.0):
    """生成非线性数据"""
    n_nodes = B.shape[0]
    B = np.tril(B, k=-1)
    
    X = np.zeros((n_samples, n_nodes))
    for i in range(n_nodes):
        parents = np.where(B[:i, i])[0]
        
        if len(parents) == 0:
            X[:, i] = np.random.randn(n_samples) * noise_std
        elif len(parents) == 1:
            # 单父节点:非线性变换
            X[:, i] = np.tanh(X[:, parents[0]]) + \
                      np.random.randn(n_samples) * noise_std
        else:
            # 多父节点:加性组合
            effect = np.sum([
                np.sin(X[:, p]) for p in parents
            ], axis=0)
            X[:, i] = effect + np.random.randn(n_samples) * noise_std
    
    return X

5. 评估协议

5.1 标准化评估流程

class CausalDiscoveryEvaluator:
    def __init__(self, metrics=['shd', 'f1', 'mcc', ' runtime']):
        self.metrics = metrics
        self.results = {}
    
    def run_evaluation(self, method, datasets, n_runs=10):
        """标准化评估流程"""
        all_results = []
        
        for dataset_name, generator in datasets.items():
            run_results = []
            
            for run in range(n_runs):
                # 生成数据
                X, true_graph = generator.generate()
                
                # 运行方法
                start_time = time.time()
                pred_graph = method.fit_predict(X)
                runtime = time.time() - start_time
                
                # 计算指标
                metrics = self._compute_metrics(true_graph, pred_graph)
                metrics['runtime'] = runtime
                
                run_results.append(metrics)
            
            # 汇总结果
            dataset_results = self._aggregate_results(run_results)
            all_results.append(dataset_results)
        
        return self._format_results(all_results)
    
    def _compute_metrics(self, true_graph, pred_graph):
        """计算所有指标"""
        metrics = {}
        
        if 'shd' in self.metrics:
            metrics['shd'] = calculate_shd(true_graph, pred_graph)
        
        if 'f1' in self.metrics:
            f1, prec, rec = calculate_f1(true_graph, pred_graph)
            metrics['f1'] = f1
            metrics['precision'] = prec
            metrics['recall'] = rec
        
        if 'mcc' in self.metrics:
            metrics['mcc'] = calculate_mcc(true_graph, pred_graph)
        
        return metrics
    
    def _aggregate_results(self, results):
        """聚合多次运行结果"""
        import pandas as pd
        df = pd.DataFrame(results)
        
        return {
            'mean': df.mean().to_dict(),
            'std': df.std().to_dict(),
            'median': df.median().to_dict()
        }

5.2 分布外评估

def out_of_distribution_evaluation(method, train_generator, test_configs):
    """分布外评估"""
    results = {}
    
    # 训练数据
    X_train, G_train = train_generator.generate()
    method.fit(X_train)
    
    # 不同分布的测试数据
    for config_name, test_generator in test_configs.items():
        X_test, G_test = test_generator.generate()
        pred_graph = method.predict(X_test)
        
        # 计算分布内和分布外性能
        results[config_name] = {
            'in_dist_shd': calculate_shd(G_train, method.predict(X_train)),
            'ood_shd': calculate_shd(G_test, pred_graph),
            'degradation': calculate_shd(G_test, pred_graph) - \
                           calculate_shd(G_train, method.predict(X_train))
        }
    
    return results

6. 基准测试最佳实践

6.1 评估检查清单

evaluation_checklist = """
因果发现算法评估检查清单:
 
□ 数据生成
  □ DAG结构多样性(链、树、树冠、无标度)
  □ 噪声类型(高斯、非高斯、异方差)
  □ 样本量(n=100, 500, 1000, 5000)
  □ 变量数(p=10, 50, 100, 500)
  
□ 性能指标
  □ 结构精度(SHD, F1, MCC, IoU)
  □ 因果效应估计(ATE误差, PEHE)
  □ 置信区间报告
  □ 统计显著性检验
  
□ 鲁棒性评估
  □ 噪声水平变化
  □ 缺失数据
  □ 分布偏移
  □ 运行时间/内存
  
□ 可重复性
  □ 随机种子报告
  □ 代码开源
  □ 超参数设置
"""

6.2 可视化评估结果

import matplotlib.pyplot as plt
 
def plot_performance_comparison(results, methods, datasets):
    """绘制性能对比图"""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # SHD对比
    axes[0, 0].bar(range(len(methods)), 
                    [results[m]['shd'] for m in methods])
    axes[0, 0].set_xticks(range(len(methods)))
    axes[0, 0].set_xticklabels(methods, rotation=45)
    axes[0, 0].set_ylabel('SHD (lower is better)')
    axes[0, 0].set_title('Structural Accuracy')
    
    # F1对比
    axes[0, 1].bar(range(len(methods)),
                   [results[m]['f1'] for m in methods])
    axes[0, 1].set_xticks(range(len(methods)))
    axes[0, 1].set_xticklabels(methods, rotation=45)
    axes[0, 1].set_ylabel('F1 (higher is better)')
    axes[0, 1].set_title('Edge Recovery')
    
    # 运行时间
    axes[1, 0].bar(range(len(methods)),
                    [results[m]['runtime'] for m in methods])
    axes[1, 0].set_xticks(range(len(methods)))
    axes[1, 0].set_xticklabels(methods, rotation=45)
    axes[1, 0].set_ylabel('Runtime (seconds)')
    axes[1, 0].set_title('Computational Efficiency')
    
    # 可扩展性
    for n_vars in [10, 50, 100]:
        runtimes = [results[m]['scalability'][n_vars] for m in methods]
        axes[1, 1].plot(range(len(methods)), runtimes, 
                        marker='o', label=f'p={n_vars}')
    
    axes[1, 1].set_xticks(range(len(methods)))
    axes[1, 1].set_xticklabels(methods, rotation=45)
    axes[1, 1].set_ylabel('Runtime (seconds)')
    axes[1, 1].set_title('Scalability')
    axes[1, 1].legend()
    
    plt.tight_layout()
    plt.savefig('causal_discovery_comparison.png', dpi=300)
    plt.show()

7. 基准测试工具

7.1 CausalBench库

# 安装
# pip install causalbench
 
from causalbench import CausalBenchmark
 
benchmark = CausalBenchmark(
    methods=['pc', 'ges', 'notears', 'gran_dag'],
    datasets=['linear_gaussian', 'nonlinear', 'hidden_confounder'],
    metrics=['shd', 'f1', 'runtime']
)
 
results = benchmark.run(n_runs=10)
benchmark.save_results('results.json')
benchmark.plot_comparison()

7.2 gCastle

from castle.algorithms import PC, GES, NOTEARS
from castle.common import Benchmark
from castle.datasets import DAG, IIDSimulation
 
# 创建基准
bm = Benchmark()
 
# 添加方法
bm.add_method('PC', PC())
bm.add_method('GES', GES())
bm.add_method('NOTEARS', NOTEARS())
 
# 生成数据
dataset = IIDSimulation(N=1000, 
                        n_edges=20,
                        n_nodes=20,
                        method='ER',
                        sem_type='gauss')
 
# 运行基准
results = bm.evaluate(dataset)
 
# 可视化
bm.plot_results()

8. 基准测试总结

8.1 推荐基准设置

数据规模推荐基准
小规模(n<20)Sachs, Tuebingen, 合成线性
中等规模(20<n<100)OCDB, 合成非线性
大规模(n>100)CausalRivers, 大规模合成
时序数据CausalTime, CausalRivers

8.2 方法选择建议

方法选择决策树:

数据是时序数据?
├── 是 → PCMCI, CausalFormer, TS-CausalNN
└── 否
    ├── 样本量小(<1000)?
    │   ├── 是 → PC, FCI
    │   └── 否
    │       ├── 需要理论保证?
    │       │   ├── 是 → LiNGAM, NOTIME
    │       │   └── 否
    │       │       ├── 非线性关系?
    │       │       │   ├── 是 → GraN-DAG, NOTEARS-MLP
    │       │       │   └── 否 → NOTEARS, DAG-GNN
    │       │       └── 需要大规模?
    │       │           └── 是 → CauScale

9. 参考文献


相关主题

Footnotes

  1. Spirtes, P., et al. (2000). Causation, Prediction, and Search. MIT Press.

  2. OCDB Authors. (2024). Open Causal Discovery Benchmark. arXiv:2406.04598.

  3. CausalTime Authors. (2024). CausalTime: Benchmark for Temporal Causal Discovery. ICLR 2024.

  4. CausalRivers Authors. (2025). CausalRivers: Large-scale Temporal Causal Discovery Benchmark. ICLR 2025.