因果发现基准与评估方法
1. 引言
因果发现算法的评估面临独特挑战:ground truth因果结构通常未知,评估只能在合成数据或半真实数据上进行。1
1.1 评估挑战
| 挑战 | 描述 |
|---|---|
| Ground Truth缺失 | 真实因果关系难以获得 |
| 方法假设不同 | 不同方法基于不同假设 |
| 数据生成偏差 | 合成数据可能偏离真实场景 |
| 分布偏移 | 实验室条件≠实际应用 |
1.2 评估框架
因果发现评估框架:
┌──────────────────────────────────────────────────┐
│ 数据生成 │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 合成数据 │ │ 半真实 │ │ 真实数据 │ │
│ │ (DAG) │ │ (扰动) │ │ (专家) │ │
│ └────┬────┘ └────┬────┘ └────┬────┘ │
└────────┼────────────┼────────────┼─────────────┘
│ │ │
↓ ↓ ↓
┌──────────────────────────────────────────────────┐
│ 因果发现算法 │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 约束方法 │ │ 分数方法 │ │ 深度方法 │ │
│ └────┬────┘ └────┬────┘ └────┬────┘ │
└────────┼────────────┼────────────┼─────────────┘
│ │ │
↓ ↓ ↓
┌──────────────────────────────────────────────────┐
│ 评估指标 │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 结构精度 │ │ 效应估计 │ │ 计算效率 │ │
│ └─────────┘ └─────────┘ └─────────┘ │
└──────────────────────────────────────────────────┘
2. 评估指标
2.1 结构精度指标
Structural Hamming Distance (SHD)
定义:与真实因果图(或CPDAG)的边差异数
其中 表示对称差集运算。
计算:
- 缺失的边:每条计1
- 多余的边:每条计1
- 方向错误的边:计1(CPDAG情况)
def calculate_shd(true_graph, pred_graph, is_dag=True):
"""
计算SHD
Args:
true_graph: 真实邻接矩阵
pred_graph: 预测邻接矩阵
is_dag: 是否为DAG(否则为CPDAG)
Returns:
shd: Structural Hamming Distance
"""
if is_dag:
# DAG情况:直接比较
diff = np.sum(np.abs(true_graph - pred_graph) > 0.5)
else:
# CPDAG情况:考虑等价类
true_skeleton = (np.abs(true_graph) > 0.5).astype(int)
pred_skeleton = (np.abs(pred_graph) > 0.5).astype(int)
# 骨架差异
skeleton_diff = np.sum(
np.abs(true_skeleton - pred_skeleton) > 0
)
# 方向差异(仅在骨架正确时计算)
direction_diff = 0
correct_skeleton = (true_skeleton == pred_skeleton)
direction_diff = np.sum(
(true_graph != pred_graph) & correct_skeleton
) / 2 # 每条边计一次
diff = skeleton_diff + direction_diff
return diffF1-Score
定义:边恢复的精确率和召回率的调和平均
其中:
def calculate_f1(true_graph, pred_graph, threshold=0.5):
"""计算边F1-Score"""
true_edges = set(zip(*np.where(np.abs(true_graph) > threshold)))
pred_edges = set(zip(*np.where(np.abs(pred_graph) > threshold)))
tp = len(true_edges & pred_edges)
fp = len(pred_edges - true_edges)
fn = len(true_edges - pred_edges)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return f1, precision, recallMatthews Correlation Coefficient (MCC)
定义:考虑所有四种预测情况(TP/TN/FP/FN)的平衡指标
优势:在类别不平衡时比F1更可靠
def calculate_mcc(true_graph, pred_graph, threshold=0.5):
"""计算MCC"""
true_binary = (np.abs(true_graph) > threshold).astype(int)
pred_binary = (np.abs(pred_graph) > threshold).astype(int)
# 展平为1D
y_true = true_binary.flatten()
y_pred = pred_binary.flatten()
tp = np.sum((y_true == 1) & (y_pred == 1))
tn = np.sum((y_true == 0) & (y_pred == 0))
fp = np.sum((y_true == 0) & (y_pred == 1))
fn = np.sum((y_true == 1) & (y_pred == 0))
numerator = tp * tn - fp * fn
denominator = np.sqrt(
(tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
)
mcc = numerator / denominator if denominator > 0 else 0
return mccIoU (Intersection over Union)
定义:边集合的Jaccard相似度
2.2 因果效应估计指标
Average Treatment Effect (ATE) 误差
定义:因果效应估计的相对误差
其中 是真实ATE, 是估计ATE。
def calculate_ate_error(true_ate, estimated_ate):
"""计算ATE相对误差"""
if np.abs(true_ate) < 1e-6:
return np.abs(estimated_ate - true_ate)
return np.abs(estimated_ate - true_ate) / np.abs(true_ate)Precision in Estimation of Heterogeneous Effects (PEHE)
定义:异质性因果效应的估计精度
其中 是个体因果效应。
2.3 计算效率指标
| 指标 | 描述 |
|---|---|
| 运行时间 | 算法执行所需时间 |
| 内存峰值 | 最大内存占用 |
| 参数数量 | 需要调整的超参数数量 |
| 可扩展性 | 运行时间随数据规模的增长 |
3. 基准测试数据集
3.1 经典基准数据集
Sachs数据集
| 特性 | 值 |
|---|---|
| 节点数 | 11 |
| 边数 | 17 |
| 数据类型 | 蛋白质信号 |
| 观测数 | ~750 |
| 特点 | 真实生物数据,存在未观测混杂 |
Tuebingen因果基准
| 特性 | 描述 |
|---|---|
| 规模 | 多种规模(10-100节点) |
| 类型 | 合成+半真实 |
| 噪声 | 多种噪声水平 |
| 用途 | 方法比较 |
3.2 综合评估基准
OCDB (Open Causal Discovery Benchmark)
OCDB(Open Causal Discovery Benchmark,arXiv:2406.04598)2:
目标:提供公平、全面的因果发现方法评估
特点:
- 基于真实数据的基准
- 评估指标:因果结构和因果效应
- 考虑无向边,公平比较DAG和CPDAG
基准设置:
class OCDBBenchmark:
def __init__(self, dataset='synthetic'):
self.datasets = {
'linear_gaussian': LinearGaussianGenerator(),
'nonlinear': NonlinearGenerator(),
'hidden_confounder': HiddenConfounderGenerator(),
'intervention': InterventionGenerator()
}
def evaluate(self, method, dataset_name):
"""评估因果发现方法"""
# 生成数据
X, true_graph, true_effect = self.datasets[dataset_name].generate()
# 运行方法
pred_graph = method.fit_predict(X)
# 计算指标
metrics = {
'shd': calculate_shd(true_graph, pred_graph),
'f1': calculate_f1(true_graph, pred_graph)[0],
'mcc': calculate_mcc(true_graph, pred_graph),
'ate_error': calculate_ate_error(true_effect, pred_effect)
}
return metricsCausalTime基准(ICLR 2024)
CausalTime3:
目标:时序因果发现的综合评估
特点:
- 基于动力学因果模型生成
- 包含分布偏移场景
- 覆盖连续、离散、混合数据
class CausalTimeBenchmark:
def __init__(self):
self.causal_mechanisms = [
'linear', 'polynomial', 'sigmoid', 'sinusoidal'
]
self.noise_types = ['gaussian', 'laplace', 'uniform']
def generate(self, n_vars, n_samples, mechanism='linear'):
"""生成时序因果数据"""
# 1. 生成因果图
adj_matrix = generate_random_dag(n_vars, edge_prob=0.2)
# 2. 生成时间序列
time_series = generate_time_series(
adj_matrix,
n_samples=n_samples,
mechanism=mechanism,
time_lags=[1, 2, 3]
)
return time_series, adj_matrix3.3 大规模时序基准
CausalRivers基准(ICLR 2025)
CausalRivers4:
特点:
- 最大规模:666节点(德国东部)、494节点(巴伐利亚)
- 真实数据:水文、气象数据
- 时间跨度:2019-2023年,15分钟分辨率
- 分布偏移:包含洪水事件等极端场景
CausalRivers基准统计:
┌─────────────────┬────────────┬────────────┐
│ 特性 │ 德国东部 │ 巴伐利亚 │
├─────────────────┼────────────┼────────────┤
│ 节点数 │ 666 │ 494 │
│ 时间步数 │ ~140,000 │ ~140,000 │
│ 时间跨度 │ 2019-2023 │ 2019-2023 │
│ 分辨率 │ 15分钟 │ 15分钟 │
│ 事件数 │ 15 │ 12 │
│ 分布偏移 │ 洪水 │ 季节变化 │
└─────────────────┴────────────┴────────────┘
4. 数据生成方法
4.1 DAG生成
Erdős–Rényi模型
def generate_er_dag(n_nodes, edge_prob):
"""生成Erdős–Rényi随机DAG"""
# 随机上三角矩阵
B = np.triu(np.random.rand(n_nodes, n_nodes) < edge_prob, k=1)
# 确保为DAG(通过拓扑排序重排)
perm = np.random.permutation(n_nodes)
B = B[perm][:, perm]
return BScale-Free模型
def generate_scale_free_dag(n_nodes, alpha=2.5):
"""生成无标度DAG(Barabási-Albert模型)"""
# 初始种子
G = nx.DiGraph()
G.add_edge(0, 1)
# 逐步添加节点
for new_node in range(2, n_nodes):
# 计算度偏好概率
degrees = np.array([G.in_degree(i) + G.out_degree(i)
for i in range(new_node)])
probs = degrees ** alpha
probs /= probs.sum()
# 选择连接目标
n_edges = np.random.poisson(3) + 1
targets = np.random.choice(
new_node, size=min(n_edges, new_node),
replace=False, p=probs
)
# 添加边
for target in targets:
G.add_edge(new_node, target)
# 转换为邻接矩阵
return nx.to_numpy_array(G, nodelist=range(n_nodes))4.2 数据生成过程
线性高斯SEM
def generate_linear_gaussian_data(B, n_samples, noise_std=1.0):
"""生成线性高斯数据"""
n_nodes = B.shape[0]
# 确保B是严格下三角(DAG)
B = np.tril(B, k=-1)
# 拓扑排序
order = np.argsort([0 if np.sum(B[:, i]) == 0 else 1
for i in range(n_nodes)])
B = B[order][:, order]
# 生成数据
X = np.zeros((n_samples, n_nodes))
for i in range(n_nodes):
parents = np.where(B[:i, i])[0]
noise = np.random.randn(n_samples) * noise_std
X[:, i] = X[:, parents] @ B[parents, i] + noise
# 恢复原始顺序
X = X[:, np.argsort(order)]
return X非线性SEM
def generate_nonlinear_data(B, n_samples, noise_std=1.0):
"""生成非线性数据"""
n_nodes = B.shape[0]
B = np.tril(B, k=-1)
X = np.zeros((n_samples, n_nodes))
for i in range(n_nodes):
parents = np.where(B[:i, i])[0]
if len(parents) == 0:
X[:, i] = np.random.randn(n_samples) * noise_std
elif len(parents) == 1:
# 单父节点:非线性变换
X[:, i] = np.tanh(X[:, parents[0]]) + \
np.random.randn(n_samples) * noise_std
else:
# 多父节点:加性组合
effect = np.sum([
np.sin(X[:, p]) for p in parents
], axis=0)
X[:, i] = effect + np.random.randn(n_samples) * noise_std
return X5. 评估协议
5.1 标准化评估流程
class CausalDiscoveryEvaluator:
def __init__(self, metrics=['shd', 'f1', 'mcc', ' runtime']):
self.metrics = metrics
self.results = {}
def run_evaluation(self, method, datasets, n_runs=10):
"""标准化评估流程"""
all_results = []
for dataset_name, generator in datasets.items():
run_results = []
for run in range(n_runs):
# 生成数据
X, true_graph = generator.generate()
# 运行方法
start_time = time.time()
pred_graph = method.fit_predict(X)
runtime = time.time() - start_time
# 计算指标
metrics = self._compute_metrics(true_graph, pred_graph)
metrics['runtime'] = runtime
run_results.append(metrics)
# 汇总结果
dataset_results = self._aggregate_results(run_results)
all_results.append(dataset_results)
return self._format_results(all_results)
def _compute_metrics(self, true_graph, pred_graph):
"""计算所有指标"""
metrics = {}
if 'shd' in self.metrics:
metrics['shd'] = calculate_shd(true_graph, pred_graph)
if 'f1' in self.metrics:
f1, prec, rec = calculate_f1(true_graph, pred_graph)
metrics['f1'] = f1
metrics['precision'] = prec
metrics['recall'] = rec
if 'mcc' in self.metrics:
metrics['mcc'] = calculate_mcc(true_graph, pred_graph)
return metrics
def _aggregate_results(self, results):
"""聚合多次运行结果"""
import pandas as pd
df = pd.DataFrame(results)
return {
'mean': df.mean().to_dict(),
'std': df.std().to_dict(),
'median': df.median().to_dict()
}5.2 分布外评估
def out_of_distribution_evaluation(method, train_generator, test_configs):
"""分布外评估"""
results = {}
# 训练数据
X_train, G_train = train_generator.generate()
method.fit(X_train)
# 不同分布的测试数据
for config_name, test_generator in test_configs.items():
X_test, G_test = test_generator.generate()
pred_graph = method.predict(X_test)
# 计算分布内和分布外性能
results[config_name] = {
'in_dist_shd': calculate_shd(G_train, method.predict(X_train)),
'ood_shd': calculate_shd(G_test, pred_graph),
'degradation': calculate_shd(G_test, pred_graph) - \
calculate_shd(G_train, method.predict(X_train))
}
return results6. 基准测试最佳实践
6.1 评估检查清单
evaluation_checklist = """
因果发现算法评估检查清单:
□ 数据生成
□ DAG结构多样性(链、树、树冠、无标度)
□ 噪声类型(高斯、非高斯、异方差)
□ 样本量(n=100, 500, 1000, 5000)
□ 变量数(p=10, 50, 100, 500)
□ 性能指标
□ 结构精度(SHD, F1, MCC, IoU)
□ 因果效应估计(ATE误差, PEHE)
□ 置信区间报告
□ 统计显著性检验
□ 鲁棒性评估
□ 噪声水平变化
□ 缺失数据
□ 分布偏移
□ 运行时间/内存
□ 可重复性
□ 随机种子报告
□ 代码开源
□ 超参数设置
"""6.2 可视化评估结果
import matplotlib.pyplot as plt
def plot_performance_comparison(results, methods, datasets):
"""绘制性能对比图"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# SHD对比
axes[0, 0].bar(range(len(methods)),
[results[m]['shd'] for m in methods])
axes[0, 0].set_xticks(range(len(methods)))
axes[0, 0].set_xticklabels(methods, rotation=45)
axes[0, 0].set_ylabel('SHD (lower is better)')
axes[0, 0].set_title('Structural Accuracy')
# F1对比
axes[0, 1].bar(range(len(methods)),
[results[m]['f1'] for m in methods])
axes[0, 1].set_xticks(range(len(methods)))
axes[0, 1].set_xticklabels(methods, rotation=45)
axes[0, 1].set_ylabel('F1 (higher is better)')
axes[0, 1].set_title('Edge Recovery')
# 运行时间
axes[1, 0].bar(range(len(methods)),
[results[m]['runtime'] for m in methods])
axes[1, 0].set_xticks(range(len(methods)))
axes[1, 0].set_xticklabels(methods, rotation=45)
axes[1, 0].set_ylabel('Runtime (seconds)')
axes[1, 0].set_title('Computational Efficiency')
# 可扩展性
for n_vars in [10, 50, 100]:
runtimes = [results[m]['scalability'][n_vars] for m in methods]
axes[1, 1].plot(range(len(methods)), runtimes,
marker='o', label=f'p={n_vars}')
axes[1, 1].set_xticks(range(len(methods)))
axes[1, 1].set_xticklabels(methods, rotation=45)
axes[1, 1].set_ylabel('Runtime (seconds)')
axes[1, 1].set_title('Scalability')
axes[1, 1].legend()
plt.tight_layout()
plt.savefig('causal_discovery_comparison.png', dpi=300)
plt.show()7. 基准测试工具
7.1 CausalBench库
# 安装
# pip install causalbench
from causalbench import CausalBenchmark
benchmark = CausalBenchmark(
methods=['pc', 'ges', 'notears', 'gran_dag'],
datasets=['linear_gaussian', 'nonlinear', 'hidden_confounder'],
metrics=['shd', 'f1', 'runtime']
)
results = benchmark.run(n_runs=10)
benchmark.save_results('results.json')
benchmark.plot_comparison()7.2 gCastle
from castle.algorithms import PC, GES, NOTEARS
from castle.common import Benchmark
from castle.datasets import DAG, IIDSimulation
# 创建基准
bm = Benchmark()
# 添加方法
bm.add_method('PC', PC())
bm.add_method('GES', GES())
bm.add_method('NOTEARS', NOTEARS())
# 生成数据
dataset = IIDSimulation(N=1000,
n_edges=20,
n_nodes=20,
method='ER',
sem_type='gauss')
# 运行基准
results = bm.evaluate(dataset)
# 可视化
bm.plot_results()8. 基准测试总结
8.1 推荐基准设置
| 数据规模 | 推荐基准 |
|---|---|
| 小规模(n<20) | Sachs, Tuebingen, 合成线性 |
| 中等规模(20<n<100) | OCDB, 合成非线性 |
| 大规模(n>100) | CausalRivers, 大规模合成 |
| 时序数据 | CausalTime, CausalRivers |
8.2 方法选择建议
方法选择决策树:
数据是时序数据?
├── 是 → PCMCI, CausalFormer, TS-CausalNN
└── 否
├── 样本量小(<1000)?
│ ├── 是 → PC, FCI
│ └── 否
│ ├── 需要理论保证?
│ │ ├── 是 → LiNGAM, NOTIME
│ │ └── 否
│ │ ├── 非线性关系?
│ │ │ ├── 是 → GraN-DAG, NOTEARS-MLP
│ │ │ └── 否 → NOTEARS, DAG-GNN
│ │ └── 需要大规模?
│ │ └── 是 → CauScale
9. 参考文献
相关主题
Footnotes
-
Spirtes, P., et al. (2000). Causation, Prediction, and Search. MIT Press. ↩
-
OCDB Authors. (2024). Open Causal Discovery Benchmark. arXiv:2406.04598. ↩
-
CausalTime Authors. (2024). CausalTime: Benchmark for Temporal Causal Discovery. ICLR 2024. ↩
-
CausalRivers Authors. (2025). CausalRivers: Large-scale Temporal Causal Discovery Benchmark. ICLR 2025. ↩