概述

因子图(Factor Graph)是一种统一表示概率图模型的二分图结构,能够同时自然地表示贝叶斯网络的有向依赖和马尔可夫随机场的无向依赖。1

本章系统阐述因子图的基本概念、和积算法(Sum-Product Algorithm),并探索其与神经网络反向传播之间的深刻联系。


因子图的基本定义

从贝叶斯网络到因子图

考虑一个联合概率分解:

传统的贝叶斯网络用有向图表示这种分解。但当变量之间存在复杂的条件依赖时,因子图提供了一种更灵活的表示方式。

因子图结构

因子图是一个二分图 ,其中:

要素符号说明
变量节点随机变量
因子节点局部势函数
连接变量与涉及它的因子

图示

    f₁        f₂        f₃        f₄
   / | \      / \        |        /
  X₁ X₂ X₃──X₂ X₃────X₃ X₄──X₄

因子图的数学表示

因子图表示的联合分布:

其中:

  • 是因子节点 的势函数(Potential Function)
  • 是与因子 相连的变量集合
  • 是配分函数

和积算法(Sum-Product Algorithm)

边缘概率的分解计算

目标:计算边缘概率

对于因子图表示:

消息传递框架

和积算法的核心是消息(Message)的递归计算。

消息类型

消息方向定义说明
变量→因子从变量 传递到因子
因子→变量从因子 传递到变量

变量节点的消息传递

当变量节点 连接到多个因子时,从变量到因子的消息是所有其他因子传来消息的乘积:

图示

        f_b               f_a
         |                 |
         |                 |
    ┌────┴────┐       ┌────┴────┐
    │ x_i     │       │         │
    └─────────┘       └─────────┘
    
    μ_{x_i→f_a} = ∏_{f_b∈N(x_i)\{f_a}} μ_{f_b→x_i}

因子节点的消息传递

因子节点 连接到变量 ,向变量 传递的消息:

边缘概率计算

当所有消息计算完成后,变量 的边缘概率:

最大后验概率(MAP)推断

和积算法也可用于MAP推断,将求和替换为取最大值:

对应的消息更新规则:


置信传播的矩阵形式

树结构的矩阵表示

当因子图是树结构时,置信传播可以用矩阵运算高效实现。

线性链CRF的置信传播

对于线性链条件随机场,定义:

  • 转移矩阵
  • 发射矩阵

前向消息

后向消息

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
 
class FactorGraphBP(nn.Module):
    """
    基于因子图的消息传递实现
    支持树结构的精确推断
    """
    def __init__(self, num_states, num_factors):
        super().__init__()
        self.num_states = num_states
        self.num_factors = num_factors
        
        # 因子势函数(可学习)
        self.factor_potentials = nn.Parameter(torch.randn(num_factors, num_states))
    
    def forward(self, adjacency_list):
        """
        执行置信传播
        
        Args:
            adjacency_list: 因子图邻接表
                [{'type': 'factor', 'id': 0, 'neighbors': [0, 1, 2]},
                 {'type': 'var', 'id': 0, 'neighbors': [0]}]
        
        Returns:
            beliefs: 每个变量的边缘概率分布
        """
        messages = {}
        
        # 初始化消息(从叶子节点开始)
        for node in adjacency_list:
            if node['type'] == 'var' and len(node['neighbors']) == 1:
                # 叶子变量节点,发送均匀消息
                messages[(node['id'], node['neighbors'][0])] = torch.ones(self.num_states)
        
        # 迭代消息传递(这里简化版本,实际需要多次迭代)
        for _ in range(len(adjacency_list)):
            for node in adjacency_list:
                if node['type'] == 'factor':
                    # 因子节点计算消息
                    for neighbor in node['neighbors']:
                        # 收集除neighbor外的所有消息
                        incoming = []
                        for other in node['neighbors']:
                            if other != neighbor:
                                if (other, node['id']) in messages:
                                    incoming.append(messages[(other, node['id'])])
                        
                        if incoming:
                            # 计算势函数与消息的乘积
                            product = torch.ones(self.num_states)
                            for msg in incoming:
                                product = product * msg
                            
                            # 乘以势函数并归一化
                            belief = product * torch.softmax(self.factor_potentials[node['id']], dim=0)
                            messages[(node['id'], neighbor)] = belief / belief.sum()
        
        # 计算边缘分布(信念)
        beliefs = {}
        for node in adjacency_list:
            if node['type'] == 'var':
                belief = torch.ones(self.num_states)
                for factor_id in node['neighbors']:
                    if (factor_id, node['id']) in messages:
                        belief = belief * messages[(factor_id, node['id'])]
                beliefs[node['id']] = belief / belief.sum()
        
        return beliefs
 
 
class LinearChainCRF(nn.Module):
    """
    线性链CRF的矩阵形式置信传播
    """
    def __init__(self, num_tags):
        super().__init__()
        self.num_tags = num_tags
        
        # 转移分数(未归一化的对数势函数)
        self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
        
        # 发射分数(由编码器提供)
        self._init_transitions()
    
    def _init_transitions(self):
        """初始化转移矩阵,限制转移的合理性"""
        nn.init.uniform_(self.transitions, -0.1, 0.1)
        # 禁止非法转移(如I-PAD不允许)
        self.transitions.data[0, 0] = -float('inf')  # PAD->PAD
    
    def forward_messages(self, emissions):
        """
        前向消息传递
        
        Args:
            emissions: (batch, seq_len, num_tags) 发射分数
        
        Returns:
            alpha: (batch, seq_len, num_tags) 前向消息
        """
        batch_size, seq_len, _ = emissions.shape
        
        # 初始化
        alpha = torch.zeros(batch_size, self.num_tags, device=emissions.device)
        alpha[:, 0] = emissions[:, 0]  # 初始分数
        
        # 递归计算
        for t in range(1, seq_len):
            # 发射分数 + 转移分数
            scores = alpha[:, t-1].unsqueeze(2) + self.transitions + emissions[:, t].unsqueeze(1)
            # Log-Sum-Exp 稳定性
            alpha[:, t] = torch.logsumexp(scores, dim=1)
        
        return alpha
    
    def backward_messages(self, emissions):
        """
        后向消息传递
        
        Args:
            emissions: (batch, seq_len, num_tags) 发射分数
        
        Returns:
            beta: (batch, seq_len, num_tags) 后向消息
        """
        batch_size, seq_len, _ = emissions.shape
        
        beta = torch.zeros(batch_size, self.num_tags, device=emissions.device)
        
        for t in range(seq_len - 2, -1, -1):
            scores = self.transitions + emissions[:, t+1].unsqueeze(1) + beta[:, t+1].unsqueeze(2)
            beta[:, t] = torch.logsumexp(scores, dim=2)
        
        return beta
    
    def forward(self, emissions):
        """
        计算对数似然
        
        Args:
            emissions: (batch, seq_len, num_tags)
        
        Returns:
            log_likelihood: (batch,)
        """
        alpha = self.forward_messages(emissions)
        return torch.logsumexp(alpha[:, -1], dim=1)
    
    def decode(self, emissions):
        """
        Viterbi解码
        
        Args:
            emissions: (batch, seq_len, num_tags)
        
        Returns:
            best_paths: (batch, seq_len) 最优路径
        """
        batch_size, seq_len, _ = emissions.shape
        backpointers = []
        
        # 前向递推
        viterbi = torch.zeros_like(emissions)
        viterbi[:, 0] = emissions[:, 0]
        
        for t in range(1, seq_len):
            # (batch, tag_j) + (tag_j, tag_i) -> (batch, tag_i)
            scores = viterbi[:, t-1].unsqueeze(2) + self.transitions + emissions[:, t].unsqueeze(1)
            best_scores, best_tags = scores.max(dim=1)
            viterbi[:, t] = best_scores
            backpointers.append(best_tags)
        
        # 回溯
        best_path = torch.zeros(batch_size, seq_len, dtype=torch.long, device=emissions.device)
        best_last_tag = viterbi[:, -1].argmax(dim=1)
        best_path[:, -1] = best_last_tag
        
        for t in range(seq_len - 2, -1, -1):
            best_path[:, t] = backpointers[t].gather(1, best_path[:, t+1].unsqueeze(1)).squeeze(1)
        
        return best_path

置信传播与反向传播的联系

结构类比

置信传播和反向传播在结构上有深刻的联系:

方面置信传播反向传播
消息概率分布/边缘分布梯度
消息组合乘积/求和乘积/链式法则
局部计算因子势函数局部梯度
并行性节点间并行层间并行

数学形式的统一

置信传播的消息更新

反向传播的梯度传递

可微分置信传播

现代深度学习框架可以将置信传播嵌入神经网络中:

class DifferentiableBP(nn.Module):
    """
    可微分的置信传播层
    使用Gumbel-Softmax等技术实现离散消息的软化
    """
    def __init__(self, num_states, temperature=1.0):
        super().__init__()
        self.num_states = num_states
        self.temperature = temperature
        
        # 因子势函数参数化
        self.factor_params = nn.Parameter(torch.randn(num_states))
    
    def soft_message(self, logits):
        """
        使用Gumbel-Softmax生成软消息
        """
        return F.gumbel_softmax(logits, tau=self.temperature, hard=False)
    
    def forward(self, incoming_messages):
        """
        软消息传递
        
        Args:
            incoming_messages: list of (num_states,) 传入的消息
        
        Returns:
            out_message: (num_states,) 传出的软消息
        """
        # 乘积操作(对数空间更稳定)
        log_sum = torch.logsumexp(torch.stack(incoming_messages), dim=0)
        
        # 加上势函数
        combined = log_sum + self.factor_params
        
        # 软最大化
        return F.softmax(combined, dim=0)

循环置信传播与变分推断

循环结构的挑战

当因子图存在环时,精确置信传播不再适用。需要使用近似方法:

方法原理适用范围
循环置信传播(LBP)迭代消息传递直到收敛松弛环、结构先验
变分推断用变分分布近似后验一般图模型
粒子方法用粒子近似消息大规模问题

循环置信传播的实践

class LoopyBeliefPropagation(nn.Module):
    """
    循环置信传播实现
    适用于有环图模型
    """
    def __init__(self, num_nodes, num_states):
        super().__init__()
        self.num_nodes = num_nodes
        self.num_states = num_states
        
        # 初始化消息(每个节点对)
        self.messages = nn.Parameter(torch.randn(num_nodes, num_nodes, num_states))
    
    def run_lbp(self, node_potentials, num_iter=10):
        """
        执行循环置信传播
        
        Args:
            node_potentials: (num_nodes, num_states) 节点势函数
            num_iter: 迭代次数
        
        Returns:
            beliefs: (num_nodes, num_states) 边缘分布
        """
        messages = self.messages.clone()
        
        for iteration in range(num_iter):
            new_messages = []
            
            for i in range(self.num_nodes):
                for j in range(self.num_nodes):
                    if i == j:
                        continue
                    
                    # 收集除j外的所有消息
                    incoming = []
                    for k in range(self.num_nodes):
                        if k != i and k != j:
                            incoming.append(messages[k, i])
                    
                    # 计算新消息
                    # 消息 = 节点势函数 * 传入消息的乘积
                    belief = node_potentials[i].clone()
                    for msg in incoming:
                        belief = belief * torch.softmax(msg, dim=-1)
                    
                    # 归一化
                    belief = F.softmax(belief, dim=-1)
                    new_messages.append((i, j, belief))
            
            # 更新消息
            for i, j, msg in new_messages:
                messages[i, j] = msg
        
        # 计算信念(边缘分布)
        beliefs = []
        for i in range(self.num_nodes):
            belief = node_potentials[i].clone()
            for j in range(self.num_nodes):
                if i != j:
                    belief = belief * torch.softmax(messages[j, i], dim=-1)
            beliefs.append(F.softmax(belief, dim=-1))
        
        return torch.stack(beliefs)
    
    def compute_marginals(self, beliefs):
        """
        计算归一化的边缘分布
        """
        return F.softmax(beliefs, dim=-1)

应用场景

1. 自然语言处理

  • 依存句法分析:将句法树建模为因子图
  • 命名实体识别:使用线性链CRF
  • 机器翻译:用于 SMT 的短语对齐

2. 计算机视觉

  • 图像分割:像素级CRF后处理
  • 姿态估计:人体关节点的关系建模
  • 场景解析:物体间的空间关系

3. 语音识别

  • 隐马尔可夫模型:经典的序列建模
  • 端到端模型:CTC与注意力机制的融合

4. 深度学习中的可视化

  • 神经网络剪枝:用因子图建模权重依赖
  • 知识蒸馏:Teacher-Student网络的概率解释

与现有wiki内容的联系


参考


相关阅读

Footnotes

  1. Kschischang, F. R., Frey, B. J., & Loeliger, H. A. (2001). Factor graphs and the sum-product algorithm. IEEE Transactions on information theory, 47(2), 498-519.