因子图与信念传播

因子图(Factor Graph)是一种灵活的概率图表示方法,将复杂概率分布的分解结构显式建模。信念传播(Belief Propagation)算法利用因子图进行高效的边缘概率推断。1

一、为什么需要因子图

从联合分布到因子分解

对于包含 个变量的联合分布 ,直接计算和存储是指数级的。通过条件独立性假设,可以将其分解为局部函数的乘积:

其中:

  • 是团的集合
  • 是团势函数(clique potential)
  • 是归一化常数(配分函数)

因子图的引入

因子图提供了一种显式表示这种分解的方式:

  • 变量节点:表示随机变量(圆形)
  • 因子节点:表示局部函数(方形)
  • :连接因子与其作用的变量

这种二部图结构使得消息传递算法的推导变得直观自然。

二、因子图定义

形式化定义

因子图是一个二部图

  • :因子节点集合
  • :变量节点集合
  • :边集合,连接因子 与其依赖的变量

全局函数分解

设全局函数 可以分解为:

其中 是因子 依赖的变量集合。

示例:从贝叶斯网络到因子图

考虑一个简单的贝叶斯网络:

I → G ← D
│       │
└───────┘
   S

其中:

  • :智力
  • :课程难度
  • :成绩
  • :SAT分数

联合概率分解

对应的因子图

  f1      f2         f3        f4
  ○        ○          ○         ○
  │        │          │         │
  I ────── f_G ────── D        S
  │        │          │         │
  └────────┼──────────┘         │
          │                     │
          └─────────────────────┘

对应的因子:

三、因子图上的概率推断

问题定义

给定因子图,求变量 的边缘概率:

其中 表示除 外的所有变量。

变量消除算法

变量消除(Variable Elimination)是理解信念传播的基础。

核心思想:按顺序消除变量,每次只处理涉及当前变量的因子。

def variable_elimination(factors, query_var, eliminate_order):
    """
    变量消除算法
    
    Args:
        factors: 因子列表
        query_var: 待查询变量
        eliminate_order: 消除顺序
    
    Returns:
        边缘概率分布
    """
    active_factors = factors.copy()
    
    for var in eliminate_order:
        if var == query_var:
            continue
        
        # 1. 收集所有包含该变量的因子
        relevant = [f for f in active_factors if var in f.variables]
        
        # 2. 乘积:所有相关因子的乘积
        product = relevant[0]
        for f in relevant[1:]:
            product = multiply_factors(product, f)
        
        # 3. 边缘化:对该变量求和
        marginalized = sum_out(product, var)
        
        # 4. 更新活跃因子列表
        active_factors = [f for f in active_factors if f not in relevant]
        active_factors.append(marginalized)
    
    # 最后对查询变量求和
    result = sum_out(active_factors[0], query_var)
    return normalize(result)

复杂度分析

变量消除的计算复杂度取决于:

  • 消除顺序:好的顺序可以大幅降低复杂度
  • 因子树宽(treewidth):图的树分解宽度

其中 是树宽, 是变量域大小。

四、信念传播算法(Sum-Product Algorithm)

算法概述

信念传播(Belief Propagation)利用因子图的树结构,通过消息传递高效计算所有边缘概率。

关键假设:因子图是树结构(或已转化为树分解)

消息定义

变量节点到因子节点的消息

变量节点 向因子节点 发送的消息是所有其他相邻因子传递消息的乘积:

其中 的邻居集合。

因子节点到变量节点的消息

因子节点 向变量节点 发送的消息是因子函数与其他所有变量的消息乘积的边缘化:

边缘概率计算

当消息传递完成后,每个变量节点的边缘概率为:

树结构上的消息传递流程

对于树结构的因子图,信念传播的流程:

1. 初始化:所有叶子节点的消息为均匀分布
2. 迭代:叶子节点向邻居发送消息
3. 收敛:当所有消息稳定后,计算边缘概率
def belief_propagation(factor_graph, max_iter=100, tol=1e-6):
    """
    信念传播算法
    
    适用于树结构的因子图,保证收敛到精确解
    """
    messages = {}  # (src, dst) -> message
    beliefs = {}   # var -> belief
    
    # 初始化叶子节点消息
    for node in factor_graph.leaves:
        if isinstance(node, VariableNode):
            messages[(node, node.neighbors[0])] = np.ones(node.domain_size)
        else:  # FactorNode
            # 叶子因子节点的消息通过边缘化得到
            messages[(node, node.neighbors[0])] = marginalize(node.factor, node.other_vars)
    
    # 迭代消息传递
    for iteration in range(max_iter):
        old_messages = messages.copy()
        
        for factor in factor_graph.factors:
            for var in factor.neighbors:
                # 计算 var -> factor 的消息
                msg_v_to_f = np.ones(var.domain_size)
                for other_factor in var.neighbors:
                    if other_factor != factor:
                        msg_v_to_f *= messages.get((other_factor, var), np.ones(var.domain_size))
                messages[(var, factor)] = msg_v_to_f
                
                # 计算 factor -> var 的消息
                msg_f_to_v = factor.compute_message(var, messages)
                messages[(factor, var)] = msg_f_to_v
        
        # 检查收敛
        diff = sum(np.abs(messages.get(k, 0) - old_messages.get(k, 0)) 
                   for k in set(messages.keys()) | set(old_messages.keys()))
        if diff < tol:
            break
    
    # 计算信念(边缘概率)
    for var in factor_graph.variables:
        belief = np.ones(var.domain_size)
        for factor in var.neighbors:
            belief *= messages.get((factor, var), np.ones(var.domain_size))
        beliefs[var] = belief / belief.sum()
    
    return beliefs, messages

五、线性链条件随机场的消息传递

信念传播在序列模型中有特殊形式。考虑线性链CRF:

X1 ──f12── X2 ──f23── X3 ──f34── X4
 │          │          │
f1( )      f2( )      f3( )      f4( )

前向-后向算法

前向消息

后向消息

边缘概率

def forward_backward(observations, potentials, transition_potentials):
    """
    前向-后向算法(线性链CRF)
    
    Args:
        observations: 观测序列
        potentials: 发射势函数 ψ(x_t)
        transition_potentials: 转移势函数 ψ(x_{t-1}, x_t)
    """
    T = len(observations)
    states = potentials[0].shape[0]
    
    # 前向传递
    alpha = np.zeros((T, states))
    alpha[0] = potentials[0]
    
    for t in range(1, T):
        for j in range(states):
            alpha[t, j] = potentials[t, j] * np.sum(
                alpha[t-1, i] * transition_potentials[i, j]
                for i in range(states)
            )
    
    # 后向传递
    beta = np.zeros((T, states))
    beta[T-1] = 1.0
    
    for t in range(T-2, -1, -1):
        for i in range(states):
            beta[t, i] = np.sum(
                potentials[t+1, j] * transition_potentials[i, j] * beta[t+1, j]
                for j in range(states)
            )
    
    # 计算边缘概率
    Z = alpha[-1].sum()  # 配分函数
    marginals = alpha * beta / Z
    
    return marginals, alpha, beta

六、置信传播与贝叶斯网络的关系

有向图到因子图的转换

贝叶斯网络的联合概率:

转换为因子图:每个条件概率 对应一个因子节点。

例子:学生成绩网络

# 定义变量
I = VariableNode('Intelligence', ['high', 'low'])  # 智力
D = VariableNode('Difficulty', ['easy', 'hard'])    # 难度
G = VariableNode('Grade', ['A', 'B', 'C'])          # 成绩
S = VariableNode('SAT', ['high', 'low'])            # SAT分数
 
# 定义因子
# P(I)
f_intel = FactorNode('P(I)', [I], 
                     values=np.array([0.3, 0.7]))  # P(I=high)=0.3
 
# P(D)
f_diff = FactorNode('P(D)', [D],
                    values=np.array([0.6, 0.4]))  # P(D=easy)=0.6
 
# P(G | I, D) - 条件概率表
cpt_grade = np.array([
    # I=high, D=easy   I=high, D=hard   I=low, D=easy   I=low, D=hard
    [0.9, 0.6, 0.3, 0.1],  # G=A
    [0.08, 0.3, 0.4, 0.3],  # G=B
    [0.02, 0.1, 0.3, 0.6]   # G=C
])
f_grade = FactorNode('P(G|I,D)', [G, I, D], values=cpt_grade)
 
# P(S | I)
cpt_sat = np.array([
    # I=high   I=low
    [0.8, 0.2],  # S=high
    [0.2, 0.8]   # S=low
])
f_sat = FactorNode('P(S|I)', [S, I], values=cpt_sat)
 
# 构建因子图
factor_graph = FactorGraph([f_intel, f_diff, f_grade, f_sat], [I, D, G, S])
 
# 运行信念传播
beliefs = belief_propagation(factor_graph)
print(f"P(I=high) = {beliefs[I][0]:.3f}")
print(f"P(G=A) = {beliefs[G][0]:.3f}")

七、Loopy Belief Propagation

环的存在

实际应用中,因子图往往包含环(loops),不再是树结构:

A ──f1── B
│         │
f2        f3
│         │
D ──f4── C

近似处理

对于带环的因子图,可以使用Loopy Belief Propagation (LBP)

  1. 初始化所有消息为均匀分布
  2. 随机选择一个节点发送消息
  3. 迭代直到收敛(或达到最大迭代次数)

注意:LBP不保证收敛,且结果为近似解。

class LoopyBeliefPropagation:
    """Loopy Belief Propagation for general factor graphs"""
    
    def __init__(self, factor_graph, max_iter=100, tol=1e-4):
        self.fg = factor_graph
        self.max_iter = max_iter
        self.tol = tol
        self.messages = {}
        self.beliefs = {}
    
    def run(self):
        """运行LBP"""
        # 初始化
        self._init_messages()
        
        for iteration in range(self.max_iter):
            old_beliefs = self.beliefs.copy()
            
            # 更新所有因子节点的消息
            for factor in self.fg.factors:
                for var in factor.neighbors:
                    self._update_factor_to_var(factor, var)
            
            # 更新所有变量节点的消息
            for var in self.fg.variables:
                for factor in var.neighbors:
                    self._update_var_to_factor(var, factor)
            
            # 更新信念
            self._update_beliefs()
            
            # 检查收敛
            diff = sum(np.abs(self.beliefs[k] - old_beliefs.get(k, 0)).sum() 
                       for k in self.beliefs)
            if diff < self.tol:
                print(f"Converged at iteration {iteration}")
                break
        
        return self.beliefs
    
    def _update_factor_to_var(self, factor, var):
        """因子到变量的消息传递"""
        neighbors = [n for n in factor.neighbors if n != var]
        
        # 消息 = 因子 × 所有其他邻居的消息的乘积,然后边缘化
        message = factor.values.copy()
        for neighbor in neighbors:
            msg = self.messages.get((neighbor, factor), np.ones(neighbor.domain_size))
            message = message * np.expand_dims(msg, axis=-1)
        
        # 边缘化
        axes = [i for i, n in enumerate(factor.neighbors) if n != var]
        message = np.sum(message, axis=tuple(axes))
        
        # 归一化
        message = message / (message.sum() + 1e-10)
        
        self.messages[(factor, var)] = message
    
    def _update_var_to_factor(self, var, factor):
        """变量到因子的消息传递"""
        neighbors = [n for n in var.neighbors if n != factor]
        
        # 消息 = 所有其他邻居的消息的乘积
        message = np.ones(var.domain_size)
        for neighbor in neighbors:
            msg = self.messages.get((neighbor, var), np.ones(var.domain_size))
            message = message * msg
        
        self.messages[(var, factor)] = message
    
    def _update_beliefs(self):
        """更新信念(边缘概率近似)"""
        for var in self.fg.variables:
            belief = np.ones(var.domain_size)
            for factor in var.neighbors:
                msg = self.messages.get((factor, var), np.ones(var.domain_size))
                belief = belief * msg
            self.beliefs[var] = belief / (belief.sum() + 1e-10)

八、置信传播与变分推断的联系

置信传播作为变分推断

置信传播可以理解为一种变分推断方法,其中:

  • 近似分布:边缘分布的乘积
  • 优化目标:最小化 或最大化 ELBO

平均场近似

平均场近似假设:

这与置信传播在树结构上的结果一致。

与变分推断的关系

方法近似形式优化方式
置信传播边缘分布乘积固定点迭代
变分推断参数化分布族梯度下降
平均场VI独立分布坐标上升

详见 变分推断

九、应用场景

1. 图像去噪

# MRF/CRF图像去噪模型
# 变量:每个像素的标签
# 因子:观测似然 + 平滑先验
 
def denoise_image(image, noise_var, smooth_weight, n_iter=10):
    """
    使用置信传播进行图像去噪
    
    Args:
        image: 带噪声的图像
        noise_var: 噪声方差
        smooth_weight: 平滑项权重
    """
    h, w = image.shape
    n_vars = h * w
    
    # 构建因子图
    factors = []
    
    # 一元因子(观测)
    for i in range(h):
        for j in range(w):
            var = pixel_vars[i * w + j]
            # P(观测 | 真实值) ∝ exp(-(y-x)²/2σ²)
            unary = np.exp(-(image[i, j] - pixel_values)**2 / (2 * noise_var))
            factors.append(FactorNode(f'unary_{i}_{j}', [var], unary))
    
    # 二元因子(平滑)
    for i in range(h):
        for j in range(w - 1):
            var1, var2 = pixel_vars[i*w+j], pixel_vars[i*w+j+1]
            # P(x_i, x_j) ∝ exp(-λ(x_i - x_j)²)
            pairwise = np.exp(-smooth_weight * (pixel_values[:, None] - pixel_values[None, :])**2)
            factors.append(FactorNode(f'pair_{i}_{j}', [var1, var2], pairwise))
    
    # 运行置信传播
    fg = FactorGraph(factors, pixel_vars)
    beliefs = loopy_belief_propagation(fg, n_iter)
    
    # 提取去噪结果
    denoised = np.zeros_like(image)
    for i in range(h):
        for j in range(w):
            denoised[i, j] = pixel_values[np.argmax(beliefs[pixel_vars[i*w+j]])]
    
    return denoised

2. 医疗诊断

因子图可用于构建医疗专家系统:

症状 ──f1── 疾病 ──f2── 检查结果
            │
            └───f3── 其他风险因素

3. 推荐系统

# 因子图推荐模型
# 变量:用户偏好、物品特征
# 因子:用户-物品交互、偏好相似性
 
def build_recommendation_factor_graph(user_ids, item_ids, interactions):
    """构建推荐系统的因子图"""
    factors = []
    
    # 用户偏好因子
    for u in user_ids:
        user_var = VariableNode(f'user_{u}', domain_size=10)
        factors.append(FactorNode(f'pref_{u}', [user_var], user_prior[u]))
    
    # 物品特征因子
    for i in item_ids:
        item_var = VariableNode(f'item_{i}', domain_size=20)
        factors.append(FactorNode(f'feat_{i}', [item_var], item_prior[i]))
    
    # 交互因子
    for (u, i, rating) in interactions:
        factors.append(
            FactorNode(f'interact_{u}_{i}', 
                      [user_vars[u], item_vars[i]], 
                      rating_matrix[u, i])
        )
    
    return FactorGraph(factors, all_vars)

十、算法复杂度分析

场景时间复杂度空间复杂度
树结构(精确)
树宽 (精确)
带环(LBP)

其中 是变量数, 是域大小, 是因子数, 是最大因子阶数。

十一、与其他方法的关系

┌─────────────────────────────────────────────────────────────┐
│                    概率推断方法                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐    │
│  │  精确推断   │     │  近似推断   │     │  采样方法   │    │
│  └─────────────┘     └─────────────┘     └─────────────┘    │
│         │                  │                   │            │
│         ▼                  ▼                   ▼            │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐    │
│  │ 变量消除    │     │  变分推断   │     │    MCMC     │    │
│  │ 信念传播    │     │  Loopy BP  │     │  (见下文)   │    │
│  └─────────────┘     └─────────────┘     └─────────────┘    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

详见:

参考

Footnotes

  1. Kschischang, F.R., Frey, B.J., & Loeliger, H.A. (2001). “Factor Graphs and the Sum-Product Algorithm”. IEEE Transactions on Information Theory, 47(2), 498-519.