团树算法(Junction Tree Algorithm)
1. 算法引入
1.1 精确推理的挑战
在贝叶斯网络中,计算边缘概率 涉及对所有其他变量的求和。当变量数目增加时,精确推理的计算复杂度呈指数增长。
例子:对于有 个二值变量的链式网络,边缘化需要 次运算。
核心思想:利用图的稀疏结构和分解性质,将联合分布分解为局部因子的乘积,从而将计算复杂度从 降低到与图的树宽相关的复杂度。
1.2 为什么需要团树?
变量消除(Variable Elimination)是一个简单但低效的方法:它每次只消除一个变量,每次消除都需要计算新的因子。
团树算法通过预处理将图转化为团树(clique tree),使得所有计算可以高效复用,本质上是变量消除的缓存优化版本。
2. 图论基础
2.1 团的定义
定义:在无向图 中,团(clique)是一个顶点集合 ,使得 中任意两个顶点之间都有边相连。
最大团(maximal clique):不是任何其他团的真超集的团。
最大团(maximum clique):包含顶点数最多的团。
2.2 弦图(Chordal Graph)
定义:一个无向图是弦图(chordal graph)当且仅当它的每个长度大于3的环都有一个弦(连接环上两个非相邻顶点的边)。
性质:弦图具有以下重要特性:
- 可以找到一个完美的顶点顺序(perfect elimination ordering)
- 最大团集合可以唯一确定
- 最小填充边集合可高效计算
2.3 运行交叉属性(RIP)
定义:给定一个团树 ,如果对于任意两个团 ,它们的交集 出现在连接 和 的路径上的所有团中,则称 满足运行交叉属性(Running Intersection Property)。
RIP的直观理解:共享变量(如条件独立的媒介)在团树中必须保持”连续”。
3. 团树构造算法
3.1 从贝叶斯网络到团树
步骤1:道德化(Moralization)
将贝叶斯网络转换为马尔可夫网络:
- 连接同父节点的父节点(添加边)
- 移除所有边的方向
def moralize(bn):
"""道德化贝叶斯网络
Args:
bn: 贝叶斯网络(有向无环图)
Returns:
moral_graph: 道德化的无向图
"""
moral_graph = bn.to_undirected()
# 连接同父节点的父节点
for node in bn.nodes():
parents = list(bn.predecessors(node))
if len(parents) >= 2:
for i, p1 in enumerate(parents):
for p2 in parents[i+1:]:
moral_graph.add_edge(p1, p2)
return moral_graph步骤2:三角化(Triangulation)
向道德化图中添加边(填充边),使其成为弦图。
算法:最大基数搜索(Maximum Cardinality Search, MCS)
def triangulate_mcs(graph):
"""使用最大基数搜索进行三角化
MCS通过给顶点编号来检测图是否已经是弦图,
如果不是,则添加必要的填充边。
Returns:
chordal_graph: 三角化后的弦图
fill_edges: 添加的填充边集合
"""
n = len(graph.nodes())
alpha = {v: -1 for v in graph.nodes()}
unnumbered = set(graph.nodes())
fill_edges = set()
for k in range(n, 0, -1):
# 选择编号最大的未编号顶点的邻居
v = max(unnumbered, key=lambda x: len([u for u in graph.neighbors(x)
if alpha[u] != -1]))
alpha[v] = k
unnumbered.remove(v)
# 找出所有已编号邻居
numbered_neighbors = [u for u in graph.neighbors(v) if alpha[u] != -1]
# 连接这些邻居(添加填充边)
for i, u1 in enumerate(numbered_neighbors):
for u2 in numbered_neighbors[i+1:]:
if not graph.has_edge(u1, u2):
fill_edges.add((min(u1, u2), max(u1, u2)))
graph.add_edge(u1, u2)
return graph, fill_edges3.2 构建团树
步骤3:识别最大团
在三角化图中,最大团可以通过以下方式识别:
- 使用完美消除顺序的逆序构建
步骤4:构建团树
将最大团组织成一棵树,满足RIP。
算法:最优团树构建
def build_junction_tree(moral_graph, root=None):
"""构建团树
Args:
moral_graph: 三角化后的道德化图
root: 树根(可选)
Returns:
junction_tree: 团树
cliques: 最大团列表
separators: 分离器集合
"""
# 找到所有最大团
cliques = find_maximal_cliques(moral_graph)
# 计算每个团的势(大小)
clique_potentials = {c: len(c) for c in cliques}
# 构建团树:使用最大团优先构建最大生成树
# 边的权重 = |C_i ∩ C_j|
edges = []
for i, c1 in enumerate(cliques):
for j, c2 in enumerate(cliques):
if i < j:
separator_size = len(set(c1) & set(c2))
if separator_size > 0:
edges.append((i, j, separator_size))
# Kruskal算法构建最大权重生成树
cliques = list(cliques)
edges.sort(key=lambda x: -x[2]) # 按分离器大小降序
parent = list(range(len(cliques)))
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
tree_edges = []
for i, j, weight in edges:
pi, pj = find(i), find(j)
if pi != pj:
parent[pi] = pj
tree_edges.append((i, j))
# 构建团树对象
junction_tree = {
'cliques': cliques,
'edges': tree_edges,
'root': root if root is not None else 0
}
return junction_tree4. 消息传递机制
4.1 势函数与边际
因子(Factor):定义在变量子集上的非负函数。对于贝叶斯网络,因子通常是条件概率表(CPT)。
势函数(Potential):在团树算法中,势函数 定义在每个团 上,表示该团中变量的联合分布(未归一化)。
4.2 Sum-Product消息传递
消息定义:从团 向相邻团 (通过分离器 )的消息为:
含义:通过在 中边缘化(消去)不属于 的变量,得到关于分离器的函数。
4.3 Hugin架构(两遍消息传递)
收集阶段(Collection Phase):
- 从叶节点开始
- 叶子向父节点发送消息
- 父节点收集所有子节点的消息后,汇总并向上传递
- 到达根节点
分发阶段(Distribution Phase):
- 从根节点开始
- 根节点向子节点分发消息
- 子节点接收父节点消息后,向自己的子节点分发
- 直到所有叶子节点
def junction_tree_inference(junction_tree, factors, query_vars=None, evidence=None):
"""团树推断算法
Args:
junction_tree: 团树结构
factors: 每个团的初始势函数(因子)
query_vars: 查询变量集合
evidence: 观测变量及其取值
Returns:
beliefs: 每个团的边际(信念)
"""
cliques = junction_tree['cliques']
edges = junction_tree['edges']
# 构建邻接表
adj = {i: [] for i in range(len(cliques))}
for i, j in edges:
adj[i].append(j)
adj[j].append(i)
# 初始化势函数
potentials = {i: factors[i].copy() for i in range(len(cliques))}
# ===== 收集阶段 =====
visited = set()
messages_collected = {i: [] for i in range(len(cliques))}
def collect(node, parent):
for child in adj[node]:
if child != parent:
collect(child, node)
# 收集子节点的消息
msg = compute_message(potentials[child], cliques[child],
cliques[node], messages_collected[child])
messages_collected[node].append((child, msg))
visited.add(node)
# 从某个叶子开始收集
leaves = [i for i in range(len(cliques)) if len(adj[i]) == 1]
root = leaves[0] if leaves else 0
collect(root, -1)
# 更新根节点的势函数
for child, msg in messages_collected[root]:
potentials[root] = multiply_factors(potentials[root], msg, cliques[root])
# ===== 分发阶段 =====
beliefs = {}
def distribute(node, parent, incoming_msg=None):
# 更新当前团
if incoming_msg is not None:
potentials[node] = multiply_factors(potentials[node],
incoming_msg, cliques[node])
# 计算边缘(信念)
belief = marginalize(potentials[node], query_vars)
beliefs[node] = belief
# 向子节点分发消息
for child in adj[node]:
if child != parent:
# 计算给子节点的消息
msg_to_child = compute_message(potentials[node], cliques[node],
cliques[child], messages_collected[node])
distribute(child, node, msg_to_child)
distribute(root, -1)
return beliefs
def compute_message(potential, source_clique, target_clique, collected_messages):
"""计算消息"""
separator = set(source_clique) & set(target_clique)
# 将收集的消息相乘
combined = potential
for child, msg in collected_messages:
combined = multiply_factors(combined, msg, source_clique)
# 边缘化到分离器
message = marginalize(combined, list(separator))
return message
def multiply_factors(f1, f2, clique):
"""因子乘法"""
# 实现因子乘法的逻辑
pass
def marginalize(factor, vars_to_keep):
"""因子边缘化"""
# 实现边缘化的逻辑
pass4.4 Shafer-Shenoy架构
与Hugin架构不同,Shafer-Shenoy架构不需要在本地存储边缘势函数,每次查询时重新计算消息。
5. 与变量消除的关系
5.1 变量消除的本质
变量消除(Variable Elimination)通过依次消除变量来计算边缘分布:
过程:
- 将联合分布分解为因子的乘积
- 选择消除顺序
- 依次消除变量:乘相关因子,边缘化目标变量
5.2 团树 = 变量消除的缓存优化
| 方面 | 变量消除 | 团树算法 |
|---|---|---|
| 计算 | 每次消除独立计算 | 预处理后复用计算 |
| 多次查询 | 每次查询重复计算 | 一次构建,多次使用 |
| 中间因子 | 存储在消除过程中 | 存储在团中 |
| 复杂度 | 每步 | 每消息 |
其中 是变量的基数(或取值数), 是树宽(最大团大小减1)。
6. 完整实现示例
6.1 贝叶斯网络到团树的转换
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
class JunctionTree:
"""团树类"""
def __init__(self):
self.cliques = [] # 团列表
self.separators = {} # 分离器
self.potentials = {} # 势函数
self.beliefs = {} # 信念
def build_from_bayesian_network(self, bn):
"""从贝叶斯网络构建团树
Args:
bn: BayesianNetwork对象
"""
# 步骤1:道德化
moral_graph = self._moralize(bn)
# 步骤2:三角化
chordal_graph, fill_edges = self._triangulate(moral_graph)
# 步骤3:找最大团
self.cliques = self._find_maximal_cliques(chordal_graph)
# 步骤4:构建团树
self._build_tree()
# 初始化势函数
self._initialize_potentials(bn)
def _moralize(self, bn):
"""道德化"""
moral = bn.to_undirected()
# 连接同父节点的父节点
for node in bn.nodes():
parents = list(bn.predecessors(node))
for i, p1 in enumerate(parents):
for p2 in parents[i+1:]:
if not moral.has_edge(p1, p2):
moral.add_edge(p1, p2)
return moral
def _triangulate(self, graph):
"""三角化(使用最大基数搜索)"""
# 简化实现
chordal = graph.copy()
fill_edges = []
# MCS三角化算法
# ...
return chordal, fill_edges
def _find_maximal_cliques(self, chordal_graph):
"""找最大团(使用Bron–Kerbosch算法)"""
cliques = []
def bron_kerbosch(R, P, X):
if not P and not X:
cliques.append(R)
return
for v in list(P):
bron_kerbosch(
R | {v},
P & set(chordal_graph.neighbors(v)),
X & set(chordal_graph.neighbors(v))
)
P.remove(v)
X.add(v)
bron_kerbosch(set(), set(chordal_graph.nodes()), set())
# 过滤掉非最大团
maximal_cliques = []
cliques.sort(key=len, reverse=True)
for c in cliques:
is_maximal = True
for existing in maximal_cliques:
c.issubset(set(existing)):
is_maximal = False
break
if is_maximal:
maximal_cliques.append(list(c))
return maximal_cliques
def _build_tree(self):
"""构建团树"""
# 使用最大权重生成树
edges = []
for i, c1 in enumerate(self.cliques):
for j, c2 in enumerate(self.cliques):
if i < j:
sep = set(c1) & set(c2)
if sep:
edges.append((i, j, len(sep)))
self.separators[(i, j)] = list(sep)
self.separators[(j, i)] = list(sep)
# Kruskal建树
edges.sort(key=lambda x: -x[2])
parent = list(range(len(self.cliques)))
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
tree_edges = []
for i, j, _ in edges:
pi, pj = find(i), find(j)
if pi != pj:
parent[pi] = pj
tree_edges.append((i, j))
self.tree_edges = tree_edges
def _initialize_potentials(self, bn):
"""初始化势函数"""
for i, clique in enumerate(self.cliques):
# 初始化为均匀分布
self.potentials[i] = np.ones([2] * len(clique))
def run_sum_product(self):
"""执行Sum-Product消息传递"""
# Hugin架构:两遍消息传递
pass6.2 实际应用:医疗诊断网络
def medical_diagnosis_example():
"""
医疗诊断贝叶斯网络示例
网络结构:
吸烟(S) → 肺结核(T)
吸烟(S) → 肺癌(L)
肺结核(T) → 呼吸困难(D)
肺癌(L) → 呼吸困难(D)
呼吸困难(D) → 胸片异常(X)
"""
# 定义网络结构
bn = BayesianNetwork([
('S', 'T'), ('S', 'L'),
('T', 'D'), ('L', 'D'),
('D', 'X')
])
# 定义CPT(简化)
cpt = {
'S': [0.1, 0.9], # P(S)
'T': {'S': [[0.05, 0.95], [0.01, 0.99]]}, # P(T|S)
'L': {'S': [[0.10, 0.90], [0.01, 0.99]]}, # P(L|S)
'D': {('T', 'L'): [[0.90, 0.10], [0.70, 0.30],
[0.70, 0.30], [0.10, 0.90]]}, # P(D|T,L)
'X': {'D': [[0.85, 0.15], [0.02, 0.98]]} # P(X|D)
}
# 构建团树
jt = JunctionTree()
jt.build_from_bayesian_network(bn)
# 查询:已知胸片异常,求肺癌概率
query_result = jt.query(evidence={'X': 1}, target='L')
return query_result7. 算法复杂度分析
7.1 时间复杂度
其中:
- :每个变量的取值数(或基数)
- :树宽 = (最大团大小减1)
- :变量数
7.2 空间复杂度
主要空间用于存储团的势函数。
7.3 树宽的影响
| 问题类型 | 树宽 | 适用性 |
|---|---|---|
| 树状网络 | 1 | ✅ 精确推理非常高效 |
| 信念网络(低连通) | 2-5 | ✅ 精确推理可行 |
| 完全图 | ❌ 需要近似推理 |
8. 与其他算法的关系
8.1 信念传播(Belief Propagation)
关系:信念传播是团树算法在树结构图上的特例。
| 算法 | 图结构 | 消息传递 |
|---|---|---|
| 信念传播 | 树 | 单遍/两遍 |
| 团树算法 | 一般图 | 两遍(收集+分发) |
| Loopy BP | 有环图 | 迭代(不保证收敛) |
8.2 变量消除(Variable Elimination)
关系:团树算法是变量消除的缓存优化版本。
| 方面 | 变量消除 | 团树算法 |
|---|---|---|
| 预处理 | 无 | 需要构建团树 |
| 多次查询 | 每次重新消除 | 一次预处理,多次使用 |
| 中间因子 | 丢弃 | 存储在团中 |
8.3 变分推断(Variational Inference)
当树宽过大时,精确推理不可行,需要使用近似推理:
| 方法 | 原理 | 权衡 |
|---|---|---|
| Loopy Belief Propagation | 忽略环进行消息传递 | 不保证收敛 |
| 变分推断 | 用近似分布替代真实分布 | 可能远离真实后验 |
| 采样方法(MCMC) | 从后验采样 | 计算成本高 |
9. 实践注意事项
9.1 最优三角化
问题:不同的三角化会产生不同的团树,从而影响效率。
目标:最小化最大团的大小(从而最小化树宽)。
NP难问题,常用启发式方法:
- 最小度启发式(Minimum Degree)
- 最小填充启发式(Minimum Fill)
9.2 数值稳定性
- 使用对数空间存储势函数
- 避免下溢(underflow)
- 必要时进行归一化
9.3 实际工具
| 工具 | 语言 | 特点 |
|---|---|---|
| pgmpy | Python | 完整的PGM工具包 |
| bnlearn | R | 专注贝叶斯网络 |
| Microsoft Infer.NET | C# | 工业级 |
| Stan | C++ | MCMC为主,支持精确推断 |