因子图与消息传递现代理论

1 引言

因子图(Factor Graph)作为概率图模型的核心表示框架,为联合概率分布的分解与高效推断提供了统一的数学语言。自Kschischang等人于2001年系统化提出因子图理论以来1,消息传递算法已成为统计推断、机器学习和深度学习领域的基石性技术。

本文系统阐述因子图与消息传递的现代理论框架,深入分析和积算法与置信传播的数学本质,探讨高斯消息传递在线性系统中的精确推断能力,并揭示其与神经网络的深刻联系。

1.1 与现有内容的关系

本文建立在以下相关文档的理论基础之上:

1.2 符号约定

本文采用以下符号约定:

符号含义
随机变量集合
个变量
个因子节点
节点 的邻居集合
从变量 到因子 的消息
从因子 到变量 的消息
配分函数(归一化常数)

2 因子图基础回顾

2.1 因子图的数学定义

定义2.1(因子图):因子图是一个二分图 ,其中:

  • 变量节点集合
  • 因子节点集合
  • 是连接变量与因子的边集合

每个变量节点 对应一个随机变量,每个因子节点 对应一个局部势函数(potential function)。

2.2 联合分布的因子分解

表示所有变量的集合, 表示与因子 相连的变量集合。因子图表示的联合分布为:

其中 配分函数(partition function),定义为:

:势函数 不必是归一化的概率分布,只需要是非负函数。归一化由配分函数 完成。

2.3 因子图表示能力的分析

因子图的表示能力源于其对条件独立性的精确编码。由因子图的结构可以直接读出变量间的条件独立性:

定理2.1(条件独立性):在因子图 中,给定变量集合 ,变量 条件独立当且仅当所有从 的路径都被 阻断。

这一性质使得因子图成为编码概率论中复杂依赖结构的利器。

2.4 与其他图模型的比较

因子图与两类经典图模型有着深刻的联系:

图模型类型表示方式因子图视角
贝叶斯网络有向无环图每个条件概率分布 是一个因子
马尔可夫随机场无向图每个最大团上的势函数是一个因子
因子图二分图显式分离变量与因子

因子图的二分图结构是其独特优势:变量节点只与因子节点相连,因子节点只与变量节点相连。这种结构消除了歧义,使得消息传递算法的推导更加清晰。


3 和积算法(Sum-Product Algorithm)深度解析

3.1 算法目标与动机

和积算法(又称置信传播)旨在高效计算因子图中所有变量的边缘概率分布

直接计算边缘分布的复杂度是指数级的(),而和积算法利用因子图的分解结构,将复杂度降低到与因子图的树宽成正比。

3.2 消息传递规则的形式化推导

3.2.1 从因子到变量的消息

设因子节点 连接到变量集合 向变量 传递的消息定义为:

推导:消息 表示在因子 的约束下,变量 应该携带的信息。这通过对因子势函数求边缘化得到,同时”吸收”了来自其他邻居变量的消息。

3.2.2 从变量到因子的消息

设变量节点 连接到因子集合 向因子 传递的消息定义为:

推导:消息 聚合了 从所有其他因子接收到的所有信息。由于这些因子对 的影响是独立的,消息是它们的乘积。

3.2.3 消息传递的计算图

┌─────────────────────────────────────────────────────────────────────────┐
│                        消息传递规则图示                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│    因子节点 f_a                                                          │
│           │                                                              │
│     ┌─────┼─────┐                                                        │
│     │     │     │                                                        │
│     ▼     ▼     ▼                                                        │
│   x_i   x_j   x_k   ← 变量节点                                           │
│     │           │                                                        │
│     │           │                                                        │
│     ▼           ▼                                                        │
│   μ_{f_a→x_i}  μ_{f_a→x_k}                                              │
│                                                                         │
│  消息计算:                                                              │
│  μ_{f_a→x_j}(x_j) = Σ_{x_i,x_k} f_a(x_i,x_j,x_k) · μ_{x_i→f_a}(x_i) · μ_{x_k→f_a}(x_k)  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

3.3 边缘概率计算

当消息传递完成后(树结构下为一次遍历),每个变量的边缘分布可以计算为:

定理3.1(和积算法的正确性):对于树结构的因子图,和积算法计算得到的边缘分布是精确的。

证明梗概:通过归纳法证明每条消息的正确性。基础情况是叶子节点的消息;归纳步骤假设所有已计算消息正确,证明相邻消息正确。

3.4 归一化常数的计算

配分函数 可以通过聚合根节点的消息来计算:

其中 是因子 的局部归一化常数。

3.5 对数域消息传递

为了数值稳定性,实际实现中常使用对数域的消息传递:

然而,加法的对数操作需要替换为 Log-Sum-Exp 操作:

3.6 最大乘积算法(Max-Product Algorithm)

MAP推断(最大后验概率推断)将求和替换为取最大值:

消息更新规则变为:

最大乘积算法与Viterbi算法有深刻联系,在序列标注问题中广泛应用。


4 置信传播(Belief Propagation)

4.1 置信消息的定义

定义4.1(置信):变量 置信(belief)定义为对其边缘分布的近似:

类似地,因子 的置信为:

4.2 消息调度策略

消息传递的顺序(调度)对算法收敛速度有重要影响。

4.2.1 同步消息传递

所有消息同时更新:

优点:易于并行化
缺点:可能振荡,难以收敛

4.2.2 异步消息传递(洪水算法)

每次只更新一个节点的消息,顺序执行。

优点:通常收敛更快
缺点:难以并行化

4.2.3 残差消息传递(Residual Belief Propagation)

每次选择变化最大的消息优先更新:

优先更新具有最大残差的消息。

4.3 收敛性分析

4.3.1 树结构的收敛保证

定理4.1:对于树结构的因子图,和积算法在单次遍历后收敛到精确的边缘分布。

4.3.2 循环结构的挑战

当因子图存在环(cycles)时,消息传递可能:

  1. 收敛到近似解
  2. 振荡而不收敛
  3. 发散(数值不稳定)

4.3.3 收敛的充分条件

定理4.2(势函数有界性):如果所有因子势函数被常数 上下界,即 ,则循环置信传播(LBP)收敛。

直觉:归一化的消息空间是紧的,消息映射是连续的,因此存在不动点。

4.4 循环置信传播(Loopy Belief Propagation)

对于有环因子图,循环置信传播是一种近似推断方法:

class LoopyBeliefPropagation:
    """
    循环置信传播实现
    
    适用于有环图模型的近似推断
    """
    
    def __init__(self, num_states, damping=0.5):
        """
        Args:
            num_states: 每个变量的状态数
            damping: 阻尼因子 (0-1),用于加速收敛
        """
        self.num_states = num_states
        self.damping = damping
    
    def run(self, factors, adjacency, max_iter=100, tol=1e-6):
        """
        执行循环置信传播
        
        Args:
            factors: dict {factor_id: {'vars': [var_ids], 'potential': array}}
            adjacency: dict {var_id: [factor_ids]}
            max_iter: 最大迭代次数
            tol: 收敛容忍度
        
        Returns:
            beliefs: dict {var_id: belief_array}
        """
        num_vars = len(adjacency)
        
        # 初始化消息
        messages_f_to_x = {}  # (factor_id, var_id) -> message
        messages_x_to_f = {}  # (var_id, factor_id) -> message
        
        for _ in range(max_iter):
            max_change = 0
            
            # 更新所有因子到变量的消息
            for f_id, factor in factors.items():
                vars_in_factor = factor['vars']
                potential = factor['potential']
                
                for i, x_id in enumerate(vars_in_factor):
                    # 计算新消息
                    new_msg = self._compute_factor_to_var_message(
                        factor, i, vars_in_factor, messages_x_to_f
                    )
                    
                    old_msg = messages_f_to_x.get((f_id, x_id), 
                                                   torch.ones(self.num_states))
                    
                    # 阻尼更新
                    damped_msg = (1 - self.damping) * new_msg + self.damping * old_msg
                    
                    # 归一化
                    damped_msg = damped_msg / damped_msg.sum()
                    
                    messages_f_to_x[(f_id, x_id)] = damped_msg
                    
                    max_change = max(max_change, 
                                   torch.abs(damped_msg - old_msg).max().item())
            
            # 更新所有变量到因子的消息
            for x_id, factor_ids in adjacency.items():
                for f_id in factor_ids:
                    # 收集来自其他因子的消息
                    other_msg = torch.ones(self.num_states)
                    for other_f_id in factor_ids:
                        if other_f_id != f_id:
                            if (other_f_id, x_id) in messages_f_to_x:
                                other_msg = other_msg * messages_f_to_x[(other_f_id, x_id)]
                    
                    old_msg = messages_x_to_f.get((x_id, f_id), 
                                                   torch.ones(self.num_states))
                    
                    # 阻尼更新
                    damped_msg = (1 - self.damping) * other_msg + self.damping * old_msg
                    
                    # 归一化
                    damped_msg = damped_msg / damped_msg.sum()
                    
                    messages_x_to_f[(x_id, f_id)] = damped_msg
            
            # 检查收敛
            if max_change < tol:
                print(f"Converged after {_+1} iterations")
                break
        
        # 计算最终信念
        beliefs = {}
        for x_id in range(num_vars):
            belief = torch.ones(self.num_states)
            for f_id in adjacency[x_id]:
                if (f_id, x_id) in messages_f_to_x:
                    belief = belief * messages_f_to_x[(f_id, x_id)]
            beliefs[x_id] = belief / belief.sum()
        
        return beliefs
    
    def _compute_factor_to_var_message(self, factor, var_idx, vars_in_factor, messages):
        """计算因子到变量的消息"""
        potential = factor['potential']
        num_vars = len(vars_in_factor)
        
        # 对所有其他变量求和/积分
        result = torch.zeros(self.num_states)
        
        # 简化实现:假设二元因子
        if num_vars == 2:
            other_idx = 1 - var_idx
            for s1 in range(self.num_states):
                for s2 in range(self.num_states):
                    if var_idx == 0:
                        msg_val = potential[s1, s2]
                    else:
                        msg_val = potential[s2, s1]
                    
                    # 乘以来自其他变量的消息
                    if (f"var_{other_idx}", f"factor_{factor['id']}") in messages:
                        other_msg = messages[(f"var_{other_idx}", f"factor_{factor['id']}")][s2]
                        msg_val = msg_val * other_msg
                    
                    result[s1] += msg_val
        
        return result

4.5 置信传播的变体

方法描述适用范围
标准BP精确消息传递树结构
循环BP迭代直到收敛有环图
衰减BP阻尼因子稳定化振荡问题
Tree-Reweighted BP加权消息边界低纠缠图
Fractional BP分数加权近似精度控制

5 高斯消息传递

5.1 线性高斯模型

高斯消息传递是处理线性高斯模型精确推断的强大工具。考虑以下模型:

其中均值 ,协方差 是正定矩阵。

5.2 高斯分布的消息表示

高斯分布可以用自然参数表示:

定义自然参数

  • 精度矩阵
  • 信息向量

则高斯分布可以写为:

5.3 高斯消息传递规则

5.3.1 乘积操作

两个高斯分布的乘积仍是高斯分布:

意义:这正是变量节点的消息传递规则——乘积对应信息组合。

5.3.2 边际化操作

高斯分布的边际化也是高斯分布:

,联合分布为:

则边缘分布为:

其中 是边缘精度矩阵。

意义:这正是因子节点的消息传递规则——边际化对应信息传递。

5.4 Kalman滤波器作为高斯BP

卡尔曼滤波器是高斯BP在时序模型中的特例。

5.4.1 状态空间模型

5.4.2 Kalman滤波的消息传递视角

时间步 t-1                    时间步 t
    │                              │
    ▼                              ▼
┌────────┐                    ┌────────┐
│ 状态先验 │                    │ 观测模型 │
│ x_{t-1}│                    │  y_t   │
└────────┘                    └────────┘
    │                              ▲
    │ μ_{t-1→t}                    │
    ▼                              │
┌────────┐                        │
│ 状态转移 │                        │
│  A      │                        │
└────────┘                        │
    │                              │
    ▼                              │
┌────────┐                    ┌────────┐
│ 预测分布 │ ──────────────────→ │ 更新分布 │
│ p(x_t)  │                    │ p(x_t|y_t) │
└────────┘                    └────────┘

预测步(消息从 传到 ):

更新步(融合观测):

5.5 与变分推断的联系

高斯消息传递与变分推断中的平均场方法有深刻联系。

平均场假设下的变分分布:

变分消息更新(对数域):

对于高斯模型,这简化为矩匹配(moment matching):


6 因子图与神经网络的对应关系

6.1 MPNN的消息传递机制

消息传递神经网络(MPNN)与因子图消息传递有着形式上的深刻对应:

MPNN组件因子图BP组件数学联系
消息函数 因子势函数 可学习的局部变换
聚合操作 求和操作 邻居信息组合
更新函数 置信计算状态更新
迭代层数消息传递步数信息传播范围

6.2 消息传递视角下的GNN

以图注意力网络(GAT)为例,其消息传递可以解释为软化的因子图消息传递

其中注意力系数 可以视为动态调整的势函数权重。

6.3 神经网络层与消息传递的统一

我们可以将神经网络层统一理解为参数化的消息传递

关键区别

方面因子图BP神经网络
消息函数由概率模型固定可学习
目标推断后验分布最小化任务损失
优化闭式/迭代梯度下降

6.4 可微分置信传播

现代深度学习框架可以将BP嵌入神经网络,实现可微分的消息传递

class DifferentiableMessagePassing(nn.Module):
    """
    可微分消息传递层
    
    将因子图BP的消息传递操作参数化
    """
    
    def __init__(self, node_dim, hidden_dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # 消息函数(替代因子势函数)
        self.message_net = nn.Sequential(
            nn.Linear(node_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 注意力网络(动态势函数)
        self.attention_net = nn.Sequential(
            nn.Linear(node_dim * 2, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, num_heads)
        )
        
        # 更新函数
        self.update_net = nn.Sequential(
            nn.Linear(node_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )
    
    def forward(self, x, edge_index):
        """
        Args:
            x: 节点特征 (num_nodes, node_dim)
            edge_index: 边索引 (2, num_edges)
        
        Returns:
            更新后的节点特征 (num_nodes, hidden_dim)
        """
        row, col = edge_index  # 源节点 -> 目标节点
        
        # 计算消息
        source_features = x[row]  # (num_edges, node_dim)
        target_features = x[col]  # (num_edges, node_dim)
        
        # 拼接源和目标特征
        combined = torch.cat([source_features, target_features], dim=-1)
        
        # 注意力权重
        att_scores = self.attention_net(combined)  # (num_edges, num_heads)
        att_weights = F.softmax(att_scores, dim=0)  # (num_edges, num_heads)
        
        # 消息内容
        messages = self.message_net(combined)  # (num_edges, hidden_dim)
        messages = messages.view(-1, self.num_heads, self.head_dim)
        
        # 加权消息
        weighted_messages = messages * att_weights.unsqueeze(-1)  # (num_edges, heads, head_dim)
        
        # 聚合消息(按头聚合)
        aggregated = weighted_messages.sum(dim=0)  # (num_nodes, hidden_dim)
        
        # 更新节点特征
        updated = self.update_net(torch.cat([x[:, :self.num_heads * self.head_dim], aggregated], dim=-1))
        
        return updated

6.5 神经网络反向传播作为消息传递

神经网络中的反向传播算法与因子图BP有深刻的数学联系:

正向传播(因子图视角):

  • 每层是一个因子节点
  • 神经元激活是变量
  • 权重是因子势函数

反向传播(消息传递视角):

  • 梯度从输出传回输入
  • 链式法则对应消息组合
  • 每层的局部梯度是消息函数
class BPAsMessagePassing:
    """
    将反向传播解释为消息传递
    
    展示BP与因子图BP的数学联系
    """
    
    def forward_message(self, x, W):
        """
        正向消息:x -> z = Wx + b
        
        等价于因子节点的消息传递
        """
        z = torch.matmul(x, W.T)
        return z
    
    def backward_message(self, grad_z, W):
        """
        反向消息:∂L/∂x = (∂L/∂z) · W
        
        等价于变量节点的消息传递
        """
        grad_x = torch.matmul(grad_z, W)
        return grad_x
    
    def weight_gradient_message(self, grad_z, x):
        """
        权重梯度:∂L/∂W = (∂L/∂z)^T · x
        
        等价于因子节点的边际化
        """
        grad_W = torch.matmul(grad_z.T, x)
        return grad_W

7 现代扩展

7.1 期望传播(Expectation Propagation)

期望传播(EP)由Minka提出,是BP的变分扩展,适用于难以精确边际化的因子。

7.1.1 基本思想

EP用指数族分布的乘积近似后验分布:

每个 是对应因子的瘦息分布(cavity distribution)。

7.1.2 EP更新规则

对于因子 ,EP的更新步骤:

  1. 构造瘦息分布
  1. 计算 tilt 分布
  1. 匹配矩

其中 是指数族分布族。

7.1.3 高斯EP的实现

class ExpectationPropagation:
    """
    期望传播实现(高斯情况)
    """
    
    def __init__(self, num_vars, num_factors):
        self.num_vars = num_vars
        self.factors = {}  # factor_id -> potential function
        
        # 初始化瘦息分布参数
        self.cavity_precision = {}  # (factor_id, var_id) -> precision
        self.cavity_mean = {}       # (factor_id, var_id) -> mean
    
    def cavity_update(self, factor_id, var_id):
        """
        构造瘦息分布
        
        q_{-a}(x) = Π_{b≠a} q_b(x)
        
        对于高斯分布,乘积对应精度矩阵和均值向量的加法
        """
        # 从当前信念中移除该因子的贡献
        # 简化实现:假设单变量情况
        pass
    
    def moment_match(self, factor_id, var_id, tilt_dist):
        """
        矩匹配
        
        从tilt分布计算均值和方差,更新瘦息分布
        """
        # 计算tilt分布的矩
        new_mean = tilt_dist.mean()
        new_var = tilt_dist.variance()
        
        # 更新瘦息分布参数
        return new_mean, new_var
    
    def run(self, max_iter=100, tol=1e-6):
        """运行EP迭代"""
        for iteration in range(max_iter):
            max_change = 0
            
            for factor_id, factor in self.factors.items():
                for var_id in factor['vars']:
                    # 1. 计算瘦息分布
                    cavity = self.cavity_update(factor_id, var_id)
                    
                    # 2. 计算tilt分布
                    tilt = self.compute_tilt_distribution(factor, cavity)
                    
                    # 3. 矩匹配
                    new_mean, new_var = self.moment_match(factor_id, var_id, tilt)
                    
                    max_change = max(max_change, 
                                   abs(new_mean - cavity.mean) + 
                                   abs(new_var - cavity.variance))
            
            if max_change < tol:
                print(f"EP converged after {iteration + 1} iterations")
                break
        
        return self.compute_posterior()

7.2 变分消息传递

变分消息传递是变分推断与消息传递的结合,适用于大规模近似推断。

7.2.1 平均场变分推断

假设变分分布可分解:

最小化 等价于最大化ELBO:

7.2.2 变分消息更新

变量 的最优变分分布

这正是消息传递框架中的消息计算

class VariationalMessagePassing:
    """
    变分消息传递
    
    实现平均场变分推断的消息传递形式
    """
    
    def __init__(self, num_vars, num_states):
        self.num_vars = num_vars
        self.num_states = num_states
        
        # 变分参数
        self.variational_params = nn.Parameter(
            torch.randn(num_vars, num_states)
        )
    
    def compute_expected_log_potential(self, factor, q_dists):
        """
        计算 E_q[log f_a(X_a)]
        
        这是变分消息的核心计算
        """
        vars_in_factor = factor['vars']
        potential = factor['potential']
        
        expected = 0.0
        for state_config in itertools.product(range(self.num_states), repeat=len(vars_in_factor)):
            # 计算该配置的log势能
            log_pot = potential[state_config]
            
            # 计算各变量的变分概率
            for i, var_id in enumerate(vars_in_factor):
                log_pot += torch.log(q_dists[var_id][state_config[i]])
            
            expected += torch.exp(log_pot)
        
        return expected
    
    def variational_message_update(self, var_id, factors, q_dists):
        """
        变量节点的变分消息更新
        
        log q_i*(x_i) ∝ E_{q_-i}[log p(X)]
        """
        new_log_q = torch.zeros(self.num_states)
        
        for factor in factors:
            if var_id not in factor['vars']:
                continue
            
            # 计算期望log势能
            expected_log_pot = self.compute_expected_log_potential(factor, q_dists)
            new_log_q += expected_log_pot
        
        # 归一化
        new_q = F.softmax(new_log_q, dim=-1)
        return new_q
    
    def run_variational_inference(self, factors, max_iter=100):
        """
        运行变分推断
        """
        q_dists = [torch.ones(self.num_states) / self.num_states 
                   for _ in range(self.num_vars)]
        
        for iteration in range(max_iter):
            new_q_dists = []
            
            for var_id in range(self.num_vars):
                # 收集涉及该变量的因子
                relevant_factors = [f for f in factors if var_id in f['vars']]
                
                # 更新变分分布
                new_q = self.variational_message_update(var_id, relevant_factors, q_dists)
                new_q_dists.append(new_q)
            
            # 计算变化
            max_change = max(torch.abs(new_q - old_q).max() 
                           for new_q, old_q in zip(new_q_dists, q_dists))
            
            q_dists = new_q_dists
            
            if max_change < 1e-6:
                print(f"Converged after {iteration + 1} iterations")
                break
        
        return q_dists

7.3 粒子消息传递

粒子消息传递(Particle Message Passing)使用蒙特卡洛采样近似消息,适用于复杂非共轭模型。

7.3.1 基本思想

用粒子集合 表示分布:

7.3.2 粒子消息更新

class ParticleMessagePassing:
    """
    粒子消息传递
    
    使用重要性采样近似消息传递
    """
    
    def __init__(self, num_particles=100):
        self.num_particles = num_particles
    
    def sample_particles(self, proposal_dist, num_samples):
        """从提议分布采样粒子"""
        samples = proposal_dist.sample((num_samples,))
        return samples
    
    def compute_importance_weights(self, target_log_prob, proposal_log_prob, samples):
        """
        计算重要性权重
        
        w ∝ p(x) / q(x)
        """
        target_log_probs = target_log_prob(samples)
        proposal_log_probs = proposal_log_prob(samples)
        
        log_weights = target_log_probs - proposal_log_probs
        weights = F.softmax(log_weights, dim=0)
        
        return weights
    
    def particle_message_update(self, factor, particles, weights):
        """
        因子到变量的粒子消息更新
        
        μ_{f→x}(x) ≈ Σ_s w_s f(x, x_-^{(s)}) δ(x - x^{(s)})
        """
        # 计算每个粒子的势函数值
        factor_values = factor.potential(particles)  # (num_particles,)
        
        # 加权
        weighted_values = weights * factor_values
        
        # 重采样(可选)
        new_particles = self.resample(particles, weights)
        
        return new_particles, weighted_values
    
    def resample(self, particles, weights):
        """多项式重采样"""
        num_particles = particles.shape[0]
        indices = torch.multinomial(weights, num_particles, replacement=True)
        return particles[indices]
    
    def run(self, factors, num_iter=10):
        """运行粒子消息传递"""
        # 初始化粒子
        particles = {i: torch.randn(self.num_particles) 
                     for i in range(num_vars)}
        weights = {i: torch.ones(self.num_particles) / self.num_particles 
                   for i in range(num_vars)}
        
        for _ in range(num_iter):
            # 消息传递迭代
            for factor in factors:
                # 更新涉及该因子的变量的粒子
                pass
        
        return particles, weights

7.4 方法比较

方法消息形式适用范围计算复杂度
精确BP闭式树结构
循环BP迭代有环图取决于收敛
期望传播指数族近似一般图 每步
变分消息传递变分分布大规模可并行
粒子消息传递粒子集合非共轭模型 每步

8 PyTorch完整实现

8.1 因子图类定义

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import itertools
 
 
@dataclass
class VariableNode:
    """变量节点"""
    id: int
    name: str
    num_states: int
    domain: Optional[torch.Tensor] = None  # 连续变量的取值范围
 
 
@dataclass
class FactorNode:
    """因子节点"""
    id: int
    name: str
    variable_ids: List[int]
    potential: torch.Tensor  # 势函数(未归一化)
 
 
class FactorGraph(nn.Module):
    """
    因子图类
    
    支持构建、消息传递和推断
    """
    
    def __init__(self, name="FactorGraph"):
        super().__init__()
        self.name = name
        self.variables: Dict[int, VariableNode] = {}
        self.factors: Dict[int, FactorNode] = {}
        self.adjacency: Dict[int, List[int]] = {}  # var_id -> [factor_ids]
        self.factor_to_vars: Dict[int, List[int]] = {}  # factor_id -> [var_ids]
        
        # 消息缓存
        self.messages_var_to_factor: Dict[Tuple[int, int], torch.Tensor] = {}
        self.messages_factor_to_var: Dict[Tuple[int, int], torch.Tensor] = {}
    
    def add_variable(self, var_id: int, name: str, num_states: int = None, 
                     domain: torch.Tensor = None):
        """添加变量节点"""
        if num_states is None and domain is None:
            raise ValueError("Must specify either num_states or domain")
        
        self.variables[var_id] = VariableNode(
            id=var_id,
            name=name,
            num_states=num_states or len(domain),
            domain=domain
        )
        self.adjacency[var_id] = []
    
    def add_factor(self, factor_id: int, name: str, variable_ids: List[int],
                   potential: torch.Tensor):
        """
        添加因子节点
        
        Args:
            potential: 势函数张量,维度与variable_ids对应
        """
        self.factors[factor_id] = FactorNode(
            id=factor_id,
            name=name,
            variable_ids=variable_ids,
            potential=potential
        )
        self.factor_to_vars[factor_id] = variable_ids
        
        for var_id in variable_ids:
            if var_id not in self.adjacency:
                self.adjacency[var_id] = []
            self.adjacency[var_id].append(factor_id)
    
    def get_variable(self, var_id: int) -> VariableNode:
        return self.variables[var_id]
    
    def get_factor(self, factor_id: int) -> FactorNode:
        return self.factors[factor_id]

8.2 和积算法实现

class SumProductAlgorithm:
    """
    和积算法(Sum-Product Algorithm)实现
    
    支持树结构和有环图的近似推断
    """
    
    def __init__(self, graph: FactorGraph, damping: float = 0.0):
        """
        Args:
            graph: 因子图
            damping: 阻尼因子 (0-1),用于LBP稳定化
        """
        self.graph = graph
        self.damping = damping
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def initialize_messages(self):
        """初始化消息"""
        # 变量到因子的消息初始化为均匀分布
        for var_id, var in self.graph.variables.items():
            for factor_id in self.graph.adjacency[var_id]:
                key = (var_id, factor_id)
                self.graph.messages_var_to_factor[key] = torch.ones(
                    var.num_states, device=self.device
                ) / var.num_states
        
        # 因子到变量的消息初始化为因子势函数
        for factor_id, factor in self.graph.factors.items():
            for var_id in factor.variable_ids:
                key = (factor_id, var_id)
                # 边缘化势函数到该变量
                msg = self._marginalize_potential(factor, var_id)
                self.graph.messages_factor_to_var[key] = msg
    
    def _marginalize_potential(self, factor: FactorNode, target_var_id: int) -> torch.Tensor:
        """将势函数边缘化到目标变量"""
        potential = factor.potential.to(self.device)
        var_ids = factor.variable_ids
        target_idx = var_ids.index(target_var_id)
        
        # 对所有其他维度求和
        axes = [i for i in range(len(var_ids)) if i != target_idx]
        if axes:
            marginal = torch.sum(potential, dim=axes)
        else:
            marginal = potential
        
        # 归一化
        marginal = marginal / marginal.sum()
        
        return marginal
    
    def compute_variable_to_factor_message(self, var_id: int, 
                                           factor_id: int) -> torch.Tensor:
        """计算变量到因子的消息"""
        var = self.graph.variables[var_id]
        
        # 消息是所有其他因子传来消息的乘积
        msg = torch.ones(var.num_states, device=self.device)
        
        for other_factor_id in self.graph.adjacency[var_id]:
            if other_factor_id != factor_id:
                key = (other_factor_id, var_id)
                if key in self.graph.messages_factor_to_var:
                    msg = msg * self.graph.messages_factor_to_var[key]
        
        # 归一化
        msg = msg / msg.sum()
        
        return msg
    
    def compute_factor_to_variable_message(self, factor_id: int,
                                           var_id: int) -> torch.Tensor:
        """计算因子到变量的消息"""
        factor = self.graph.factors[factor_id]
        var_ids = factor.variable_ids
        var_idx = var_ids.index(var_id)
        
        # 获取势函数
        potential = factor.potential.to(self.device)
        
        # 计算消息:边缘化势函数并乘以传入消息
        # 简化实现:假设势函数维度不大
        num_vars = len(var_ids)
        num_states = self.graph.variables[var_id].num_states
        
        if num_vars == 1:
            # 一元因子
            msg = potential
        elif num_vars == 2:
            # 二元因子
            other_var_id = var_ids[1 - var_idx]
            other_msg = self.graph.messages_var_to_factor[(other_var_id, factor_id)]
            
            if var_idx == 0:
                msg = torch.sum(potential * other_msg.unsqueeze(1), dim=1)
            else:
                msg = torch.sum(potential * other_msg.unsqueeze(0), dim=0)
        else:
            # 通用实现
            msg = self._general_factor_message(potential, var_ids, var_idx)
        
        # 归一化
        msg = msg / msg.sum()
        
        return msg
    
    def _general_factor_message(self, potential: torch.Tensor, 
                                 var_ids: List[int], 
                                 target_idx: int) -> torch.Tensor:
        """通用因子消息计算(支持任意数量变量)"""
        # 收集所有传入消息
        incoming_messages = []
        axes_to_sum = []
        
        for i, var_id in enumerate(var_ids):
            if i != target_idx:
                msg = self.graph.messages_var_to_factor[(var_id, self.graph.factors[var_ids[0]].id)]
                incoming_messages.append((i, msg))
                axes_to_sum.append(i)
        
        # 乘以势函数
        result = potential.clone()
        for axis, msg in incoming_messages:
            # 为消息添加维度以便广播
            shape = [1] * len(var_ids)
            shape[axis] = -1
            result = result * msg.view(shape)
        
        # 求和边缘化
        msg = torch.sum(result, dim=axes_to_sum)
        
        return msg
    
    def run_tree_bp(self) -> Dict[int, torch.Tensor]:
        """
        在树结构图上运行和积算法
        
        Returns:
            beliefs: 每个变量的边缘分布
        """
        # 找到根节点(选择第一个变量)
        root_var_id = list(self.graph.variables.keys())[0]
        
        # 计算节点顺序用于后序遍历
        parent_map, order = self._get_traversal_order(root_var_id)
        
        # 自底向上:计算向叶子方向的消息
        for var_id in reversed(order):
            for factor_id in self.graph.adjacency[var_id]:
                if parent_map.get(var_id) != factor_id:
                    # 这是一个叶子方向的因子
                    pass
        
        # 自顶向下:传递向根方向的消息
        for var_id in order:
            parent_factor = parent_map.get(var_id)
            if parent_factor is not None:
                # 计算从父因子到该变量的消息
                msg = self.compute_factor_to_variable_message(parent_factor, var_id)
                self.graph.messages_factor_to_var[(parent_factor, var_id)] = msg
        
        # 计算信念
        return self.compute_beliefs()
    
    def _get_traversal_order(self, root_var_id: int) -> Tuple[Dict, List]:
        """获取遍历顺序和父子关系"""
        parent_map = {root_var_id: None}
        order = [root_var_id]
        
        # BFS遍历
        queue = [root_var_id]
        while queue:
            var_id = queue.pop(0)
            for factor_id in self.graph.adjacency[var_id]:
                for next_var_id in self.graph.factor_to_vars[factor_id]:
                    if next_var_id not in parent_map:
                        parent_map[next_var_id] = factor_id
                        order.append(next_var_id)
                        queue.append(next_var_id)
        
        return parent_map, order
    
    def run_loopy_bp(self, max_iter: int = 100, tol: float = 1e-6,
                     schedule: str = 'random') -> Dict[int, torch.Tensor]:
        """
        在有环图上运行循环置信传播
        
        Args:
            max_iter: 最大迭代次数
            tol: 收敛容忍度
            schedule: 调度策略 ('random', 'residual', 'flooding')
        
        Returns:
            beliefs: 每个变量的边缘分布
        """
        self.initialize_messages()
        
        best_beliefs = None
        best_energy = float('inf')
        
        for iteration in range(max_iter):
            max_change = 0.0
            
            if schedule == 'random':
                # 随机调度
                var_ids = list(self.graph.variables.keys())
                np.random.shuffle(var_ids)
                
                for var_id in var_ids:
                    for factor_id in self.graph.adjacency[var_id]:
                        # 计算并更新消息
                        new_msg = self.compute_variable_to_factor_message(var_id, factor_id)
                        old_msg = self.graph.messages_var_to_factor.get(
                            (var_id, factor_id), torch.ones_like(new_msg)
                        )
                        
                        # 阻尼
                        if self.damping > 0:
                            new_msg = ((1 - self.damping) * new_msg + 
                                      self.damping * old_msg)
                        
                        change = torch.abs(new_msg - old_msg).max().item()
                        max_change = max(max_change, change)
                        
                        self.graph.messages_var_to_factor[(var_id, factor_id)] = new_msg
                        
                        # 更新反向消息
                        reverse_msg = self.compute_factor_to_variable_message(
                            factor_id, var_id
                        )
                        
                        if self.damping > 0:
                            old_reverse = self.graph.messages_factor_to_var.get(
                                (factor_id, var_id), torch.ones_like(reverse_msg)
                            )
                            reverse_msg = ((1 - self.damping) * reverse_msg + 
                                          self.damping * old_reverse)
                        
                        self.graph.messages_factor_to_var[(factor_id, var_id)] = reverse_msg
            
            elif schedule == 'flooding':
                # 洪水算法:同时更新所有消息
                new_var_to_factor = {}
                new_factor_to_var = {}
                
                for var_id, var in self.graph.variables.items():
                    for factor_id in self.graph.adjacency[var_id]:
                        new_var_to_factor[(var_id, factor_id)] = \
                            self.compute_variable_to_factor_message(var_id, factor_id)
                
                for factor_id, factor in self.graph.factors.items():
                    for var_id in factor.variable_ids:
                        new_factor_to_var[(factor_id, var_id)] = \
                            self.compute_factor_to_variable_message(factor_id, var_id)
                
                # 更新所有消息
                for key, msg in new_var_to_factor.items():
                    old_msg = self.graph.messages_var_to_factor.get(
                        key, torch.ones_like(msg)
                    )
                    change = torch.abs(msg - old_msg).max().item()
                    max_change = max(max_change, change)
                    
                    if self.damping > 0:
                        msg = (1 - self.damping) * msg + self.damping * old_msg
                    
                    self.graph.messages_var_to_factor[key] = msg
                
                for key, msg in new_factor_to_var.items():
                    old_msg = self.graph.messages_factor_to_var.get(
                        key, torch.ones_like(msg)
                    )
                    
                    if self.damping > 0:
                        msg = (1 - self.damping) * msg + self.damping * old_msg
                    
                    self.graph.messages_factor_to_var[key] = msg
            
            # 计算当前信念
            beliefs = self.compute_beliefs()
            
            # 计算伪对数似然(用于早停)
            energy = self._compute_pseudo_energy(beliefs)
            if energy < best_energy:
                best_energy = energy
                best_beliefs = {k: v.clone() for k, v in beliefs.items()}
            
            if max_change < tol:
                print(f"LBP converged after {iteration + 1} iterations")
                break
        
        return best_beliefs if best_beliefs else beliefs
    
    def compute_beliefs(self) -> Dict[int, torch.Tensor]:
        """计算所有变量的信念(边缘分布)"""
        beliefs = {}
        
        for var_id, var in self.graph.variables.items():
            belief = torch.ones(var.num_states, device=self.device)
            
            for factor_id in self.graph.adjacency[var_id]:
                key = (factor_id, var_id)
                if key in self.graph.messages_factor_to_var:
                    belief = belief * self.graph.messages_factor_to_var[key]
            
            # 归一化
            belief = belief / belief.sum()
            beliefs[var_id] = belief
        
        return beliefs
    
    def _compute_pseudo_energy(self, beliefs: Dict[int, torch.Tensor]) -> float:
        """
        计算伪能量(用于监测收敛)
        
        近似于负对数似然
        """
        energy = 0.0
        
        for factor_id, factor in self.graph.factors.items():
            # 计算因子的期望能量
            potential = factor.potential.to(self.device)
            
            for var_id in factor.variable_ids:
                belief = beliefs[var_id]
                # 简化:计算势函数的期望
                energy -= torch.sum(belief * torch.log(potential + 1e-10))
        
        return energy.item()

8.3 高斯消息传递实现

class GaussianMessagePassing:
    """
    高斯消息传递实现
    
    用于线性高斯模型的精确推断
    """
    
    def __init__(self, device=None):
        self.device = device or torch.device('cpu')
    
    def gaussian_product(self, lambda1: torch.Tensor, xi1: torch.Tensor,
                         lambda2: torch.Tensor, xi2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        高斯分布的乘积
        
        N(μ1, Σ1) * N(μ2, Σ2) = N(μ, Σ)
        
        其中:
        Σ = (Σ1⁻¹ + Σ2⁻¹)⁻¹
        μ = Σ (Σ1⁻¹μ1 + Σ2⁻¹μ2)
        
        自然参数形式:
        Λ = Λ1 + Λ2
        ξ = ξ1 + ξ2
        """
        # 精度矩阵形式
        Lambda = lambda1 + lambda2
        xi = xi1 + xi2
        
        return Lambda, xi
    
    def gaussian_marginalize(self, Lambda: torch.Tensor, xi: torch.Tensor,
                             indices: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        高斯分布的边际化
        
        保持指定索引的变量
        """
        Lambda_mm = Lambda[indices][:, indices]
        xi_m = xi[indices]
        
        return Lambda_mm, xi_m
    
    def kalman_update(self, mu: torch.Tensor, Sigma: torch.Tensor,
                      y: torch.Tensor, H: torch.Tensor, R: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Kalman更新(高斯BP的特殊情况)
        
        Args:
            mu: 先验均值
            Sigma: 先验协方差
            y: 观测值
            H: 观测矩阵
            R: 观测噪声协方差
        
        Returns:
            mu_posterior, Sigma_posterior: 后验均值和协方差
        """
        # 预测
        S = H @ Sigma @ H.T + R  # 观测预测协方差
        K = Sigma @ H.T @ torch.linalg.inv(S)  # Kalman增益
        
        # 更新
        mu_posterior = mu + K @ (y - H @ mu)
        Sigma_posterior = (torch.eye(mu.shape[0], device=self.device) - K @ H) @ Sigma
        
        return mu_posterior, Sigma_posterior
    
    def kalman_predict(self, mu: torch.Tensor, Sigma: torch.Tensor,
                       A: torch.Tensor, Q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Kalman预测(状态转移)
        """
        mu_pred = A @ mu
        Sigma_pred = A @ Sigma @ A.T + Q
        
        return mu_pred, Sigma_pred
    
    def run_linear_gaussian_bsf(self, observations: List[Tuple[int, torch.Tensor, torch.Tensor]],
                                  F: torch.Tensor, Q: torch.Tensor,
                                  H: torch.Tensor, R: torch.Tensor,
                                  x0: torch.Tensor, Sigma0: torch.Tensor) -> Tuple[List, List]:
        """
        运行线性高斯系统的最优估计(Kalman滤波/平滑)
        
        Args:
            observations: [(time, y, H), ...] 观测列表
            F: 状态转移矩阵
            Q: 过程噪声协方差
            H: 观测矩阵
            R: 观测噪声协方差
            x0: 初始均值
            Sigma0: 初始协方差
        
        Returns:
            filtered_means, filtered_covs: 滤波结果
            smoothed_means, smoothed_covs: 平滑结果
        """
        T = max(obs[0] for obs in observations) + 1
        
        # 初始化
        mu = x0
        Sigma = Sigma0
        
        filtered_means = []
        filtered_covs = []
        
        # 前向滤波
        for t in range(T):
            # 预测
            mu_pred, Sigma_pred = self.kalman_predict(mu, Sigma, F, Q)
            
            # 检查该时间是否有观测
            obs_t = [(y, H_t) for (time, y, H_t) in observations if time == t]
            
            if obs_t:
                for y, H_t in obs_t:
                    # 更新
                    mu, Sigma = self.kalman_update(mu_pred, Sigma_pred, y, H_t, R)
                    mu_pred, Sigma_pred = mu, Sigma  # 用于下一个观测
            else:
                mu, Sigma = mu_pred, Sigma_pred
            
            filtered_means.append(mu)
            filtered_covs.append(Sigma)
        
        # 后向平滑
        smoothed_means = filtered_means.copy()
        smoothed_covs = filtered_covs.copy()
        
        mu_smooth = filtered_means[-1]
        Sigma_smooth = filtered_covs[-1]
        
        for t in range(T - 2, -1, -1):
            # 预测
            mu_pred, Sigma_pred = self.kalman_predict(
                filtered_means[t], filtered_covs[t], F, Q
            )
            
            # 平滑增益
            try:
                G = filtered_covs[t] @ F.T @ torch.linalg.inv(Sigma_pred)
            except:
                G = filtered_covs[t] @ F.T @ (Sigma_pred + 1e-6 * torch.eye(Sigma_pred.shape[0], device=self.device)).inverse()
            
            # 平滑
            mu_smooth = filtered_means[t] + G @ (mu_smooth - mu_pred)
            Sigma_smooth = filtered_covs[t] + G @ (Sigma_smooth - Sigma_pred) @ G.T
            
            smoothed_means[t] = mu_smooth
            smoothed_covs[t] = Sigma_smooth
        
        return (filtered_means, filtered_covs), (smoothed_means, smoothed_covs)

8.4 使用示例

def example_simple():
    """简单示例:二元变量因子图"""
    
    # 创建因子图
    graph = FactorGraph("Simple Example")
    
    # 添加变量
    graph.add_variable(0, "x0", num_states=2)
    graph.add_variable(1, "x1", num_states=2)
    graph.add_variable(2, "x2", num_states=2)
    
    # 添加因子
    # f0(x0, x1) - 鼓励 x0 = x1
    potential01 = torch.tensor([[2.0, 0.5],
                                [0.5, 2.0]])
    graph.add_factor(0, "f01", [0, 1], potential01)
    
    # f1(x1, x2) - 鼓励 x1 = x2
    potential12 = torch.tensor([[2.0, 0.5],
                                [0.5, 2.0]])
    graph.add_factor(1, "f12", [1, 2], potential12)
    
    # f2(x0) - x0 的先验
    potential0 = torch.tensor([0.3, 0.7])
    graph.add_factor(2, "f0", [0], potential0)
    
    # 运行和积算法
    spa = SumProductAlgorithm(graph, damping=0.3)
    beliefs = spa.run_loopy_bp(max_iter=100, tol=1e-6)
    
    print("边缘分布:")
    for var_id, belief in beliefs.items():
        var = graph.get_variable(var_id)
        print(f"  {var.name}: {belief.numpy()}")
    
    return beliefs
 
 
def example_gaussian():
    """示例:高斯消息传递(Kalman滤波)"""
    
    # 状态空间模型参数
    dt = 0.1  # 时间步长
    F = torch.tensor([[1, dt],
                      [0, 1]])  # 状态转移
    
    Q = torch.tensor([[0.01, 0],
                      [0, 0.01]])  # 过程噪声
    
    H = torch.tensor([[1, 0]])  # 观测矩阵
    
    R = torch.tensor([[0.1]])  # 观测噪声
    
    # 初始状态
    x0 = torch.tensor([0, 1])
    Sigma0 = torch.tensor([[1, 0],
                           [0, 1]])
    
    # 生成观测数据
    np.random.seed(42)
    true_states = [x0.numpy()]
    observations = []
    
    for t in range(20):
        # 真实状态转移
        x_true = F @ torch.tensor(true_states[-1]) + np.random.randn(2) * 0.1
        true_states.append(x_true)
        
        # 观测
        y = H @ x_true + np.random.randn(1) * np.sqrt(R[0, 0])
        observations.append((t, torch.tensor(y), H))
    
    # 运行Kalman滤波/平滑
    gmp = GaussianMessagePassing()
    (filtered_means, filtered_covs), (smoothed_means, smoothed_covs) = \
        gmp.run_linear_gaussian_bsf(observations, F, Q, H, R, x0, Sigma0)
    
    print("滤波结果(显示前5个时间步):")
    for t in range(5):
        print(f"  t={t}: mean={filtered_means[t].numpy()}, std={np.sqrt(filtered_covs[t].numpy().diagonal())}")
    
    print("\n平滑结果(显示前5个时间步):")
    for t in range(5):
        print(f"  t={t}: mean={smoothed_means[t].numpy()}, std={np.sqrt(smoothed_covs[t].numpy().diagonal())}")
    
    return filtered_means, filtered_covs, smoothed_means, smoothed_covs
 
 
if __name__ == "__main__":
    print("=" * 60)
    print("示例1: 简单离散因子图")
    print("=" * 60)
    example_simple()
    
    print("\n" + "=" * 60)
    print("示例2: 高斯消息传递(Kalman滤波)")
    print("=" * 60)
    example_gaussian()

9 理论总结

9.1 核心概念回顾

概念定义关键性质
因子图变量-因子二部图分解联合分布
和积算法树结构的精确BP 推断
循环BP有环图的近似推断迭代收敛
高斯BP线性高斯的精确推断闭式解
期望传播指数族近似推断变分扩展

9.2 与深度学习的统一视角

消息传递机制是连接概率推断与深度学习的桥梁:

┌─────────────────────────────────────────────────────────────────────────┐
│                    消息传递的统一视角                                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌───────────────────┐      ┌───────────────────┐                       │
│  │  概率图模型        │      │  深度学习          │                       │
│  ├───────────────────┤      ├───────────────────┤                       │
│  │  因子节点 = 消息函数│      │  神经网络层        │                       │
│  │  变量节点 = 隐状态  │      │  神经元            │                       │
│  │  消息传递 = 推断   │      │  前向传播          │                       │
│  │  信念 = 边缘分布   │      │  激活值            │                       │
│  │  优化 = 最大化似然 │      │  梯度下降          │                       │
│  └───────────────────┘      └───────────────────┘                       │
│              \                      /                                   │
│               \                    /                                    │
│                ▼                  ▼                                     │
│        ┌─────────────────────────────────┐                              │
│        │     统一的消息传递框架           │                              │
│        │  h_v^{(l+1)} = Update(h_v^{(l)}, │                              │
│        │                 AGG_{u∈N(v)}     │                              │
│        │                 Message(h_u^{(l)}│                              │
│        │                 h_v^{(l)}, e_{uv}))                             │
│        └─────────────────────────────────┘                              │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

9.3 未来研究方向

  1. 可扩展性:大规模图的高效消息传递
  2. 异构图:多关系、多模态图的消息传递
  3. 动态图:时变图结构的消息传递
  4. 理论保证:收敛性、表达能力的形式化分析
  5. 与Transformer的融合:注意力机制作为软消息传递

参考文献


相关文档

Footnotes

  1. Kschischang, F. R., Frey, B. J., & Loeliger, H. A. (2001). Factor graphs and the sum-product algorithm. IEEE Transactions on Information Theory, 47(2), 498-519.