因子图与消息传递现代理论
1 引言
因子图(Factor Graph)作为概率图模型的核心表示框架,为联合概率分布的分解与高效推断提供了统一的数学语言。自Kschischang等人于2001年系统化提出因子图理论以来1,消息传递算法已成为统计推断、机器学习和深度学习领域的基石性技术。
本文系统阐述因子图与消息传递的现代理论框架,深入分析和积算法与置信传播的数学本质,探讨高斯消息传递在线性系统中的精确推断能力,并揭示其与神经网络的深刻联系。
1.1 与现有内容的关系
本文建立在以下相关文档的理论基础之上:
- 因子图与消息传递算法 — 基础概念
- 因子图与置信传播的统一框架 — 置信传播框架
- 消息传递神经网络 — MPNN框架
- GNN消息传递机制深度解析 — GNN消息传递
1.2 符号约定
本文采用以下符号约定:
| 符号 | 含义 |
|---|---|
| 随机变量集合 | |
| 第 个变量 | |
| 第 个因子节点 | |
| 节点 的邻居集合 | |
| 从变量 到因子 的消息 | |
| 从因子 到变量 的消息 | |
| 配分函数(归一化常数) |
2 因子图基础回顾
2.1 因子图的数学定义
定义2.1(因子图):因子图是一个二分图 ,其中:
- 是变量节点集合
- 是因子节点集合
- 是连接变量与因子的边集合
每个变量节点 对应一个随机变量,每个因子节点 对应一个局部势函数(potential function)。
2.2 联合分布的因子分解
设 表示所有变量的集合, 表示与因子 相连的变量集合。因子图表示的联合分布为:
其中 是配分函数(partition function),定义为:
注:势函数 不必是归一化的概率分布,只需要是非负函数。归一化由配分函数 完成。
2.3 因子图表示能力的分析
因子图的表示能力源于其对条件独立性的精确编码。由因子图的结构可以直接读出变量间的条件独立性:
定理2.1(条件独立性):在因子图 中,给定变量集合 ,变量 与 条件独立当且仅当所有从 到 的路径都被 阻断。
这一性质使得因子图成为编码概率论中复杂依赖结构的利器。
2.4 与其他图模型的比较
因子图与两类经典图模型有着深刻的联系:
因子图的二分图结构是其独特优势:变量节点只与因子节点相连,因子节点只与变量节点相连。这种结构消除了歧义,使得消息传递算法的推导更加清晰。
3 和积算法(Sum-Product Algorithm)深度解析
3.1 算法目标与动机
和积算法(又称置信传播)旨在高效计算因子图中所有变量的边缘概率分布:
直接计算边缘分布的复杂度是指数级的(),而和积算法利用因子图的分解结构,将复杂度降低到与因子图的树宽成正比。
3.2 消息传递规则的形式化推导
3.2.1 从因子到变量的消息
设因子节点 连接到变量集合 。 向变量 传递的消息定义为:
推导:消息 表示在因子 的约束下,变量 应该携带的信息。这通过对因子势函数求边缘化得到,同时”吸收”了来自其他邻居变量的消息。
3.2.2 从变量到因子的消息
设变量节点 连接到因子集合 。 向因子 传递的消息定义为:
推导:消息 聚合了 从所有其他因子接收到的所有信息。由于这些因子对 的影响是独立的,消息是它们的乘积。
3.2.3 消息传递的计算图
┌─────────────────────────────────────────────────────────────────────────┐
│ 消息传递规则图示 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 因子节点 f_a │
│ │ │
│ ┌─────┼─────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ x_i x_j x_k ← 变量节点 │
│ │ │ │
│ │ │ │
│ ▼ ▼ │
│ μ_{f_a→x_i} μ_{f_a→x_k} │
│ │
│ 消息计算: │
│ μ_{f_a→x_j}(x_j) = Σ_{x_i,x_k} f_a(x_i,x_j,x_k) · μ_{x_i→f_a}(x_i) · μ_{x_k→f_a}(x_k) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
3.3 边缘概率计算
当消息传递完成后(树结构下为一次遍历),每个变量的边缘分布可以计算为:
定理3.1(和积算法的正确性):对于树结构的因子图,和积算法计算得到的边缘分布是精确的。
证明梗概:通过归纳法证明每条消息的正确性。基础情况是叶子节点的消息;归纳步骤假设所有已计算消息正确,证明相邻消息正确。
3.4 归一化常数的计算
配分函数 可以通过聚合根节点的消息来计算:
其中 是因子 的局部归一化常数。
3.5 对数域消息传递
为了数值稳定性,实际实现中常使用对数域的消息传递:
然而,加法的对数操作需要替换为 Log-Sum-Exp 操作:
3.6 最大乘积算法(Max-Product Algorithm)
MAP推断(最大后验概率推断)将求和替换为取最大值:
消息更新规则变为:
最大乘积算法与Viterbi算法有深刻联系,在序列标注问题中广泛应用。
4 置信传播(Belief Propagation)
4.1 置信消息的定义
定义4.1(置信):变量 的置信(belief)定义为对其边缘分布的近似:
类似地,因子 的置信为:
4.2 消息调度策略
消息传递的顺序(调度)对算法收敛速度有重要影响。
4.2.1 同步消息传递
所有消息同时更新:
优点:易于并行化
缺点:可能振荡,难以收敛
4.2.2 异步消息传递(洪水算法)
每次只更新一个节点的消息,顺序执行。
优点:通常收敛更快
缺点:难以并行化
4.2.3 残差消息传递(Residual Belief Propagation)
每次选择变化最大的消息优先更新:
优先更新具有最大残差的消息。
4.3 收敛性分析
4.3.1 树结构的收敛保证
定理4.1:对于树结构的因子图,和积算法在单次遍历后收敛到精确的边缘分布。
4.3.2 循环结构的挑战
当因子图存在环(cycles)时,消息传递可能:
- 收敛到近似解
- 振荡而不收敛
- 发散(数值不稳定)
4.3.3 收敛的充分条件
定理4.2(势函数有界性):如果所有因子势函数被常数 上下界,即 ,则循环置信传播(LBP)收敛。
直觉:归一化的消息空间是紧的,消息映射是连续的,因此存在不动点。
4.4 循环置信传播(Loopy Belief Propagation)
对于有环因子图,循环置信传播是一种近似推断方法:
class LoopyBeliefPropagation:
"""
循环置信传播实现
适用于有环图模型的近似推断
"""
def __init__(self, num_states, damping=0.5):
"""
Args:
num_states: 每个变量的状态数
damping: 阻尼因子 (0-1),用于加速收敛
"""
self.num_states = num_states
self.damping = damping
def run(self, factors, adjacency, max_iter=100, tol=1e-6):
"""
执行循环置信传播
Args:
factors: dict {factor_id: {'vars': [var_ids], 'potential': array}}
adjacency: dict {var_id: [factor_ids]}
max_iter: 最大迭代次数
tol: 收敛容忍度
Returns:
beliefs: dict {var_id: belief_array}
"""
num_vars = len(adjacency)
# 初始化消息
messages_f_to_x = {} # (factor_id, var_id) -> message
messages_x_to_f = {} # (var_id, factor_id) -> message
for _ in range(max_iter):
max_change = 0
# 更新所有因子到变量的消息
for f_id, factor in factors.items():
vars_in_factor = factor['vars']
potential = factor['potential']
for i, x_id in enumerate(vars_in_factor):
# 计算新消息
new_msg = self._compute_factor_to_var_message(
factor, i, vars_in_factor, messages_x_to_f
)
old_msg = messages_f_to_x.get((f_id, x_id),
torch.ones(self.num_states))
# 阻尼更新
damped_msg = (1 - self.damping) * new_msg + self.damping * old_msg
# 归一化
damped_msg = damped_msg / damped_msg.sum()
messages_f_to_x[(f_id, x_id)] = damped_msg
max_change = max(max_change,
torch.abs(damped_msg - old_msg).max().item())
# 更新所有变量到因子的消息
for x_id, factor_ids in adjacency.items():
for f_id in factor_ids:
# 收集来自其他因子的消息
other_msg = torch.ones(self.num_states)
for other_f_id in factor_ids:
if other_f_id != f_id:
if (other_f_id, x_id) in messages_f_to_x:
other_msg = other_msg * messages_f_to_x[(other_f_id, x_id)]
old_msg = messages_x_to_f.get((x_id, f_id),
torch.ones(self.num_states))
# 阻尼更新
damped_msg = (1 - self.damping) * other_msg + self.damping * old_msg
# 归一化
damped_msg = damped_msg / damped_msg.sum()
messages_x_to_f[(x_id, f_id)] = damped_msg
# 检查收敛
if max_change < tol:
print(f"Converged after {_+1} iterations")
break
# 计算最终信念
beliefs = {}
for x_id in range(num_vars):
belief = torch.ones(self.num_states)
for f_id in adjacency[x_id]:
if (f_id, x_id) in messages_f_to_x:
belief = belief * messages_f_to_x[(f_id, x_id)]
beliefs[x_id] = belief / belief.sum()
return beliefs
def _compute_factor_to_var_message(self, factor, var_idx, vars_in_factor, messages):
"""计算因子到变量的消息"""
potential = factor['potential']
num_vars = len(vars_in_factor)
# 对所有其他变量求和/积分
result = torch.zeros(self.num_states)
# 简化实现:假设二元因子
if num_vars == 2:
other_idx = 1 - var_idx
for s1 in range(self.num_states):
for s2 in range(self.num_states):
if var_idx == 0:
msg_val = potential[s1, s2]
else:
msg_val = potential[s2, s1]
# 乘以来自其他变量的消息
if (f"var_{other_idx}", f"factor_{factor['id']}") in messages:
other_msg = messages[(f"var_{other_idx}", f"factor_{factor['id']}")][s2]
msg_val = msg_val * other_msg
result[s1] += msg_val
return result4.5 置信传播的变体
| 方法 | 描述 | 适用范围 |
|---|---|---|
| 标准BP | 精确消息传递 | 树结构 |
| 循环BP | 迭代直到收敛 | 有环图 |
| 衰减BP | 阻尼因子稳定化 | 振荡问题 |
| Tree-Reweighted BP | 加权消息边界 | 低纠缠图 |
| Fractional BP | 分数加权 | 近似精度控制 |
5 高斯消息传递
5.1 线性高斯模型
高斯消息传递是处理线性高斯模型精确推断的强大工具。考虑以下模型:
其中均值 ,协方差 是正定矩阵。
5.2 高斯分布的消息表示
高斯分布可以用自然参数表示:
定义自然参数:
- 精度矩阵:
- 信息向量:
则高斯分布可以写为:
5.3 高斯消息传递规则
5.3.1 乘积操作
两个高斯分布的乘积仍是高斯分布:
意义:这正是变量节点的消息传递规则——乘积对应信息组合。
5.3.2 边际化操作
高斯分布的边际化也是高斯分布:
设 ,联合分布为:
则边缘分布为:
其中 , 是边缘精度矩阵。
意义:这正是因子节点的消息传递规则——边际化对应信息传递。
5.4 Kalman滤波器作为高斯BP
卡尔曼滤波器是高斯BP在时序模型中的特例。
5.4.1 状态空间模型
5.4.2 Kalman滤波的消息传递视角
时间步 t-1 时间步 t
│ │
▼ ▼
┌────────┐ ┌────────┐
│ 状态先验 │ │ 观测模型 │
│ x_{t-1}│ │ y_t │
└────────┘ └────────┘
│ ▲
│ μ_{t-1→t} │
▼ │
┌────────┐ │
│ 状态转移 │ │
│ A │ │
└────────┘ │
│ │
▼ │
┌────────┐ ┌────────┐
│ 预测分布 │ ──────────────────→ │ 更新分布 │
│ p(x_t) │ │ p(x_t|y_t) │
└────────┘ └────────┘
预测步(消息从 传到 ):
更新步(融合观测):
5.5 与变分推断的联系
高斯消息传递与变分推断中的平均场方法有深刻联系。
平均场假设下的变分分布:
变分消息更新(对数域):
对于高斯模型,这简化为矩匹配(moment matching):
6 因子图与神经网络的对应关系
6.1 MPNN的消息传递机制
消息传递神经网络(MPNN)与因子图消息传递有着形式上的深刻对应:
| MPNN组件 | 因子图BP组件 | 数学联系 |
|---|---|---|
| 消息函数 | 因子势函数 | 可学习的局部变换 |
| 聚合操作 | 求和操作 | 邻居信息组合 |
| 更新函数 | 置信计算 | 状态更新 |
| 迭代层数 | 消息传递步数 | 信息传播范围 |
6.2 消息传递视角下的GNN
以图注意力网络(GAT)为例,其消息传递可以解释为软化的因子图消息传递:
其中注意力系数 可以视为动态调整的势函数权重。
6.3 神经网络层与消息传递的统一
我们可以将神经网络层统一理解为参数化的消息传递:
关键区别:
| 方面 | 因子图BP | 神经网络 |
|---|---|---|
| 消息函数 | 由概率模型固定 | 可学习 |
| 目标 | 推断后验分布 | 最小化任务损失 |
| 优化 | 闭式/迭代 | 梯度下降 |
6.4 可微分置信传播
现代深度学习框架可以将BP嵌入神经网络,实现可微分的消息传递:
class DifferentiableMessagePassing(nn.Module):
"""
可微分消息传递层
将因子图BP的消息传递操作参数化
"""
def __init__(self, node_dim, hidden_dim, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
# 消息函数(替代因子势函数)
self.message_net = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 注意力网络(动态势函数)
self.attention_net = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.LeakyReLU(),
nn.Linear(hidden_dim, num_heads)
)
# 更新函数
self.update_net = nn.Sequential(
nn.Linear(node_dim + hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)
def forward(self, x, edge_index):
"""
Args:
x: 节点特征 (num_nodes, node_dim)
edge_index: 边索引 (2, num_edges)
Returns:
更新后的节点特征 (num_nodes, hidden_dim)
"""
row, col = edge_index # 源节点 -> 目标节点
# 计算消息
source_features = x[row] # (num_edges, node_dim)
target_features = x[col] # (num_edges, node_dim)
# 拼接源和目标特征
combined = torch.cat([source_features, target_features], dim=-1)
# 注意力权重
att_scores = self.attention_net(combined) # (num_edges, num_heads)
att_weights = F.softmax(att_scores, dim=0) # (num_edges, num_heads)
# 消息内容
messages = self.message_net(combined) # (num_edges, hidden_dim)
messages = messages.view(-1, self.num_heads, self.head_dim)
# 加权消息
weighted_messages = messages * att_weights.unsqueeze(-1) # (num_edges, heads, head_dim)
# 聚合消息(按头聚合)
aggregated = weighted_messages.sum(dim=0) # (num_nodes, hidden_dim)
# 更新节点特征
updated = self.update_net(torch.cat([x[:, :self.num_heads * self.head_dim], aggregated], dim=-1))
return updated6.5 神经网络反向传播作为消息传递
神经网络中的反向传播算法与因子图BP有深刻的数学联系:
正向传播(因子图视角):
- 每层是一个因子节点
- 神经元激活是变量
- 权重是因子势函数
反向传播(消息传递视角):
- 梯度从输出传回输入
- 链式法则对应消息组合
- 每层的局部梯度是消息函数
class BPAsMessagePassing:
"""
将反向传播解释为消息传递
展示BP与因子图BP的数学联系
"""
def forward_message(self, x, W):
"""
正向消息:x -> z = Wx + b
等价于因子节点的消息传递
"""
z = torch.matmul(x, W.T)
return z
def backward_message(self, grad_z, W):
"""
反向消息:∂L/∂x = (∂L/∂z) · W
等价于变量节点的消息传递
"""
grad_x = torch.matmul(grad_z, W)
return grad_x
def weight_gradient_message(self, grad_z, x):
"""
权重梯度:∂L/∂W = (∂L/∂z)^T · x
等价于因子节点的边际化
"""
grad_W = torch.matmul(grad_z.T, x)
return grad_W7 现代扩展
7.1 期望传播(Expectation Propagation)
期望传播(EP)由Minka提出,是BP的变分扩展,适用于难以精确边际化的因子。
7.1.1 基本思想
EP用指数族分布的乘积近似后验分布:
每个 是对应因子的瘦息分布(cavity distribution)。
7.1.2 EP更新规则
对于因子 ,EP的更新步骤:
- 构造瘦息分布:
- 计算 tilt 分布:
- 匹配矩:
其中 是指数族分布族。
7.1.3 高斯EP的实现
class ExpectationPropagation:
"""
期望传播实现(高斯情况)
"""
def __init__(self, num_vars, num_factors):
self.num_vars = num_vars
self.factors = {} # factor_id -> potential function
# 初始化瘦息分布参数
self.cavity_precision = {} # (factor_id, var_id) -> precision
self.cavity_mean = {} # (factor_id, var_id) -> mean
def cavity_update(self, factor_id, var_id):
"""
构造瘦息分布
q_{-a}(x) = Π_{b≠a} q_b(x)
对于高斯分布,乘积对应精度矩阵和均值向量的加法
"""
# 从当前信念中移除该因子的贡献
# 简化实现:假设单变量情况
pass
def moment_match(self, factor_id, var_id, tilt_dist):
"""
矩匹配
从tilt分布计算均值和方差,更新瘦息分布
"""
# 计算tilt分布的矩
new_mean = tilt_dist.mean()
new_var = tilt_dist.variance()
# 更新瘦息分布参数
return new_mean, new_var
def run(self, max_iter=100, tol=1e-6):
"""运行EP迭代"""
for iteration in range(max_iter):
max_change = 0
for factor_id, factor in self.factors.items():
for var_id in factor['vars']:
# 1. 计算瘦息分布
cavity = self.cavity_update(factor_id, var_id)
# 2. 计算tilt分布
tilt = self.compute_tilt_distribution(factor, cavity)
# 3. 矩匹配
new_mean, new_var = self.moment_match(factor_id, var_id, tilt)
max_change = max(max_change,
abs(new_mean - cavity.mean) +
abs(new_var - cavity.variance))
if max_change < tol:
print(f"EP converged after {iteration + 1} iterations")
break
return self.compute_posterior()7.2 变分消息传递
变分消息传递是变分推断与消息传递的结合,适用于大规模近似推断。
7.2.1 平均场变分推断
假设变分分布可分解:
最小化 等价于最大化ELBO:
7.2.2 变分消息更新
变量 的最优变分分布:
这正是消息传递框架中的消息计算!
class VariationalMessagePassing:
"""
变分消息传递
实现平均场变分推断的消息传递形式
"""
def __init__(self, num_vars, num_states):
self.num_vars = num_vars
self.num_states = num_states
# 变分参数
self.variational_params = nn.Parameter(
torch.randn(num_vars, num_states)
)
def compute_expected_log_potential(self, factor, q_dists):
"""
计算 E_q[log f_a(X_a)]
这是变分消息的核心计算
"""
vars_in_factor = factor['vars']
potential = factor['potential']
expected = 0.0
for state_config in itertools.product(range(self.num_states), repeat=len(vars_in_factor)):
# 计算该配置的log势能
log_pot = potential[state_config]
# 计算各变量的变分概率
for i, var_id in enumerate(vars_in_factor):
log_pot += torch.log(q_dists[var_id][state_config[i]])
expected += torch.exp(log_pot)
return expected
def variational_message_update(self, var_id, factors, q_dists):
"""
变量节点的变分消息更新
log q_i*(x_i) ∝ E_{q_-i}[log p(X)]
"""
new_log_q = torch.zeros(self.num_states)
for factor in factors:
if var_id not in factor['vars']:
continue
# 计算期望log势能
expected_log_pot = self.compute_expected_log_potential(factor, q_dists)
new_log_q += expected_log_pot
# 归一化
new_q = F.softmax(new_log_q, dim=-1)
return new_q
def run_variational_inference(self, factors, max_iter=100):
"""
运行变分推断
"""
q_dists = [torch.ones(self.num_states) / self.num_states
for _ in range(self.num_vars)]
for iteration in range(max_iter):
new_q_dists = []
for var_id in range(self.num_vars):
# 收集涉及该变量的因子
relevant_factors = [f for f in factors if var_id in f['vars']]
# 更新变分分布
new_q = self.variational_message_update(var_id, relevant_factors, q_dists)
new_q_dists.append(new_q)
# 计算变化
max_change = max(torch.abs(new_q - old_q).max()
for new_q, old_q in zip(new_q_dists, q_dists))
q_dists = new_q_dists
if max_change < 1e-6:
print(f"Converged after {iteration + 1} iterations")
break
return q_dists7.3 粒子消息传递
粒子消息传递(Particle Message Passing)使用蒙特卡洛采样近似消息,适用于复杂非共轭模型。
7.3.1 基本思想
用粒子集合 表示分布:
7.3.2 粒子消息更新
class ParticleMessagePassing:
"""
粒子消息传递
使用重要性采样近似消息传递
"""
def __init__(self, num_particles=100):
self.num_particles = num_particles
def sample_particles(self, proposal_dist, num_samples):
"""从提议分布采样粒子"""
samples = proposal_dist.sample((num_samples,))
return samples
def compute_importance_weights(self, target_log_prob, proposal_log_prob, samples):
"""
计算重要性权重
w ∝ p(x) / q(x)
"""
target_log_probs = target_log_prob(samples)
proposal_log_probs = proposal_log_prob(samples)
log_weights = target_log_probs - proposal_log_probs
weights = F.softmax(log_weights, dim=0)
return weights
def particle_message_update(self, factor, particles, weights):
"""
因子到变量的粒子消息更新
μ_{f→x}(x) ≈ Σ_s w_s f(x, x_-^{(s)}) δ(x - x^{(s)})
"""
# 计算每个粒子的势函数值
factor_values = factor.potential(particles) # (num_particles,)
# 加权
weighted_values = weights * factor_values
# 重采样(可选)
new_particles = self.resample(particles, weights)
return new_particles, weighted_values
def resample(self, particles, weights):
"""多项式重采样"""
num_particles = particles.shape[0]
indices = torch.multinomial(weights, num_particles, replacement=True)
return particles[indices]
def run(self, factors, num_iter=10):
"""运行粒子消息传递"""
# 初始化粒子
particles = {i: torch.randn(self.num_particles)
for i in range(num_vars)}
weights = {i: torch.ones(self.num_particles) / self.num_particles
for i in range(num_vars)}
for _ in range(num_iter):
# 消息传递迭代
for factor in factors:
# 更新涉及该因子的变量的粒子
pass
return particles, weights7.4 方法比较
| 方法 | 消息形式 | 适用范围 | 计算复杂度 |
|---|---|---|---|
| 精确BP | 闭式 | 树结构 | |
| 循环BP | 迭代 | 有环图 | 取决于收敛 |
| 期望传播 | 指数族近似 | 一般图 | 每步 |
| 变分消息传递 | 变分分布 | 大规模 | 可并行 |
| 粒子消息传递 | 粒子集合 | 非共轭模型 | 每步 |
8 PyTorch完整实现
8.1 因子图类定义
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import itertools
@dataclass
class VariableNode:
"""变量节点"""
id: int
name: str
num_states: int
domain: Optional[torch.Tensor] = None # 连续变量的取值范围
@dataclass
class FactorNode:
"""因子节点"""
id: int
name: str
variable_ids: List[int]
potential: torch.Tensor # 势函数(未归一化)
class FactorGraph(nn.Module):
"""
因子图类
支持构建、消息传递和推断
"""
def __init__(self, name="FactorGraph"):
super().__init__()
self.name = name
self.variables: Dict[int, VariableNode] = {}
self.factors: Dict[int, FactorNode] = {}
self.adjacency: Dict[int, List[int]] = {} # var_id -> [factor_ids]
self.factor_to_vars: Dict[int, List[int]] = {} # factor_id -> [var_ids]
# 消息缓存
self.messages_var_to_factor: Dict[Tuple[int, int], torch.Tensor] = {}
self.messages_factor_to_var: Dict[Tuple[int, int], torch.Tensor] = {}
def add_variable(self, var_id: int, name: str, num_states: int = None,
domain: torch.Tensor = None):
"""添加变量节点"""
if num_states is None and domain is None:
raise ValueError("Must specify either num_states or domain")
self.variables[var_id] = VariableNode(
id=var_id,
name=name,
num_states=num_states or len(domain),
domain=domain
)
self.adjacency[var_id] = []
def add_factor(self, factor_id: int, name: str, variable_ids: List[int],
potential: torch.Tensor):
"""
添加因子节点
Args:
potential: 势函数张量,维度与variable_ids对应
"""
self.factors[factor_id] = FactorNode(
id=factor_id,
name=name,
variable_ids=variable_ids,
potential=potential
)
self.factor_to_vars[factor_id] = variable_ids
for var_id in variable_ids:
if var_id not in self.adjacency:
self.adjacency[var_id] = []
self.adjacency[var_id].append(factor_id)
def get_variable(self, var_id: int) -> VariableNode:
return self.variables[var_id]
def get_factor(self, factor_id: int) -> FactorNode:
return self.factors[factor_id]8.2 和积算法实现
class SumProductAlgorithm:
"""
和积算法(Sum-Product Algorithm)实现
支持树结构和有环图的近似推断
"""
def __init__(self, graph: FactorGraph, damping: float = 0.0):
"""
Args:
graph: 因子图
damping: 阻尼因子 (0-1),用于LBP稳定化
"""
self.graph = graph
self.damping = damping
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def initialize_messages(self):
"""初始化消息"""
# 变量到因子的消息初始化为均匀分布
for var_id, var in self.graph.variables.items():
for factor_id in self.graph.adjacency[var_id]:
key = (var_id, factor_id)
self.graph.messages_var_to_factor[key] = torch.ones(
var.num_states, device=self.device
) / var.num_states
# 因子到变量的消息初始化为因子势函数
for factor_id, factor in self.graph.factors.items():
for var_id in factor.variable_ids:
key = (factor_id, var_id)
# 边缘化势函数到该变量
msg = self._marginalize_potential(factor, var_id)
self.graph.messages_factor_to_var[key] = msg
def _marginalize_potential(self, factor: FactorNode, target_var_id: int) -> torch.Tensor:
"""将势函数边缘化到目标变量"""
potential = factor.potential.to(self.device)
var_ids = factor.variable_ids
target_idx = var_ids.index(target_var_id)
# 对所有其他维度求和
axes = [i for i in range(len(var_ids)) if i != target_idx]
if axes:
marginal = torch.sum(potential, dim=axes)
else:
marginal = potential
# 归一化
marginal = marginal / marginal.sum()
return marginal
def compute_variable_to_factor_message(self, var_id: int,
factor_id: int) -> torch.Tensor:
"""计算变量到因子的消息"""
var = self.graph.variables[var_id]
# 消息是所有其他因子传来消息的乘积
msg = torch.ones(var.num_states, device=self.device)
for other_factor_id in self.graph.adjacency[var_id]:
if other_factor_id != factor_id:
key = (other_factor_id, var_id)
if key in self.graph.messages_factor_to_var:
msg = msg * self.graph.messages_factor_to_var[key]
# 归一化
msg = msg / msg.sum()
return msg
def compute_factor_to_variable_message(self, factor_id: int,
var_id: int) -> torch.Tensor:
"""计算因子到变量的消息"""
factor = self.graph.factors[factor_id]
var_ids = factor.variable_ids
var_idx = var_ids.index(var_id)
# 获取势函数
potential = factor.potential.to(self.device)
# 计算消息:边缘化势函数并乘以传入消息
# 简化实现:假设势函数维度不大
num_vars = len(var_ids)
num_states = self.graph.variables[var_id].num_states
if num_vars == 1:
# 一元因子
msg = potential
elif num_vars == 2:
# 二元因子
other_var_id = var_ids[1 - var_idx]
other_msg = self.graph.messages_var_to_factor[(other_var_id, factor_id)]
if var_idx == 0:
msg = torch.sum(potential * other_msg.unsqueeze(1), dim=1)
else:
msg = torch.sum(potential * other_msg.unsqueeze(0), dim=0)
else:
# 通用实现
msg = self._general_factor_message(potential, var_ids, var_idx)
# 归一化
msg = msg / msg.sum()
return msg
def _general_factor_message(self, potential: torch.Tensor,
var_ids: List[int],
target_idx: int) -> torch.Tensor:
"""通用因子消息计算(支持任意数量变量)"""
# 收集所有传入消息
incoming_messages = []
axes_to_sum = []
for i, var_id in enumerate(var_ids):
if i != target_idx:
msg = self.graph.messages_var_to_factor[(var_id, self.graph.factors[var_ids[0]].id)]
incoming_messages.append((i, msg))
axes_to_sum.append(i)
# 乘以势函数
result = potential.clone()
for axis, msg in incoming_messages:
# 为消息添加维度以便广播
shape = [1] * len(var_ids)
shape[axis] = -1
result = result * msg.view(shape)
# 求和边缘化
msg = torch.sum(result, dim=axes_to_sum)
return msg
def run_tree_bp(self) -> Dict[int, torch.Tensor]:
"""
在树结构图上运行和积算法
Returns:
beliefs: 每个变量的边缘分布
"""
# 找到根节点(选择第一个变量)
root_var_id = list(self.graph.variables.keys())[0]
# 计算节点顺序用于后序遍历
parent_map, order = self._get_traversal_order(root_var_id)
# 自底向上:计算向叶子方向的消息
for var_id in reversed(order):
for factor_id in self.graph.adjacency[var_id]:
if parent_map.get(var_id) != factor_id:
# 这是一个叶子方向的因子
pass
# 自顶向下:传递向根方向的消息
for var_id in order:
parent_factor = parent_map.get(var_id)
if parent_factor is not None:
# 计算从父因子到该变量的消息
msg = self.compute_factor_to_variable_message(parent_factor, var_id)
self.graph.messages_factor_to_var[(parent_factor, var_id)] = msg
# 计算信念
return self.compute_beliefs()
def _get_traversal_order(self, root_var_id: int) -> Tuple[Dict, List]:
"""获取遍历顺序和父子关系"""
parent_map = {root_var_id: None}
order = [root_var_id]
# BFS遍历
queue = [root_var_id]
while queue:
var_id = queue.pop(0)
for factor_id in self.graph.adjacency[var_id]:
for next_var_id in self.graph.factor_to_vars[factor_id]:
if next_var_id not in parent_map:
parent_map[next_var_id] = factor_id
order.append(next_var_id)
queue.append(next_var_id)
return parent_map, order
def run_loopy_bp(self, max_iter: int = 100, tol: float = 1e-6,
schedule: str = 'random') -> Dict[int, torch.Tensor]:
"""
在有环图上运行循环置信传播
Args:
max_iter: 最大迭代次数
tol: 收敛容忍度
schedule: 调度策略 ('random', 'residual', 'flooding')
Returns:
beliefs: 每个变量的边缘分布
"""
self.initialize_messages()
best_beliefs = None
best_energy = float('inf')
for iteration in range(max_iter):
max_change = 0.0
if schedule == 'random':
# 随机调度
var_ids = list(self.graph.variables.keys())
np.random.shuffle(var_ids)
for var_id in var_ids:
for factor_id in self.graph.adjacency[var_id]:
# 计算并更新消息
new_msg = self.compute_variable_to_factor_message(var_id, factor_id)
old_msg = self.graph.messages_var_to_factor.get(
(var_id, factor_id), torch.ones_like(new_msg)
)
# 阻尼
if self.damping > 0:
new_msg = ((1 - self.damping) * new_msg +
self.damping * old_msg)
change = torch.abs(new_msg - old_msg).max().item()
max_change = max(max_change, change)
self.graph.messages_var_to_factor[(var_id, factor_id)] = new_msg
# 更新反向消息
reverse_msg = self.compute_factor_to_variable_message(
factor_id, var_id
)
if self.damping > 0:
old_reverse = self.graph.messages_factor_to_var.get(
(factor_id, var_id), torch.ones_like(reverse_msg)
)
reverse_msg = ((1 - self.damping) * reverse_msg +
self.damping * old_reverse)
self.graph.messages_factor_to_var[(factor_id, var_id)] = reverse_msg
elif schedule == 'flooding':
# 洪水算法:同时更新所有消息
new_var_to_factor = {}
new_factor_to_var = {}
for var_id, var in self.graph.variables.items():
for factor_id in self.graph.adjacency[var_id]:
new_var_to_factor[(var_id, factor_id)] = \
self.compute_variable_to_factor_message(var_id, factor_id)
for factor_id, factor in self.graph.factors.items():
for var_id in factor.variable_ids:
new_factor_to_var[(factor_id, var_id)] = \
self.compute_factor_to_variable_message(factor_id, var_id)
# 更新所有消息
for key, msg in new_var_to_factor.items():
old_msg = self.graph.messages_var_to_factor.get(
key, torch.ones_like(msg)
)
change = torch.abs(msg - old_msg).max().item()
max_change = max(max_change, change)
if self.damping > 0:
msg = (1 - self.damping) * msg + self.damping * old_msg
self.graph.messages_var_to_factor[key] = msg
for key, msg in new_factor_to_var.items():
old_msg = self.graph.messages_factor_to_var.get(
key, torch.ones_like(msg)
)
if self.damping > 0:
msg = (1 - self.damping) * msg + self.damping * old_msg
self.graph.messages_factor_to_var[key] = msg
# 计算当前信念
beliefs = self.compute_beliefs()
# 计算伪对数似然(用于早停)
energy = self._compute_pseudo_energy(beliefs)
if energy < best_energy:
best_energy = energy
best_beliefs = {k: v.clone() for k, v in beliefs.items()}
if max_change < tol:
print(f"LBP converged after {iteration + 1} iterations")
break
return best_beliefs if best_beliefs else beliefs
def compute_beliefs(self) -> Dict[int, torch.Tensor]:
"""计算所有变量的信念(边缘分布)"""
beliefs = {}
for var_id, var in self.graph.variables.items():
belief = torch.ones(var.num_states, device=self.device)
for factor_id in self.graph.adjacency[var_id]:
key = (factor_id, var_id)
if key in self.graph.messages_factor_to_var:
belief = belief * self.graph.messages_factor_to_var[key]
# 归一化
belief = belief / belief.sum()
beliefs[var_id] = belief
return beliefs
def _compute_pseudo_energy(self, beliefs: Dict[int, torch.Tensor]) -> float:
"""
计算伪能量(用于监测收敛)
近似于负对数似然
"""
energy = 0.0
for factor_id, factor in self.graph.factors.items():
# 计算因子的期望能量
potential = factor.potential.to(self.device)
for var_id in factor.variable_ids:
belief = beliefs[var_id]
# 简化:计算势函数的期望
energy -= torch.sum(belief * torch.log(potential + 1e-10))
return energy.item()8.3 高斯消息传递实现
class GaussianMessagePassing:
"""
高斯消息传递实现
用于线性高斯模型的精确推断
"""
def __init__(self, device=None):
self.device = device or torch.device('cpu')
def gaussian_product(self, lambda1: torch.Tensor, xi1: torch.Tensor,
lambda2: torch.Tensor, xi2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
高斯分布的乘积
N(μ1, Σ1) * N(μ2, Σ2) = N(μ, Σ)
其中:
Σ = (Σ1⁻¹ + Σ2⁻¹)⁻¹
μ = Σ (Σ1⁻¹μ1 + Σ2⁻¹μ2)
自然参数形式:
Λ = Λ1 + Λ2
ξ = ξ1 + ξ2
"""
# 精度矩阵形式
Lambda = lambda1 + lambda2
xi = xi1 + xi2
return Lambda, xi
def gaussian_marginalize(self, Lambda: torch.Tensor, xi: torch.Tensor,
indices: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
高斯分布的边际化
保持指定索引的变量
"""
Lambda_mm = Lambda[indices][:, indices]
xi_m = xi[indices]
return Lambda_mm, xi_m
def kalman_update(self, mu: torch.Tensor, Sigma: torch.Tensor,
y: torch.Tensor, H: torch.Tensor, R: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Kalman更新(高斯BP的特殊情况)
Args:
mu: 先验均值
Sigma: 先验协方差
y: 观测值
H: 观测矩阵
R: 观测噪声协方差
Returns:
mu_posterior, Sigma_posterior: 后验均值和协方差
"""
# 预测
S = H @ Sigma @ H.T + R # 观测预测协方差
K = Sigma @ H.T @ torch.linalg.inv(S) # Kalman增益
# 更新
mu_posterior = mu + K @ (y - H @ mu)
Sigma_posterior = (torch.eye(mu.shape[0], device=self.device) - K @ H) @ Sigma
return mu_posterior, Sigma_posterior
def kalman_predict(self, mu: torch.Tensor, Sigma: torch.Tensor,
A: torch.Tensor, Q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Kalman预测(状态转移)
"""
mu_pred = A @ mu
Sigma_pred = A @ Sigma @ A.T + Q
return mu_pred, Sigma_pred
def run_linear_gaussian_bsf(self, observations: List[Tuple[int, torch.Tensor, torch.Tensor]],
F: torch.Tensor, Q: torch.Tensor,
H: torch.Tensor, R: torch.Tensor,
x0: torch.Tensor, Sigma0: torch.Tensor) -> Tuple[List, List]:
"""
运行线性高斯系统的最优估计(Kalman滤波/平滑)
Args:
observations: [(time, y, H), ...] 观测列表
F: 状态转移矩阵
Q: 过程噪声协方差
H: 观测矩阵
R: 观测噪声协方差
x0: 初始均值
Sigma0: 初始协方差
Returns:
filtered_means, filtered_covs: 滤波结果
smoothed_means, smoothed_covs: 平滑结果
"""
T = max(obs[0] for obs in observations) + 1
# 初始化
mu = x0
Sigma = Sigma0
filtered_means = []
filtered_covs = []
# 前向滤波
for t in range(T):
# 预测
mu_pred, Sigma_pred = self.kalman_predict(mu, Sigma, F, Q)
# 检查该时间是否有观测
obs_t = [(y, H_t) for (time, y, H_t) in observations if time == t]
if obs_t:
for y, H_t in obs_t:
# 更新
mu, Sigma = self.kalman_update(mu_pred, Sigma_pred, y, H_t, R)
mu_pred, Sigma_pred = mu, Sigma # 用于下一个观测
else:
mu, Sigma = mu_pred, Sigma_pred
filtered_means.append(mu)
filtered_covs.append(Sigma)
# 后向平滑
smoothed_means = filtered_means.copy()
smoothed_covs = filtered_covs.copy()
mu_smooth = filtered_means[-1]
Sigma_smooth = filtered_covs[-1]
for t in range(T - 2, -1, -1):
# 预测
mu_pred, Sigma_pred = self.kalman_predict(
filtered_means[t], filtered_covs[t], F, Q
)
# 平滑增益
try:
G = filtered_covs[t] @ F.T @ torch.linalg.inv(Sigma_pred)
except:
G = filtered_covs[t] @ F.T @ (Sigma_pred + 1e-6 * torch.eye(Sigma_pred.shape[0], device=self.device)).inverse()
# 平滑
mu_smooth = filtered_means[t] + G @ (mu_smooth - mu_pred)
Sigma_smooth = filtered_covs[t] + G @ (Sigma_smooth - Sigma_pred) @ G.T
smoothed_means[t] = mu_smooth
smoothed_covs[t] = Sigma_smooth
return (filtered_means, filtered_covs), (smoothed_means, smoothed_covs)8.4 使用示例
def example_simple():
"""简单示例:二元变量因子图"""
# 创建因子图
graph = FactorGraph("Simple Example")
# 添加变量
graph.add_variable(0, "x0", num_states=2)
graph.add_variable(1, "x1", num_states=2)
graph.add_variable(2, "x2", num_states=2)
# 添加因子
# f0(x0, x1) - 鼓励 x0 = x1
potential01 = torch.tensor([[2.0, 0.5],
[0.5, 2.0]])
graph.add_factor(0, "f01", [0, 1], potential01)
# f1(x1, x2) - 鼓励 x1 = x2
potential12 = torch.tensor([[2.0, 0.5],
[0.5, 2.0]])
graph.add_factor(1, "f12", [1, 2], potential12)
# f2(x0) - x0 的先验
potential0 = torch.tensor([0.3, 0.7])
graph.add_factor(2, "f0", [0], potential0)
# 运行和积算法
spa = SumProductAlgorithm(graph, damping=0.3)
beliefs = spa.run_loopy_bp(max_iter=100, tol=1e-6)
print("边缘分布:")
for var_id, belief in beliefs.items():
var = graph.get_variable(var_id)
print(f" {var.name}: {belief.numpy()}")
return beliefs
def example_gaussian():
"""示例:高斯消息传递(Kalman滤波)"""
# 状态空间模型参数
dt = 0.1 # 时间步长
F = torch.tensor([[1, dt],
[0, 1]]) # 状态转移
Q = torch.tensor([[0.01, 0],
[0, 0.01]]) # 过程噪声
H = torch.tensor([[1, 0]]) # 观测矩阵
R = torch.tensor([[0.1]]) # 观测噪声
# 初始状态
x0 = torch.tensor([0, 1])
Sigma0 = torch.tensor([[1, 0],
[0, 1]])
# 生成观测数据
np.random.seed(42)
true_states = [x0.numpy()]
observations = []
for t in range(20):
# 真实状态转移
x_true = F @ torch.tensor(true_states[-1]) + np.random.randn(2) * 0.1
true_states.append(x_true)
# 观测
y = H @ x_true + np.random.randn(1) * np.sqrt(R[0, 0])
observations.append((t, torch.tensor(y), H))
# 运行Kalman滤波/平滑
gmp = GaussianMessagePassing()
(filtered_means, filtered_covs), (smoothed_means, smoothed_covs) = \
gmp.run_linear_gaussian_bsf(observations, F, Q, H, R, x0, Sigma0)
print("滤波结果(显示前5个时间步):")
for t in range(5):
print(f" t={t}: mean={filtered_means[t].numpy()}, std={np.sqrt(filtered_covs[t].numpy().diagonal())}")
print("\n平滑结果(显示前5个时间步):")
for t in range(5):
print(f" t={t}: mean={smoothed_means[t].numpy()}, std={np.sqrt(smoothed_covs[t].numpy().diagonal())}")
return filtered_means, filtered_covs, smoothed_means, smoothed_covs
if __name__ == "__main__":
print("=" * 60)
print("示例1: 简单离散因子图")
print("=" * 60)
example_simple()
print("\n" + "=" * 60)
print("示例2: 高斯消息传递(Kalman滤波)")
print("=" * 60)
example_gaussian()9 理论总结
9.1 核心概念回顾
| 概念 | 定义 | 关键性质 |
|---|---|---|
| 因子图 | 变量-因子二部图 | 分解联合分布 |
| 和积算法 | 树结构的精确BP | 推断 |
| 循环BP | 有环图的近似推断 | 迭代收敛 |
| 高斯BP | 线性高斯的精确推断 | 闭式解 |
| 期望传播 | 指数族近似推断 | 变分扩展 |
9.2 与深度学习的统一视角
消息传递机制是连接概率推断与深度学习的桥梁:
┌─────────────────────────────────────────────────────────────────────────┐
│ 消息传递的统一视角 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────────┐ ┌───────────────────┐ │
│ │ 概率图模型 │ │ 深度学习 │ │
│ ├───────────────────┤ ├───────────────────┤ │
│ │ 因子节点 = 消息函数│ │ 神经网络层 │ │
│ │ 变量节点 = 隐状态 │ │ 神经元 │ │
│ │ 消息传递 = 推断 │ │ 前向传播 │ │
│ │ 信念 = 边缘分布 │ │ 激活值 │ │
│ │ 优化 = 最大化似然 │ │ 梯度下降 │ │
│ └───────────────────┘ └───────────────────┘ │
│ \ / │
│ \ / │
│ ▼ ▼ │
│ ┌─────────────────────────────────┐ │
│ │ 统一的消息传递框架 │ │
│ │ h_v^{(l+1)} = Update(h_v^{(l)}, │ │
│ │ AGG_{u∈N(v)} │ │
│ │ Message(h_u^{(l)}│ │
│ │ h_v^{(l)}, e_{uv})) │
│ └─────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
9.3 未来研究方向
- 可扩展性:大规模图的高效消息传递
- 异构图:多关系、多模态图的消息传递
- 动态图:时变图结构的消息传递
- 理论保证:收敛性、表达能力的形式化分析
- 与Transformer的融合:注意力机制作为软消息传递
参考文献
相关文档
- 因子图与消息传递算法 — 基础概念
- 因子图与置信传播的统一框架 — 置信传播框架
- 消息传递神经网络 — MPNN框架
- GNN消息传递机制深度解析 — GNN消息传递
- 概率与期望 — 概率论基础
- 卡尔曼滤波器 — 高斯BP的特例
Footnotes
-
Kschischang, F. R., Frey, B. J., & Loeliger, H. A. (2001). Factor graphs and the sum-product algorithm. IEEE Transactions on Information Theory, 47(2), 498-519. ↩