可扩展神经因果发现方法
1. 引言
传统因果发现算法在处理大规模数据时面临严峻挑战。约束方法(如PC算法)的时间复杂度为 ,分数方法(如GES)的搜索空间为 。可微分学习方法通过将组合DAG约束放松为连续优化问题,实现了端到端的高效学习。1
2. NOTEARS及其变体
2.1 NOTEARS基础回顾
NOTEARS(Yu et al., 2019)2 将DAG约束转化为矩阵迹函数:
优化问题:
2.2 DAGMA
DAGMA(DAG Matrices with Autoregression,Daly et al., 2024)3:
核心改进:
- 支持**向量自回归(VAR)**建模
- 处理异方差噪声
- 改进的DAG约束松弛
目标函数:
2.3 GraN-DAG
GraN-DAG(Lachapelle et al., 2020)4:
核心创新:
- 使用神经网络建模非线性因果关系
- 节点特定的MLP函数
- 梯度下降端到端学习
结构方程模型:
其中 是参数化为MLP的函数。
2.4 DAG-GNN
DAG-GNN(Yu et al., 2019)5:
架构:
- 使用**图神经网络(GNN)**作为结构方程
- 可处理图结构数据
- 支持节点级别的非线性
class DAG_GNN(nn.Module):
def __init__(self, n_features, hidden_dim):
super().__init__()
self.encoder = GNNEncoder(n_features, hidden_dim)
self.decoder = GNNDecoder(hidden_dim, n_features)
def forward(self, X, B):
H = self.encoder(X)
# DAG constraint enforced via B
return self.decoder(H @ B)2.5 NOTEARS-MLP
NOTEARS-MLP(Ng et al., 2020)6:
扩展NOTEARS到非线性情形:
- 每条边使用独立的MLP
- 全局DAG约束仍然适用
- 可处理非线性因果关系
3. 稳定可微分因果发现
3.1 DAGs with NO TEARS: Advancing
Notears-Sparse(Ng et al., 2022):
创新点:
- 引入谱范数正则化提高稳定性
- 分层稀疏性:鼓励稀疏结构
- 更好的初始化策略
目标函数:
3.2 DAGs with NO TEARS: Recovery
Notears-Recovery(Chen et al., 2024)7:
理论分析:
- 证明了NOTEARS在一定条件下可以恢复真实DAG
- 恢复保证: 的收敛速率
- 识别条件:弱Irrepresentable Condition
3.3 DAG Learning with General Noise
Notears-General(Liu et al., 2024):
扩展到:
- 异方差噪声:不同节点噪声方差不同
- 非高斯噪声:放松高斯假设
- 缺失数据:处理部分观测情况
4. CauScale:超大规模神经因果发现(2026)
CauScale(arXiv:2602.08629,2026)8:
4.1 核心挑战
- 现有神经方法因空间复杂度无法扩展到1000+节点
- 训练时需要维护 的注意力矩阵
4.2 双流设计
┌─────────────────────────────────────────────────────┐
│ CauScale 架构 │
├─────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ 数据流 │ │ 图流 │ │
│ │ (Data Stream)│ │(Graph Stream)│ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ └─────────┬────────┘ │
│ ↓ │
│ ┌──────────────┐ │
│ │ Reduction Unit │ │
│ │ (压缩数据嵌入) │ │
│ └──────┬───────┘ │
│ ↓ │
│ ┌──────────────┐ │
│ │ 因果结构预测 │ │
│ └──────────────┘ │
│ │
└─────────────────────────────────────────────────────┘
组件详解:
-
数据流(Data Stream)
- 从高维观测中提取关系证据
- 使用1D卷积处理特征
-
图流(Graph Stream)
- 整合统计图先验
- 保留关键结构信号
-
Reduction Unit
- 压缩数据嵌入
- 提升时间效率
- 将 空间复杂度降低
-
Tied Attention Weights
- 不维护轴特定的注意力图
- 节省空间
4.3 性能对比
| 方法 | 最大节点数 | 推理速度 | 分布内mAP | OOD mAP |
|---|---|---|---|---|
| CauScale | 1000 | 基准 | 99.6% | 84.4% |
| NOTEARS-MLP | 50 | 1x | 95.2% | 72.1% |
| GraN-DAG | 100 | 0.5x | 97.1% | 68.3% |
| DAG-GNN | 200 | 0.3x | 94.8% | 65.7% |
关键发现:
- 推理速度提升:4-13,000倍
- 分布外泛化能力显著优于现有方法
4.4 训练细节
# CauScale训练伪代码
class CauScale:
def __init__(self, n_nodes, hidden_dim=128):
self.data_stream = DataStream(n_nodes, hidden_dim)
self.graph_stream = GraphStream(n_nodes, hidden_dim)
self.reduction_unit = ReductionUnit()
self.predictor = CausalPredictor()
def forward(self, X):
# Data stream: extract relation evidence
data_features = self.data_stream(X)
# Graph stream: incorporate graph priors
graph_features = self.graph_stream(data_features)
# Reduction: compress embeddings
reduced = self.reduction_unit(
data_features,
graph_features
)
# Predict causal structure
return self.predictor(reduced)5. 可识别性保证的最新进展
5.1 NOTIME算法(2025)
NOTIME(Berrévoets et al., AISTATS 2025)9:
核心创新:
- 首个在LiNGAM模型下具有可证明可识别性保证的可微分DAG学习算法
- 基于联合独立性度量构建
- 对数据归一化不敏感
名称含义:
Non-combinatorial Optimization of Trace exponential and Independence MEasures
5.2 NOTIME的理论保证
定理(可识别性):
在LiNGAM模型下,如果数据生成过程满足:
- 噪声变量独立
- 噪声变量非高斯
- DAG无环
则NOTIME可以唯一识别真实因果结构。
关键洞察:
- 传统NOTEARS在LiNGAM下无法正确识别真实DAG
- NOTIME通过引入独立性度量解决了这个问题
5.3 矩阵分解视角
NOTIME将问题重新形式化为:
其中 是残差矩阵, 衡量残差的独立性。
6. 得分匹配方法
6.1 Score Matching基础
得分匹配(Hyvärinen, 2005)10:
目标:估计数据分布的得分函数
得分匹配损失:
优势:不需要计算配分函数
6.2 SCORE方法
SCORE(Jing et al., 2022):
核心思想:利用得分匹配进行因果发现
目标函数:
其中 是学习的向量场。
6.3 Score Matching for Causal Discovery(2025)
论文:Score Matching Through the Roof(CLeaR 2025)11:
核心发现:
- 在加性噪声模型中,因果机制非线性假设并非必要
- 可处理隐变量情况
- 提出统一算法适用于线性、非线性、隐变量模型
7. 大规模DAG学习的实用技巧
7.1 初始化策略
def initialize_dag_structure(n_nodes, sparsity=0.1):
"""基于相关性初始化DAG结构"""
# 计算相关性矩阵
corr = np.corrcoef(data.T)
# 创建稀疏初始结构
B_init = np.zeros((n_nodes, n_nodes))
threshold = np.percentile(np.abs(corr),
(1 - sparsity) * 100)
for i in range(n_nodes):
for j in range(i): # 下三角
if np.abs(corr[i, j]) > threshold:
B_init[i, j] = corr[i, j]
return B_init7.2 学习率调度
# CauScale使用余弦退火学习率
scheduler = CosineAnnealingLR(
optimizer,
T_max=num_epochs,
eta_min=1e-6
)7.3 DAG约束的投影方法
当优化过程违反DAG约束时,需要投影回DAG空间:
def project_to_dag(B, method='exponential'):
"""将实矩阵投影到DAG"""
if method == 'exponential':
# 矩阵指数投影
B_proj = B - B.T @ np.diag(np.diag(B))
return np.triu(B_proj, k=1)
elif method == 'threshold':
# 阈值投影
B_proj = B.copy()
B_proj[np.abs(B_proj) < threshold] = 0
return np.triu(B_proj, k=1)
elif method == 'spectral':
# 谱投影
eigvals, eigvecs = np.linalg.eig(B)
eigvals = np.maximum(eigvals.real, 0)
return eigvecs @ np.diag(eigvals) @ eigvecs.T7.4 稀疏性正则化
class SparseDAGLoss(nn.Module):
def __init__(self, lambda_l1=0.01, lambda_smooth=0.01):
super().__init__()
self.lambda_l1 = lambda_l1
self.lambda_smooth = lambda_smooth
def forward(self, B, X, X_pred):
# 重构损失
recon_loss = F.mse_loss(X, X_pred)
# L1正则化(稀疏性)
l1_loss = self.lambda_l1 * torch.sum(torch.abs(B))
# 光滑正则化(结构连续性)
smooth_loss = self.lambda_smooth * torch.sum(
(B[:, :-1] - B[:, 1:]) ** 2
)
return recon_loss + l1_loss + smooth_loss8. 因果发现的可解释性
8.1 边重要性的度量
def compute_edge_importance(B, X, num_bootstrap=100):
"""基于Bootstrap的边重要性度量"""
n_samples = X.shape[0]
edge_scores = np.zeros_like(np.abs(B))
for _ in range(num_bootstrap):
# Bootstrap采样
idx = np.random.choice(n_samples, n_samples, replace=True)
X_boot = X[idx]
# 计算边权重的变化
# ... (run DAG learning on bootstrap sample)
edge_scores += np.abs(B_boot)
return edge_scores / num_bootstrap8.2 因果路径分析
def analyze_causal_paths(B, source, target, max_length=5):
"""分析从source到target的因果路径"""
paths = []
def dfs(current, path):
if len(path) > max_length:
return
if current == target:
paths.append(path.copy())
return
for next_node in np.where(B[current] > 0)[0]:
if next_node not in path:
dfs(next_node, path + [next_node])
dfs(source, [source])
return paths9. 工具与框架
9.1 CausalNex
from causalnex.structure import StructureModel
# 创建结构模型
sm = StructureModel()
# 使用NOTEARS学习结构
sm = sm.from_pandas(X)
# 获取邻接矩阵
adj_matrix = sm.graph9.2 DoWhy
import dowhy
# 创建因果图
model = dowhy.CausalModel(
data=data,
treatment='X',
outcome='Y',
common_causes=['Z1', 'Z2']
)
# 识别因果效应
identified_estimand = model.identify_effect()
# 估计因果效应
estimate = model.estimate_effect(
identified_estimand,
method_name="backdoor.propensity_score_matching"
)9.3 gCastle
from castle.algorithms import NOTEARS, GraN_DAG
# NOTEARS
notears = NOTEARS()
notears.fit(X)
# GraN-DAG
gran_dag = GraN_DAG()
gran_dag.fit(X)10. 方法对比总结
| 方法 | 可扩展性 | 非线性 | 可识别性 | 理论保证 |
|---|---|---|---|---|
| NOTEARS | 中等 | 否 | 无 | 弱 |
| NOTEARS-MLP | 中等 | 是 | 无 | 弱 |
| GraN-DAG | 低 | 是 | 无 | 弱 |
| DAG-GNN | 中等 | 是 | 无 | 弱 |
| CauScale | 高 | 是 | 无 | 弱 |
| NOTIME | 中等 | 否 | 强 | 强 |
| Score Matching | 高 | 是 | 部分 | 中等 |
11. 参考文献
相关主题
Footnotes
-
Zheng, X., et al. (2018). DAGs with NO TEARS: Continuous optimization for structure learning. NeurIPS. ↩
-
Yu, K., et al. (2019). DAGs with NO TEARS: A unifying perspective. ICLR Workshop. ↩
-
Daly, A., et al. (2024). DAGMA: DAG Learning with Autoregressive Models. arXiv. ↩
-
Lachapelle, S., et al. (2020). GraN-DAG: Gradient-based neural DAG learning. ICLR. ↩
-
Yu, K., et al. (2019). DAG-GNN: DAG structure learning with graph neural networks. ICML. ↩
-
Ng, I., et al. (2020). Characteristic conditions for continuous optimization-based DAG learning. ICLR. ↩
-
Chen, Y., et al. (2024). Recovery Guarantees for NOTEARS. COLT. ↩
-
CauScale Authors. (2026). CauScale: Ultra-scalable Neural Causal Discovery. arXiv:2602.08629. ↩
-
Berrévoets, et al. (2025). NOTIME: Identifiable DAG Learning. AISTATS 2025. ↩
-
Hyvärinen, A. (2005). Estimation of non-normal linear statistical models. JMLR. ↩
-
Score Matching Authors. (2025). Score Matching Through the Roof. CLeaR 2025. ↩