团树算法(Junction Tree Algorithm)

1. 算法引入

1.1 精确推理的挑战

在贝叶斯网络中,计算边缘概率 涉及对所有其他变量的求和。当变量数目增加时,精确推理的计算复杂度呈指数增长

例子:对于有 个二值变量的链式网络,边缘化需要 次运算。

核心思想:利用图的稀疏结构分解性质,将联合分布分解为局部因子的乘积,从而将计算复杂度从 降低到与图的树宽相关的复杂度。

1.2 为什么需要团树?

变量消除(Variable Elimination)是一个简单但低效的方法:它每次只消除一个变量,每次消除都需要计算新的因子。

团树算法通过预处理将图转化为团树(clique tree),使得所有计算可以高效复用,本质上是变量消除的缓存优化版本


2. 图论基础

2.1 团的定义

定义:在无向图 中,(clique)是一个顶点集合 ,使得 中任意两个顶点之间都有边相连。

最大团(maximal clique):不是任何其他团的真超集的团。

最大团(maximum clique):包含顶点数最多的团。

2.2 弦图(Chordal Graph)

定义:一个无向图是弦图(chordal graph)当且仅当它的每个长度大于3的环都有一个(连接环上两个非相邻顶点的边)。

性质:弦图具有以下重要特性:

  1. 可以找到一个完美的顶点顺序(perfect elimination ordering)
  2. 最大团集合可以唯一确定
  3. 最小填充边集合可高效计算

2.3 运行交叉属性(RIP)

定义:给定一个团树 ,如果对于任意两个团 ,它们的交集 出现在连接 的路径上的所有团中,则称 满足运行交叉属性(Running Intersection Property)。

RIP的直观理解:共享变量(如条件独立的媒介)在团树中必须保持”连续”。


3. 团树构造算法

3.1 从贝叶斯网络到团树

步骤1:道德化(Moralization)

将贝叶斯网络转换为马尔可夫网络:

  1. 连接同父节点的父节点(添加边)
  2. 移除所有边的方向
def moralize(bn):
    """道德化贝叶斯网络
    
    Args:
        bn: 贝叶斯网络(有向无环图)
    
    Returns:
        moral_graph: 道德化的无向图
    """
    moral_graph = bn.to_undirected()
    
    # 连接同父节点的父节点
    for node in bn.nodes():
        parents = list(bn.predecessors(node))
        if len(parents) >= 2:
            for i, p1 in enumerate(parents):
                for p2 in parents[i+1:]:
                    moral_graph.add_edge(p1, p2)
    
    return moral_graph

步骤2:三角化(Triangulation)

向道德化图中添加边(填充边),使其成为弦图。

算法:最大基数搜索(Maximum Cardinality Search, MCS)

def triangulate_mcs(graph):
    """使用最大基数搜索进行三角化
    
    MCS通过给顶点编号来检测图是否已经是弦图,
    如果不是,则添加必要的填充边。
    
    Returns:
        chordal_graph: 三角化后的弦图
        fill_edges: 添加的填充边集合
    """
    n = len(graph.nodes())
    alpha = {v: -1 for v in graph.nodes()}
    unnumbered = set(graph.nodes())
    fill_edges = set()
    
    for k in range(n, 0, -1):
        # 选择编号最大的未编号顶点的邻居
        v = max(unnumbered, key=lambda x: len([u for u in graph.neighbors(x) 
                                                if alpha[u] != -1]))
        alpha[v] = k
        unnumbered.remove(v)
        
        # 找出所有已编号邻居
        numbered_neighbors = [u for u in graph.neighbors(v) if alpha[u] != -1]
        
        # 连接这些邻居(添加填充边)
        for i, u1 in enumerate(numbered_neighbors):
            for u2 in numbered_neighbors[i+1:]:
                if not graph.has_edge(u1, u2):
                    fill_edges.add((min(u1, u2), max(u1, u2)))
                    graph.add_edge(u1, u2)
    
    return graph, fill_edges

3.2 构建团树

步骤3:识别最大团

在三角化图中,最大团可以通过以下方式识别:

  • 使用完美消除顺序的逆序构建

步骤4:构建团树

将最大团组织成一棵树,满足RIP。

算法:最优团树构建

def build_junction_tree(moral_graph, root=None):
    """构建团树
    
    Args:
        moral_graph: 三角化后的道德化图
        root: 树根(可选)
    
    Returns:
        junction_tree: 团树
        cliques: 最大团列表
        separators: 分离器集合
    """
    # 找到所有最大团
    cliques = find_maximal_cliques(moral_graph)
    
    # 计算每个团的势(大小)
    clique_potentials = {c: len(c) for c in cliques}
    
    # 构建团树:使用最大团优先构建最大生成树
    # 边的权重 = |C_i ∩ C_j|
    edges = []
    for i, c1 in enumerate(cliques):
        for j, c2 in enumerate(cliques):
            if i < j:
                separator_size = len(set(c1) & set(c2))
                if separator_size > 0:
                    edges.append((i, j, separator_size))
    
    # Kruskal算法构建最大权重生成树
    cliques = list(cliques)
    edges.sort(key=lambda x: -x[2])  # 按分离器大小降序
    
    parent = list(range(len(cliques)))
    
    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
    
    tree_edges = []
    for i, j, weight in edges:
        pi, pj = find(i), find(j)
        if pi != pj:
            parent[pi] = pj
            tree_edges.append((i, j))
    
    # 构建团树对象
    junction_tree = {
        'cliques': cliques,
        'edges': tree_edges,
        'root': root if root is not None else 0
    }
    
    return junction_tree

4. 消息传递机制

4.1 势函数与边际

因子(Factor):定义在变量子集上的非负函数。对于贝叶斯网络,因子通常是条件概率表(CPT)。

势函数(Potential):在团树算法中,势函数 定义在每个团 上,表示该团中变量的联合分布(未归一化)。

4.2 Sum-Product消息传递

消息定义:从团 向相邻团 (通过分离器 )的消息为:

含义:通过在 边缘化(消去)不属于 的变量,得到关于分离器的函数。

4.3 Hugin架构(两遍消息传递)

收集阶段(Collection Phase)

  1. 从叶节点开始
  2. 叶子向父节点发送消息
  3. 父节点收集所有子节点的消息后,汇总并向上传递
  4. 到达根节点

分发阶段(Distribution Phase)

  1. 从根节点开始
  2. 根节点向子节点分发消息
  3. 子节点接收父节点消息后,向自己的子节点分发
  4. 直到所有叶子节点
def junction_tree_inference(junction_tree, factors, query_vars=None, evidence=None):
    """团树推断算法
    
    Args:
        junction_tree: 团树结构
        factors: 每个团的初始势函数(因子)
        query_vars: 查询变量集合
        evidence: 观测变量及其取值
    
    Returns:
        beliefs: 每个团的边际(信念)
    """
    cliques = junction_tree['cliques']
    edges = junction_tree['edges']
    
    # 构建邻接表
    adj = {i: [] for i in range(len(cliques))}
    for i, j in edges:
        adj[i].append(j)
        adj[j].append(i)
    
    # 初始化势函数
    potentials = {i: factors[i].copy() for i in range(len(cliques))}
    
    # ===== 收集阶段 =====
    visited = set()
    messages_collected = {i: [] for i in range(len(cliques))}
    
    def collect(node, parent):
        for child in adj[node]:
            if child != parent:
                collect(child, node)
                # 收集子节点的消息
                msg = compute_message(potentials[child], cliques[child], 
                                    cliques[node], messages_collected[child])
                messages_collected[node].append((child, msg))
        visited.add(node)
    
    # 从某个叶子开始收集
    leaves = [i for i in range(len(cliques)) if len(adj[i]) == 1]
    root = leaves[0] if leaves else 0
    collect(root, -1)
    
    # 更新根节点的势函数
    for child, msg in messages_collected[root]:
        potentials[root] = multiply_factors(potentials[root], msg, cliques[root])
    
    # ===== 分发阶段 =====
    beliefs = {}
    
    def distribute(node, parent, incoming_msg=None):
        # 更新当前团
        if incoming_msg is not None:
            potentials[node] = multiply_factors(potentials[node], 
                                               incoming_msg, cliques[node])
        
        # 计算边缘(信念)
        belief = marginalize(potentials[node], query_vars)
        beliefs[node] = belief
        
        # 向子节点分发消息
        for child in adj[node]:
            if child != parent:
                # 计算给子节点的消息
                msg_to_child = compute_message(potentials[node], cliques[node],
                                             cliques[child], messages_collected[node])
                distribute(child, node, msg_to_child)
    
    distribute(root, -1)
    
    return beliefs
 
 
def compute_message(potential, source_clique, target_clique, collected_messages):
    """计算消息"""
    separator = set(source_clique) & set(target_clique)
    # 将收集的消息相乘
    combined = potential
    for child, msg in collected_messages:
        combined = multiply_factors(combined, msg, source_clique)
    # 边缘化到分离器
    message = marginalize(combined, list(separator))
    return message
 
 
def multiply_factors(f1, f2, clique):
    """因子乘法"""
    # 实现因子乘法的逻辑
    pass
 
 
def marginalize(factor, vars_to_keep):
    """因子边缘化"""
    # 实现边缘化的逻辑
    pass

4.4 Shafer-Shenoy架构

与Hugin架构不同,Shafer-Shenoy架构不需要在本地存储边缘势函数,每次查询时重新计算消息。


5. 与变量消除的关系

5.1 变量消除的本质

变量消除(Variable Elimination)通过依次消除变量来计算边缘分布:

过程

  1. 将联合分布分解为因子的乘积
  2. 选择消除顺序
  3. 依次消除变量:乘相关因子,边缘化目标变量

5.2 团树 = 变量消除的缓存优化

方面变量消除团树算法
计算每次消除独立计算预处理后复用计算
多次查询每次查询重复计算一次构建,多次使用
中间因子存储在消除过程中存储在团中
复杂度 每步 每消息

其中 是变量的基数(或取值数), 是树宽(最大团大小减1)。


6. 完整实现示例

6.1 贝叶斯网络到团树的转换

import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
 
class JunctionTree:
    """团树类"""
    
    def __init__(self):
        self.cliques = []  # 团列表
        self.separators = {}  # 分离器
        self.potentials = {}  # 势函数
        self.beliefs = {}  # 信念
        
    def build_from_bayesian_network(self, bn):
        """从贝叶斯网络构建团树
        
        Args:
            bn: BayesianNetwork对象
        """
        # 步骤1:道德化
        moral_graph = self._moralize(bn)
        
        # 步骤2:三角化
        chordal_graph, fill_edges = self._triangulate(moral_graph)
        
        # 步骤3:找最大团
        self.cliques = self._find_maximal_cliques(chordal_graph)
        
        # 步骤4:构建团树
        self._build_tree()
        
        # 初始化势函数
        self._initialize_potentials(bn)
        
    def _moralize(self, bn):
        """道德化"""
        moral = bn.to_undirected()
        
        # 连接同父节点的父节点
        for node in bn.nodes():
            parents = list(bn.predecessors(node))
            for i, p1 in enumerate(parents):
                for p2 in parents[i+1:]:
                    if not moral.has_edge(p1, p2):
                        moral.add_edge(p1, p2)
        
        return moral
    
    def _triangulate(self, graph):
        """三角化(使用最大基数搜索)"""
        # 简化实现
        chordal = graph.copy()
        fill_edges = []
        
        # MCS三角化算法
        # ...
        
        return chordal, fill_edges
    
    def _find_maximal_cliques(self, chordal_graph):
        """找最大团(使用Bron–Kerbosch算法)"""
        cliques = []
        
        def bron_kerbosch(R, P, X):
            if not P and not X:
                cliques.append(R)
                return
            for v in list(P):
                bron_kerbosch(
                    R | {v},
                    P & set(chordal_graph.neighbors(v)),
                    X & set(chordal_graph.neighbors(v))
                )
                P.remove(v)
                X.add(v)
        
        bron_kerbosch(set(), set(chordal_graph.nodes()), set())
        
        # 过滤掉非最大团
        maximal_cliques = []
        cliques.sort(key=len, reverse=True)
        for c in cliques:
            is_maximal = True
            for existing in maximal_cliques:
                c.issubset(set(existing)):
                    is_maximal = False
                    break
            if is_maximal:
                maximal_cliques.append(list(c))
        
        return maximal_cliques
    
    def _build_tree(self):
        """构建团树"""
        # 使用最大权重生成树
        edges = []
        for i, c1 in enumerate(self.cliques):
            for j, c2 in enumerate(self.cliques):
                if i < j:
                    sep = set(c1) & set(c2)
                    if sep:
                        edges.append((i, j, len(sep)))
                        self.separators[(i, j)] = list(sep)
                        self.separators[(j, i)] = list(sep)
        
        # Kruskal建树
        edges.sort(key=lambda x: -x[2])
        parent = list(range(len(self.cliques)))
        
        def find(x):
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]
        
        tree_edges = []
        for i, j, _ in edges:
            pi, pj = find(i), find(j)
            if pi != pj:
                parent[pi] = pj
                tree_edges.append((i, j))
        
        self.tree_edges = tree_edges
    
    def _initialize_potentials(self, bn):
        """初始化势函数"""
        for i, clique in enumerate(self.cliques):
            # 初始化为均匀分布
            self.potentials[i] = np.ones([2] * len(clique))
    
    def run_sum_product(self):
        """执行Sum-Product消息传递"""
        # Hugin架构:两遍消息传递
        pass

6.2 实际应用:医疗诊断网络

def medical_diagnosis_example():
    """
    医疗诊断贝叶斯网络示例
    
    网络结构:
    吸烟(S) → 肺结核(T)
    吸烟(S) → 肺癌(L)
    肺结核(T) → 呼吸困难(D)
    肺癌(L) → 呼吸困难(D)
    呼吸困难(D) → 胸片异常(X)
    """
    
    # 定义网络结构
    bn = BayesianNetwork([
        ('S', 'T'), ('S', 'L'),
        ('T', 'D'), ('L', 'D'),
        ('D', 'X')
    ])
    
    # 定义CPT(简化)
    cpt = {
        'S': [0.1, 0.9],  # P(S)
        'T': {'S': [[0.05, 0.95], [0.01, 0.99]]},  # P(T|S)
        'L': {'S': [[0.10, 0.90], [0.01, 0.99]]},  # P(L|S)
        'D': {('T', 'L'): [[0.90, 0.10], [0.70, 0.30], 
                           [0.70, 0.30], [0.10, 0.90]]},  # P(D|T,L)
        'X': {'D': [[0.85, 0.15], [0.02, 0.98]]}   # P(X|D)
    }
    
    # 构建团树
    jt = JunctionTree()
    jt.build_from_bayesian_network(bn)
    
    # 查询:已知胸片异常,求肺癌概率
    query_result = jt.query(evidence={'X': 1}, target='L')
    
    return query_result

7. 算法复杂度分析

7.1 时间复杂度

其中:

  • :每个变量的取值数(或基数)
  • 树宽 = (最大团大小减1)
  • :变量数

7.2 空间复杂度

主要空间用于存储团的势函数。

7.3 树宽的影响

问题类型树宽适用性
树状网络1✅ 精确推理非常高效
信念网络(低连通)2-5✅ 精确推理可行
完全图❌ 需要近似推理

8. 与其他算法的关系

8.1 信念传播(Belief Propagation)

关系:信念传播是团树算法在树结构图上的特例。

算法图结构消息传递
信念传播单遍/两遍
团树算法一般图两遍(收集+分发)
Loopy BP有环图迭代(不保证收敛)

8.2 变量消除(Variable Elimination)

关系:团树算法是变量消除的缓存优化版本

方面变量消除团树算法
预处理需要构建团树
多次查询每次重新消除一次预处理,多次使用
中间因子丢弃存储在团中

8.3 变分推断(Variational Inference)

当树宽过大时,精确推理不可行,需要使用近似推理

方法原理权衡
Loopy Belief Propagation忽略环进行消息传递不保证收敛
变分推断用近似分布替代真实分布可能远离真实后验
采样方法(MCMC)从后验采样计算成本高

9. 实践注意事项

9.1 最优三角化

问题:不同的三角化会产生不同的团树,从而影响效率。

目标:最小化最大团的大小(从而最小化树宽)。

NP难问题,常用启发式方法:

  • 最小度启发式(Minimum Degree)
  • 最小填充启发式(Minimum Fill)

9.2 数值稳定性

  • 使用对数空间存储势函数
  • 避免下溢(underflow)
  • 必要时进行归一化

9.3 实际工具

工具语言特点
pgmpyPython完整的PGM工具包
bnlearnR专注贝叶斯网络
Microsoft Infer.NETC#工业级
StanC++MCMC为主,支持精确推断

参考文献