概述

概率图模型(Probabilistic Graphical Models, PGM) 是用图结构表示随机变量之间条件独立关系的概率模型框架。这一框架将贝叶斯网络的有向表示和马尔可夫随机场的无向表示统一起来,为不确定性推理、结构化预测和因果推断提供了强大工具。1

┌────────────────────────────────────────────────────────────────────────────┐
│                          概率图模型体系架构                                   │
├────────────────────────────────────────────────────────────────────────────┤
│                                                                            │
│                         概率图模型 (PGM)                                     │
│                              │                                              │
│              ┌───────────────┴───────────────┐                            │
│              │                               │                            │
│        有向图模型                        无向图模型                        │
│              │                               │                            │
│    ┌─────────┴─────────┐           ┌────────┴────────┐                  │
│    │                   │           │                 │                   │
│ 贝叶斯网络         动态模型       马尔可夫随机场      条件随机场             │
│ (DAG)              (TBN/HMM)         (MRF)           (CRF)                  │
│                                                                            │
│                              │                                              │
│                        因子图表示                                          │
│                              │                                              │
│              ┌───────────────┴───────────────┐                            │
│              │                               │                            │
│        Sum-Product算法                  Max-Product算法                    │
│        (边缘推断)                       (MAP推断)                          │
│                                                                            │
└────────────────────────────────────────────────────────────────────────────┘

1. 概率图模型基础

1.1 图表示的基本要素

概率图模型通过有向或无向图 表示变量之间的依赖关系:

要素符号含义
节点/顶点随机变量集合
变量间的直接依赖关系
父节点指向节点 的所有节点
子节点从节点 指向的所有节点
邻居与节点 相邻的节点(无向图)
马尔可夫毯完全决定节点 的条件分布

1.2 有向图 vs 无向图

特性有向图(贝叶斯网络)无向图(马尔可夫随机场)
边方向有向边(因果/条件依赖)无向边(相关关系)
图结构必须是有向无环图(DAG)可以有环(loopy)
联合分布$P(X_1,…,X_n) = \prod_i P(X_i\text{Pa}(X_i))$
势函数自然对应条件概率需要定义势函数
适用场景因果建模、诊断推理空间相关性、图像处理

1.3 条件独立性

条件独立性是概率图模型的核心概念。在给定其他变量的条件下,变量之间可能变得独立。

定义:设 为三个互不相交的变量集合,若

则称 在给定 的条件下条件独立,记作

在图模型中的意义

  • 有向图中,条件独立性通过 D-分离(d-separation)判定
  • 无向图中,条件独立性由图的连通性决定

1.4 D-分离(D-Separation)

D-分离是判定有向图中条件独立性的图判定方法。对于三个不相交的节点集合

三种基本结构

结构图形独立性性质
链式结构
分叉结构
碰撞结构 未观测时)
观测时)

D-分离判定规则

┌─────────────────────────────────────────────────────────────────┐
│                     D-分离判定算法                                │
├─────────────────────────────────────────────────────────────────┤
│ 输入:有向无环图 G,节点集合 X, Y, Z                              │
│                                                                  │
│ 1. 构建 actives 和 inactives 路径                                │
│ 2. 若存在 active 路径连接 X 和 Y(给定 Z),则 X ≁ Y | Z        │
│ 3. 若所有路径都是 inactive,则 X ⟂ Y | Z                         │
│                                                                  │
│ Active条件:                                                     │
│   - 链式 X→Z→Y 或 X←Z←Y:Z被观测 → active                       │
│   - 碰撞 X→Z←Y:Z及其后代都未观测 → active                       │
└─────────────────────────────────────────────────────────────────┘

Python 实现

import numpy as np
from collections import defaultdict, deque
 
class DSeparator:
    """D-分离判定器"""
    
    def __init__(self, dag):
        """
        dag: dict, 邻接表表示的有向无环图
              e.g., {'X': ['Y', 'Z'], 'Y': ['W']}
              表示 X→Y, X→Z, Y→W
        """
        self.dag = dag
        self.parents = self._compute_parents()
    
    def _compute_parents(self):
        """计算每个节点的父节点集合"""
        parents = defaultdict(set)
        for node in self.dag:
            for child in self.dag[node]:
                parents[child].add(node)
        return parents
    
    def is_d_separated(self, X, Y, Z):
        """
        判断 X 和 Y 是否在给定 Z 的条件下 D-分离
        
        Args:
            X, Y, Z: set, 节点集合
        """
        # 合并所有集合
        all_nodes = set(self.dag.keys())
        for neighbors in self.dag.values():
            all_nodes.update(neighbors)
        
        # 祖先节点集合(包括 Z)
        ancestors = self._get_ancestors(Z | {X, Y})
        
        # 在证据节点下考虑祖先子图
        considered = ancestors | Z
        
        # 从 X 开始BFS,检查是否有 active 路径到达 Y
        visited = set()
        queue = [(x, 'descendant') for x in X]
        
        while queue:
            node, direction = queue.pop(0)
            
            if node in visited:
                continue
            visited.add(node)
            
            if node in Y and node not in Z:
                return False  # 发现 active 路径,X 和 Y 不 D-分离
            
            # 探索相邻节点
            if direction == 'descendant':
                # 沿边向下(向子节点方向)
                for child in self.dag.get(node, []):
                    if child not in visited:
                        if child in Z:
                            # 碰撞节点被观测,路径阻断
                            continue
                        queue.append((child, 'descendant'))
                
                # 沿边上(向父节点方向)
                for parent in self.parents[node]:
                    if parent not in visited:
                        queue.append((parent, 'ancestor'))
            
            else:  # direction == 'ancestor'
                # 只允许向上移动
                for parent in self.parents[node]:
                    if parent not in visited:
                        queue.append((parent, 'ancestor'))
        
        return True  # 没有 active 路径,X 和 Y D-分离
    
    def _get_ancestors(self, nodes):
        """计算节点的祖先集合"""
        ancestors = set(nodes)
        queue = list(nodes)
        
        while queue:
            node = queue.pop(0)
            for parent in self.parents[node]:
                if parent not in ancestors:
                    ancestors.add(parent)
                    queue.append(parent)
        
        return ancestors
 
# 示例用法
dag = {
    'Intelligence': ['Grade', 'SAT'],
    'Difficulty': ['Grade'],
    'Grade': ['Letter'],
    'SAT': [],
    'Letter': []
}
 
dsep = DSeparator(dag)
 
# 测试碰撞结构:Grade 是碰撞节点
print(dsep.is_d_separated({'Intelligence'}, {'Difficulty'}, set()))  # True: 未观测Grade
print(dsep.is_d_separated({'Intelligence'}, {'Difficulty'}, {'Grade'}))  # False: 观测Grade后不独立
 
# 测试链式结构:Grade 在中间
print(dsep.is_d_separated({'Intelligence'}, {'Letter'}, {'Grade'}))  # True: 给定Grade后独立

1.5 图模型的学习与推断概述

任务描述方法
结构学习从数据中学习图的拓扑结构PC算法、HC算法、贪婪搜索
参数学习学习条件概率表或势函数参数MLE、贝叶斯估计、EM算法
精确推断计算精确的边缘/后验概率变量消除、信念传播、Junction Tree
近似推断处理大规模/复杂模型变分推断、MCMC采样、BP近似

2. 贝叶斯网络

2.1 有向无环图结构

贝叶斯网络(Bayesian Network),又称信念网络或有向无环图模型,是一种用有向无环图(DAG)表示随机变量之间条件依赖关系的概率图模型。1

DAG 的性质

  • 没有有向环
  • 可以进行拓扑排序
  • 每个节点的条件分布只依赖于其父节点

示例:学生成绩模型

                    [Intelligence] ───┬──→ [Grade]
                         │            │
                    [Difficulty] ─────┘    │
                         │                 ▼
                         │            [Letter]
                         │                 
                         └───────→ [SAT]

这个网络表示:

  • 智力(Intelligence)独立于其他变量
  • 难度(Difficulty)独立于其他变量
  • 成绩(Grade)依赖于智力和难度
  • SAT分数依赖于智力
  • 推荐信(Letter)依赖于成绩

2.2 条件概率表(CPT)

在贝叶斯网络中,每个节点都有一个条件概率表(CPT),描述该节点在其父节点取不同值时的条件概率分布。

CPT 示例(成绩 Grade)

IntelligenceDifficultyP(Grade=Good)P(Grade=Medium)P(Grade=Bad)
HighEasy0.900.080.02
HighHard0.700.250.05
LowEasy0.600.300.10
LowHard0.200.400.40

2.3 因子分解

贝叶斯网络的联合概率分布可以分解为所有节点条件概率的乘积:

其中 的父节点集合。

对于学生成绩模型

因子分解的优势

  • 联合分布需要 个参数
  • 分解后: 个参数
  • 指数级减少参数量

2.4 朴素贝叶斯分类器

朴素贝叶斯是贝叶斯网络的特例,假设所有特征条件独立于目标变量:

import numpy as np
from collections import defaultdict
 
class NaiveBayesClassifier:
    """朴素贝叶斯分类器"""
    
    def __init__(self, alpha=1.0):
        """
        Args:
            alpha: 拉普拉斯平滑参数
        """
        self.alpha = alpha
        self.classes = None
        self.class_prior = {}  # P(Y)
        self.feature_probs = {}  # P(X_i | Y)
    
    def fit(self, X, y):
        """
        训练朴素贝叶斯分类器
        
        Args:
            X: np.array, shape (n_samples, n_features), 特征矩阵
            y: np.array, shape (n_samples,), 标签
        """
        n_samples = len(y)
        self.classes, class_counts = np.unique(y, return_counts=True)
        n_classes = len(self.classes)
        
        # 计算类先验 P(Y)
        for c, count in zip(self.classes, class_counts):
            self.class_prior[c] = (count + self.alpha) / (n_samples + n_classes * self.alpha)
        
        # 计算条件概率 P(X_i | Y)
        self.feature_probs = {}
        n_features = X.shape[1]
        
        for c in self.classes:
            # 获取类别 c 的样本
            X_c = X[y == c]
            n_c = len(X_c)
            
            for feat_idx in range(n_features):
                # 统计每个特征值在该类别中的出现次数
                feature_vals, counts = np.unique(X_c[:, feat_idx], return_counts=True)
                
                # 存储为字典:{(class, feature_idx, feature_val): probability}
                for val, count in zip(feature_vals, counts):
                    key = (c, feat_idx, val)
                    # 拉普拉斯平滑
                    prob = (count + self.alpha) / (n_c + self.alpha * 2)
                    self.feature_probs[key] = prob
        
        return self
    
    def _compute_log_posterior(self, x, y_class):
        """计算对数后验概率 P(y_class | x)"""
        log_prob = np.log(self.class_prior[y_class])
        
        for feat_idx, feat_val in enumerate(x):
            key = (y_class, feat_idx, feat_val)
            if key in self.feature_probs:
                log_prob += np.log(self.feature_probs[key])
            else:
                # 未见过的特征值,使用平滑
                log_prob += np.log(self.alpha / (2 * self.alpha))
        
        return log_prob
    
    def predict(self, X):
        """预测类别"""
        predictions = []
        for x in X:
            log_probs = [self._compute_log_posterior(x, c) for c in self.classes]
            predictions.append(self.classes[np.argmax(log_probs)])
        
        return np.array(predictions)
    
    def predict_proba(self, X):
        """预测类别概率"""
        n_samples = len(X)
        n_classes = len(self.classes)
        probs = np.zeros((n_samples, n_classes))
        
        for i, x in enumerate(X):
            log_probs = [self._compute_log_posterior(x, c) for c in self.classes]
            # 归一化
            log_probs = np.array(log_probs)
            log_probs -= np.max(log_probs)  # 数值稳定性
            probs[i] = np.exp(log_probs)
            probs[i] /= probs[i].sum()
        
        return probs
 
# 使用示例
if __name__ == "__main__":
    from sklearn.datasets import make_classification
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score
    
    # 生成示例数据
    X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # 离散化(朴素贝叶斯需要离散特征)
    X_train_discrete = np.round(X_train * 2).astype(int) % 2
    X_test_discrete = np.round(X_test * 2).astype(int) % 2
    
    # 训练和预测
    clf = NaiveBayesClassifier(alpha=1.0)
    clf.fit(X_train_discrete, y_train)
    y_pred = clf.predict(X_test_discrete)
    
    print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")

2.5 马尔可夫链与动态贝叶斯网络

**隐马尔可夫模型(HMM)**是贝叶斯网络在时序数据上的扩展:

时间片 t-1           时间片 t
┌─────────┐         ┌─────────┐
│  X_{t-1} │────────→│   X_t   │
│  Y_{t-1} │    ↑    │   Y_t   │
└─────────┘    │    └─────────┘
               │         ↑
               └─────────┘

更多内容请参见:HMM详解深度学习中的HMM


3. 马尔可夫随机场

3.1 无向图模型

马尔可夫随机场(Markov Random Field, MRF),又称马尔可夫网络或条件随机场的特例,是一种用无向图表示变量之间依赖关系的概率图模型。1

与贝叶斯网络的主要区别

  • 边无方向,无法表示因果关系
  • 可以表示循环依赖
  • 使用势函数而非条件概率

示例:图像去噪模型

像素依赖关系(相邻像素相关)
┌───┬───┬───┐
│   │   │   │
├───┼───┼───┤
│   │ ● │   │  ● 为观测像素
├───┼───┼───┤
│   │   │   │
└───┴───┴───┘

3.2 势函数与配分函数

MRF 的联合概率分布定义为:

其中:

  • 是定义在团 上的势函数(Potential Function)
  • 配分函数(Partition Function),用于归一化:

势函数的性质

  • 非负:
  • 不必是概率分布
  • 通常取指数形式:,其中 能量函数

常见的势函数形式

类型势函数形式应用场景
对数线性模型统计建模
高斯势函数连续变量
Potts势函数图像分割

3.3 团与最大团

团(Clique):无向图中两两相邻的节点集合。

最大团(Maximal Clique):不是任何其他团子集的团。

团的示例:
    ●──●──●

单个节点 {A}, {B}, {C} 是团
边 {A,B}, {B,C} 是团
三元组 {A,B,C} 是最大团

Hammersley-Clifford 定理:若正分布 对所有状态成立,则它是一个 MRF 当且仅当它可以分解为最大团势函数的乘积:

这将 MRF 与吉布斯分布联系起来,形成吉布斯随机场(Gibbs Random Field)

3.4 马尔可夫随机场的实现

import numpy as np
from itertools import combinations
 
class MarkovRandomField:
    """马尔可夫随机场实现"""
    
    def __init__(self, graph, features):
        """
        Args:
            graph: 无向图的邻接表,dict {node: [neighbors]}
            features: 每个节点的特征矩阵,dict {node: np.array}
        """
        self.graph = graph
        self.features = features
        self.nodes = list(graph.keys())
        self.max_cliques = self._find_max_cliques()
        self.weights = {}  # 势函数权重
    
    def _find_max_cliques(self):
        """寻找所有最大团(使用Bron-Kerbosch算法简化版)"""
        # 对于简单的网格图,直接识别
        if self._is_grid_graph():
            return self._grid_max_cliques()
        
        # 通用方法:枚举所有极大团
        max_cliques = []
        cliques = self._bron_kerbosch([])
        
        # 过滤出极大的
        for c in cliques:
            is_maximal = True
            for other in cliques:
                if c < other and set(c).issubset(set(other)):
                    is_maximal = False
                    break
            if is_maximal:
                max_cliques.append(tuple(c))
        
        return max_cliques
    
    def _is_grid_graph(self):
        """检查是否为网格图"""
        return False  # 简化实现
    
    def _bron_kerbosch(self, R, P=None, X=None):
        """Bron-Kerbosch算法找所有极大团"""
        if P is None:
            P = set(self.nodes)
        if X is None:
            X = set()
        
        if not P and not X:
            return [R]
        
        cliques = []
        P_list = list(P)
        
        for v in P_list:
            cliques.extend(self._bron_kerbosch(
                R + [v],
                P & set(self.graph[v]),
                X & set(self.graph[v])
            ))
            P.remove(v)
            X.add(v)
        
        return cliques
    
    def compute_potential(self, clique, assignment):
        """计算团的势函数"""
        # 使用对数线性模型
        # ψ(clique) = exp(w^T · f(clique, assignment))
        if tuple(clique) not in self.weights:
            self.weights[tuple(clique)] = np.random.randn(len(clique) * 2)
        
        w = self.weights[tuple(clique)]
        
        # 构造特征向量
        features = []
        for node in clique:
            feat = self.features.get(node, np.array([0, 0]))
            if isinstance(feat, np.ndarray):
                features.extend(feat[:2])
            else:
                features.extend([feat, 0])
        
        # 简化的势函数
        energy = sum(assignment.get(n, 0) for n in clique)
        return np.exp(energy)
    
    def partition_function(self):
        """计算配分函数(精确计算,仅对小规模图有效)"""
        Z = 0
        n_nodes = len(self.nodes)
        
        # 遍历所有可能的赋值
        for assignment in range(2 ** n_nodes):
            bits = format(assignment, f'0{n_nodes}b')
            state = {node: int(bits[i]) for i, node in enumerate(self.nodes)}
            
            # 计算联合势
            joint_potential = 1
            for clique in self.max_cliques:
                joint_potential *= self.compute_potential(clique, state)
            
            Z += joint_potential
        
        return Z
    
    def gibbs_sample(self, current_state, n_iter=1000):
        """吉布斯采样"""
        state = current_state.copy()
        
        for _ in range(n_iter):
            for node in self.nodes:
                # 计算 P(X_node = 1 | 其他变量)
                state[0] = 1
                p1 = self._compute_conditional(node, state)
                
                state[0] = 0
                p0 = self._compute_conditional(node, state)
                
                # 归一化
                Z = p0 + p1
                p1 /= Z
                
                # 采样
                state[node] = 1 if np.random.random() < p1 else 0
        
        return state
    
    def _compute_conditional(self, node, state):
        """计算条件概率 P(X_node | 其他变量) ∝ 势函数乘积"""
        prob = 1.0
        
        # 找到包含该节点的所有最大团
        for clique in self.max_cliques:
            if node in clique:
                prob *= self.compute_potential(clique, state)
        
        return prob
 
# 图像去噪示例
class ImageDenoisingMRF:
    """用于图像去噪的MRF"""
    
    def __init__(self, beta=1.0, sigma=0.5):
        """
        Args:
            beta: 相似像素惩罚参数
            sigma: 噪声标准差
        """
        self.beta = beta
        self.sigma = sigma
    
    def _build_grid_graph(self, image):
        """构建网格图"""
        h, w = image.shape
        graph = {}
        
        for i in range(h):
            for j in range(w):
                node = (i, j)
                neighbors = []
                
                if i > 0:
                    neighbors.append((i-1, j))
                if i < h-1:
                    neighbors.append((i+1, j))
                if j > 0:
                    neighbors.append((i, j-1))
                if j < w-1:
                    neighbors.append((i, j+1))
                
                graph[node] = neighbors
        
        return graph
    
    def energy(self, x, y):
        """计算MRF能量函数"""
        h, w = x.shape
        energy = 0
        
        # 数据项
        energy += np.sum((x - y) ** 2) / (2 * self.sigma ** 2)
        
        # 平滑项(4邻域)
        for i in range(h):
            for j in range(w):
                if i > 0:
                    energy += self.beta * (x[i,j] - x[i-1,j]) ** 2
                if j > 0:
                    energy += self.beta * (x[i,j] - x[i,j-1]) ** 2
        
        return energy
    
    def denoise(self, noisy_image, n_iter=100):
        """ICM(迭代条件模式)算法去噪"""
        x = noisy_image.copy()
        h, w = x.shape
        
        for _ in range(n_iter):
            changed = False
            
            for i in range(h):
                for j in range(w):
                    current = x[i,j]
                    
                    # 尝试翻转
                    best_val = self.energy(x, noisy_image)
                    best_state = current
                    
                    for val in [0, 1]:
                        if val != current:
                            x[i,j] = val
                            e = self.energy(x, noisy_image)
                            if e < best_val:
                                best_val = e
                                best_state = val
                                changed = True
                    
                    x[i,j] = best_state
            
            if not changed:
                break
        
        return x

3.5 条件随机场(CRF)

**条件随机场(Conditional Random Field, CRF)**是一种判别式无向图模型,直接建模条件概率 ,而非联合分布。2

CRF 广泛应用于序列标注任务(NER、词性标注等)。更多内容请参见:CRF详解序列CRF


4. 因子图与消息传递

4.1 因子图表示

**因子图(Factor Graph)**是一种二分图表示方法,同时表示变量和因子(势函数),使消息传递算法的推导更加清晰。3

有向图表示:                    因子图表示:
    A → B → C                      A ─┐
        ↓                          B ─┼─ f₁
        D                          C ─┤
                                     D ─┘

    变量节点:○ A, ○ B, ○ C, ○ D
    因子节点:□ f₁(A,B), □ f₂(B,C,D)

因子图的形式化定义

因子图是一个二分图 ,其中:

  • :变量节点集合
  • :因子节点集合
  • :边,仅连接变量节点和因子节点

联合分布表示

其中 是因子 连接的变量集合。

4.2 Sum-Product算法(和积算法)

Sum-Product算法(又称信念传播)是在因子图上计算边缘概率的消息传递算法。3

核心思想

对于变量 的边缘概率:

其中 是从因子 传递给变量 的消息。

消息传递规则

┌─────────────────────────────────────────────────────────────────┐
│                    Sum-Product 消息传递                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  变量 → 因子消息:                                                │
│                                                                  │
│      ○ x ─────────→ □ f                                          │
│      ↑           │                                               │
│      │           │                                               │
│      │    μ_{x→f}(x) = ∏_{h∈ne(x)\{f}} μ_{h→x}(x)               │
│      │                                                               │
│      │                                                               │
│      ○ h₁ ○ h₂                                                    │
│                                                                  │
│  因子 → 变量消息:                                                │
│                                                                  │
│      □ f ─────────→ ○ x                                          │
│      │                                                               │
│      │    μ_{f→x}(x) = ∑_{X_{ne(f)\{x}}} f(X_{ne(f)})            │
│      │                  ∏_{y∈ne(f)\{x}} μ_{y→f}(y)               │
│      │                                                               │
│      ○ y₁ ○ y₂ ○ y₃                                              │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

4.3 Sum-Product算法实现

import numpy as np
from collections import defaultdict
 
class FactorGraph:
    """因子图及Sum-Product算法实现"""
    
    def __init__(self):
        self.variables = {}  # {name: num_states}
        self.factors = {}    # {name: factor_table}
        self.var_to_factors = defaultdict(list)  # 变量 -> 因子
        self.factor_to_vars = {}  # 因子 -> 变量
        self.messages = {}  # 缓存消息
    
    def add_variable(self, name, n_states=2):
        """添加变量节点"""
        self.variables[name] = n_states
    
    def add_factor(self, name, var_names, table):
        """
        添加因子节点
        
        Args:
            name: 因子名称
            var_names: 变量名列表
            table: numpy array, 势函数值
        """
        self.factors[name] = table
        self.factor_to_vars[name] = var_names
        
        for var in var_names:
            self.var_to_factors[var].append(name)
    
    def variable_to_factor_message(self, var_name, factor_name):
        """计算变量到因子的消息 μ_{var→factor}"""
        cache_key = (var_name, factor_name)
        if cache_key in self.messages:
            return self.messages[cache_key]
        
        n_states = self.variables[var_name]
        message = np.ones(n_states)
        
        # 乘以来自其他因子的消息
        for other_factor in self.var_to_factors[var_name]:
            if other_factor != factor_name:
                msg = self.factor_to_variable_message(other_factor, var_name)
                message *= msg
        
        self.messages[cache_key] = message
        return message
    
    def factor_to_variable_message(self, factor_name, var_name):
        """计算因子到变量的消息 μ_{factor→var}"""
        cache_key = (factor_name, var_name)
        if cache_key in self.messages:
            return self.messages[cache_key]
        
        factor = self.factors[factor_name]
        var_names = self.factor_to_vars[factor_name]
        var_idx = var_names.index(var_name)
        
        # 计算因子的边缘
        # 消息 = 对其他变量求和(乘以它们的传入消息)
        message = np.zeros(self.variables[var_name])
        
        # 遍历该变量的所有可能值
        for val in range(self.variables[var_name]):
            total = 0.0
            
            # 遍历因子表中的所有条目
            for idx in np.ndindex(*[self.variables[v] for v in var_names]):
                # 如果该索引在变量 var_name 处等于 val
                if idx[var_idx] == val:
                    # 计算该条目的权重
                    weight = factor[idx]
                    
                    # 乘以其他变量的消息
                    for i, v in enumerate(var_names):
                        if i != var_idx:
                            incoming = self.variable_to_factor_message(v, factor_name)
                            weight *= incoming[idx[i]]
                    
                    total += weight
            
            message[val] = total
        
        # 归一化(可选)
        if message.sum() > 0:
            message /= message.sum()
        
        self.messages[cache_key] = message
        return message
    
    def compute_belief(self, var_name):
        """计算变量的边缘分布(信念)"""
        n_states = self.variables[var_name]
        belief = np.ones(n_states)
        
        for factor in self.var_to_factors[var_name]:
            msg = self.factor_to_variable_message(factor, var_name)
            belief *= msg
        
        # 归一化
        belief /= belief.sum()
        return belief
    
    def run_sum_product(self, max_iter=100):
        """运行Sum-Product算法直到收敛"""
        # 清除缓存的消息
        self.messages = {}
        
        # 迭代更新
        for iteration in range(max_iter):
            old_messages = dict(self.messages)
            
            # 更新所有因子到变量的消息
            for factor_name in self.factors:
                for var_name in self.factor_to_vars[factor_name]:
                    self.factor_to_variable_message(factor_name, var_name)
            
            # 检查收敛
            converged = True
            for key in old_messages:
                if key not in self.messages:
                    converged = False
                    break
                if not np.allclose(old_messages[key], self.messages[key], atol=1e-6):
                    converged = False
                    break
            
            if converged:
                print(f"Converged at iteration {iteration}")
                break
        
        return {var: self.compute_belief(var) for var in self.variables}
 
# 示例:简单贝叶斯网络
def simple_bayes_net_example():
    """示例:P(A, B, C) = P(A) P(B|A) P(C|B)"""
    fg = FactorGraph()
    
    # 添加变量(都是二元)
    fg.add_variable('A', n_states=2)
    fg.add_variable('B', n_states=2)
    fg.add_variable('C', n_states=2)
    
    # 添加因子
    # P(A)
    pA = np.array([0.7, 0.3])
    fg.add_factor('fA', ['A'], pA)
    
    # P(B|A)
    pB_given_A = np.array([
        [0.9, 0.1],  # A=0: P(B=0|A=0), P(B=1|A=0)
        [0.2, 0.8]   # A=1: P(B=0|A=1), P(B=1|A=1)
    ])
    fg.add_factor('fB', ['A', 'B'], pB_given_A)
    
    # P(C|B)
    pC_given_B = np.array([
        [0.8, 0.2],  # B=0: P(C=0|B=0), P(C=1|B=0)
        [0.1, 0.9]   # B=1: P(C=0|B=1), P(C=1|B=1)
    ])
    fg.add_factor('fC', ['B', 'C'], pC_given_B)
    
    # 运行Sum-Product
    beliefs = fg.run_sum_product()
    
    print("Marginal distributions:")
    for var, belief in beliefs.items():
        print(f"P({var}) = {belief}")
    
    return beliefs
 
# 运行示例
if __name__ == "__main__":
    simple_bayes_net_example()

4.4 Max-Product算法(最大乘积算法)

Max-Product算法用于计算最大后验概率(MAP)推断,即找到最可能的变量赋值:

消息传递规则(与Sum-Product类似,但用max代替sum):

Max-Sum算法:为避免数值下溢,通常使用对数形式:

class MaxProductAlgorithm:
    """Max-Product算法(MAP推断)"""
    
    def __init__(self, factor_graph):
        self.fg = factor_graph
        self.messages = {}
        self.backpointers = {}  # 用于回溯
    
    def variable_to_factor_message(self, var_name, factor_name):
        """变量到因子的消息(与Sum-Product相同)"""
        return self.fg.variable_to_factor_message(var_name, factor_name)
    
    def factor_to_variable_message(self, factor_name, var_name):
        """因子到变量的消息(使用max而非sum)"""
        cache_key = (factor_name, var_name)
        if cache_key in self.messages:
            return self.messages[cache_key]
        
        factor = self.fg.factors[factor_name]
        var_names = self.fg.factor_to_vars[factor_name]
        var_idx = var_names.index(var_name)
        
        n_states = self.fg.variables[var_name]
        message = np.zeros(n_states)
        argmax_table = {}  # 存储每个状态的argmax
        
        for val in range(n_states):
            best = -np.inf
            
            # 遍历因子表中的所有条目
            for idx in np.ndindex(*[self.fg.variables[v] for v in var_names]):
                if idx[var_idx] == val:
                    weight = np.log(factor[idx] + 1e-10)  # 对数势函数
                    
                    for i, v in enumerate(var_names):
                        if i != var_idx:
                            incoming = self.fg.variable_to_factor_message(v, factor_name)
                            weight += np.log(incoming[idx[i]] + 1e-10)
                    
                    if weight > best:
                        best = weight
                        argmax_table[val] = idx
            
            message[val] = np.exp(best) if best > -np.inf else 0.0
            self.backpointers[(factor_name, var_name, val)] = argmax_table.get(val)
        
        self.messages[cache_key] = message
        return message
    
    def compute_map(self):
        """计算MAP赋值"""
        # 运行消息传递
        self.fg.messages = {}
        self.messages = {}
        
        # 迭代更新
        for _ in range(len(self.fg.variables)):
            for factor_name in self.fg.factors:
                for var_name in self.fg.factor_to_vars[factor_name]:
                    self.factor_to_variable_message(factor_name, var_name)
        
        # 计算每个变量的MAP值
        map_assignment = {}
        
        for var_name in self.fg.variables:
            belief = np.zeros(self.fg.variables[var_name])
            
            for factor in self.fg.var_to_factors[var_name]:
                msg = self.factor_to_variable_message(factor, var_name)
                belief *= msg
            
            map_assignment[var_name] = np.argmax(belief)
        
        return map_assignment
 
    def decode_with_backtracking(self):
        """回溯找到MAP赋值(树结构保证正确)"""
        map_assignment = {}
        
        for var_name in self.fg.variables:
            # 选择消息乘积最大的状态
            belief = np.zeros(self.fg.variables[var_name])
            
            for factor in self.fg.var_to_factors[var_name]:
                msg = self.factor_to_variable_message(factor, var_name)
                belief *= msg
            
            map_assignment[var_name] = np.argmax(belief)
        
        return map_assignment

4.5 变量消除

**变量消除(Variable Elimination, VE)**是精确推断的基本方法,通过按顺序消除变量来计算边缘概率。1

算法步骤

┌─────────────────────────────────────────────────────────────────┐
│                     变量消除算法                                  │
├─────────────────────────────────────────────────────────────────┤
│ 输入:因子集合 F,消除变量顺序 Z = [z₁, z₂, ..., zₖ]             │
│                                                                  │
│ for each z in Z:                                                 │
│     1. 收集:找出所有包含 z 的因子 F_z                           │
│     2. 乘积:将 F_z 中的因子相乘得到 τ_z                         │
│     3. 边缘化:对 τ_z 求和(积分)消除 z,得到 m_z               │
│     4. 替换:用 m_z 替换 F_z,更新因子集合 F                     │
│                                                                  │
│ 输出:边缘概率或归一化常数                                         │
└─────────────────────────────────────────────────────────────────┘
import numpy as np
from functools import reduce
 
class VariableElimination:
    """变量消除算法实现"""
    
    def __init__(self, factors):
        """
        Args:
            factors: list of (var_names, table) tuples
        """
        self.original_factors = factors
        self.factors = list(factors)
    
    def eliminate(self, var_name):
        """消除一个变量"""
        # 1. 收集所有包含该变量的因子
        relevant = [(vars, table) for vars, table in self.factors 
                   if var_name in vars]
        
        # 2. 移除这些因子
        self.factors = [(vars, table) for vars, table in self.factors 
                        if var_name not in vars]
        
        if not relevant:
            return
        
        # 3. 乘积所有相关因子
        product_vars = list(reduce(
            lambda a, b: set(a[0]) | set(b[0]), 
            [(vars,) for vars, _ in relevant]
        ))
        
        # 简化版:假设最多两个因子
        if len(relevant) == 1:
            (v1, t1), = relevant
            return (v1, t1)
        
        # 乘积两个因子
        def multiply_factors(f1, f2):
            v1, t1 = f1
            v2, t2 = f2
            
            # 合并变量
            common = set(v1) & set(v2)
            new_vars = [v for v in v1 if v in v2]  # 保持顺序
            
            # 计算乘积
            shape = [2] * (len(v1) + len(v2) - len(common))
            result = np.zeros(shape)
            
            for idx1 in np.ndindex(*[2]*len(v1)):
                for idx2 in np.ndindex(*[2]*len(v2)):
                    # 检查一致性
                    consistent = True
                    assign = {}
                    for i, v in enumerate(v1):
                        assign[v] = idx1[i]
                    for i, v in enumerate(v2):
                        if v in assign and assign[v] != idx2[i]:
                            consistent = False
                            break
                        assign[v] = idx2[i]
                    
                    if consistent:
                        idx = tuple(assign[v] for v in v1 + v2 if v not in common)
                        result[idx] = t1[idx1] * t2[idx2]
            
            return (v1 + [v for v in v2 if v not in v1], result)
        
        product = reduce(multiply_factors, relevant)
        
        # 4. 边缘化(求和)
        result_vars = [v for v in product[0] if v != var_name]
        result_table = np.sum(product[1], axis=product[0].index(var_name))
        
        # 5. 添加结果因子
        if result_table.sum() > 0:
            self.factors.append((result_vars, result_table))
        
        return (result_vars, result_table)
    
    def query(self, query_vars, evidence=None, elimination_order=None):
        """
        计算后验边缘概率 P(query_vars | evidence)
        
        Args:
            query_vars: 查询变量列表
            evidence: 证据变量字典 {var: value}
            elimination_order: 消除顺序(可选)
        
        Returns:
            归一化的概率表
        """
        if evidence is None:
            evidence = {}
        
        # 重置因子
        self.factors = list(self.original_factors)
        
        # 添加证据因子(固定变量值)
        for var, val in evidence.items():
            # 创建一个只有该变量的因子
            table = np.array([1.0, 0.0])
            if val == 1:
                table = np.array([0.0, 1.0])
            self.factors.append(([var], table))
        
        # 确定消除顺序
        all_vars = set()
        for vars, _ in self.factors:
            all_vars.update(vars)
        
        eliminate_vars = [v for v in all_vars 
                         if v not in query_vars and v not in evidence]
        
        if elimination_order is None:
            elimination_order = eliminate_vars
        else:
            elimination_order = [v for v in elimination_order if v in eliminate_vars]
        
        # 消除变量
        for var in elimination_order:
            self.eliminate(var)
        
        # 剩余因子即为查询结果
        if self.factors:
            result_vars = list(reduce(lambda a, b: set(a) | set(b), 
                                      [set(v) for v, _ in self.factors]))
            result = self.factors[0][1]
            for _, table in self.factors[1:]:
                result = np.outer(result, table).flatten()
            
            # 归一化
            if result.sum() > 0:
                result /= result.sum()
            
            return result
        
        return np.array([1.0])

更多内容请参见:因子图详解置信传播理论


5. 推断方法

5.1 精确推断

5.1.1 变量消除法(Variable Elimination)

变量消除是最基本的精确推断方法,时间复杂度为 ,其中 是变量数, 是每个变量的取值数, 是最大团的大小。

优缺点

优点缺点
实现简单每次查询需要重新计算
内存需求低时间复杂度指数级
可优化消除顺序不适合大规模问题

5.1.2 信念传播(Belief Propagation)

信念传播利用消息传递避免重复计算,在树结构图上具有线性时间复杂度。

适用于

  • 树结构(有向或无向)
  • 因子图
  • 链式结构

消息计算

收敛性

  • 在树结构上:必定收敛到精确解
  • 在有环图上(Loopy BP):可能不收敛或收敛到局部最优

更多内容请参见:深度学习中的信念传播

5.1.3 Junction Tree算法

Junction Tree算法将一般图转换为树结构,从而可以使用信念传播。

步骤

1. 道德化(Moralization):将有向图转为无向图(去除方向,连接父节点)
2. 三角化(Triangulation):添加边使图弦数为零
3. 构建团树(Clique Tree/Junction Tree):团作为节点,最大团树
4. 消息传递:在团树上运行信念传播
class JunctionTree:
    """Junction Tree算法实现"""
    
    def __init__(self, dag):
        """
        Args:
            dag: dict, 有向无环图的邻接表
        """
        self.dag = dag
        self.moralized = None
        self.cliques = []
        self.junction_tree = None
    
    def moralize(self):
        """道德化:将DAG转为无向图并连接父节点"""
        moral = defaultdict(set)
        
        # 添加所有边(无向)
        for node in self.dag:
            for parent in self.dag[node]:
                moral[node].add(parent)
                moral[parent].add(node)
        
        # 连接同一节点的父节点
        for node in self.dag:
            parents = list(self.dag[node])
            for i in range(len(parents)):
                for j in range(i+1, len(parents)):
                    moral[parents[i]].add(parents[j])
                    moral[parents[j]].add(parents[i])
        
        self.moralized = dict(moral)
        return self.moralized
    
    def triangulate(self):
        """三角化(简化版)"""
        # 使用最大Cardinality搜索
        # 这里是简化实现
        return self.moralized
    
    def find_cliques(self):
        """寻找所有最大团"""
        # 使用Bron-Kerbosch算法
        # 这里是简化实现
        pass
    
    def build_junction_tree(self):
        """构建团树"""
        # 团树满足运行intersection性质
        pass
    
    def run_belief_propagation(self):
        """在团树上运行信念传播"""
        pass

5.2 近似推断

5.2.1 变分推断(Variational Inference)

变分推断将推断问题转化为优化问题,通过引入变分分布近似真实后验分布。4

基本思想

设真实后验为 ,引入变分分布 ,通过最小化 KL散度:

其中 变分下界(ELBO)

平均场变分推断

假设变分分布可分解:

最优的 满足:

import numpy as np
 
class MeanFieldVariationalInference:
    """平均场变分推断"""
    
    def __init__(self, model, n_iter=100, tol=1e-4):
        """
        Args:
            model: 概率模型,包含 log_joint() 方法
            n_iter: 最大迭代次数
            tol: 收敛阈值
        """
        self.model = model
        self.n_iter = n_iter
        self.tol = tol
        self.q_params = {}  # 变分参数
        self.elbo_history = []
    
    def fit(self, X):
        """
        运行变分推断
        
        Args:
            X: 观测数据
        """
        # 初始化变分参数
        self._initialize_params(X)
        
        for iteration in range(self.n_iter):
            # 更新每个变分因子
            self._coordinate_ascent(X)
            
            # 计算ELBO
            elbo = self._compute_elbo(X)
            self.elbo_history.append(elbo)
            
            # 检查收敛
            if len(self.elbo_history) > 1:
                if abs(elbo - self.elbo_history[-2]) < self.tol:
                    print(f"Converged at iteration {iteration}")
                    break
        
        return self
    
    def _initialize_params(self, X):
        """初始化变分参数"""
        pass  # 由子类实现
    
    def _coordinate_ascent(self, X):
        """坐标上升更新"""
        pass  # 由子类实现
    
    def _compute_elbo(self, X):
        """计算ELBO"""
        pass  # 由子类实现
    
    def get_posterior(self):
        """获取近似后验分布"""
        return {k: v for k, v in self.q_params.items()}
 
 
class VariationalBayesGaussianMixture(MeanFieldVariationalInference):
    """高斯混合模型的变分贝叶斯推断"""
    
    def __init__(self, n_components=3, n_iter=100, tol=1e-4):
        super().__init__(None, n_iter, tol)
        self.n_components = n_components
        self.X = None
    
    def fit(self, X):
        """拟合GMM"""
        self.X = X
        n_samples, n_features = X.shape
        
        # 初始化参数
        self.q_params = {
            'phi': np.ones((n_samples, self.n_components)) / self.n_components,
            'mu': X[np.random.choice(n_samples, self.n_components)],
            'Sigma': [np.eye(n_features) for _ in range(self.n_components)],
            'alpha': np.ones(self.n_components),
            'beta': np.ones(self.n_components),
            'nu': n_features * np.ones(self.n_components)
        }
        
        for iteration in range(self.n_iter):
            # E步:更新职责
            self._e_step()
            
            # M步:更新参数
            self._m_step()
            
            # 计算ELBO
            elbo = self._compute_elbo(X)
            self.elbo_history.append(elbo)
            
            if len(self.elbo_history) > 1:
                if abs(elbo - self.elbo_history[-2]) < self.tol:
                    print(f"Converged at iteration {iteration}")
                    break
        
        return self
    
    def _e_step(self):
        """E步:更新变分分布"""
        X = self.X
        n_samples, n_features = X.shape
        
        for k in range(self.n_components):
            mu_k = self.q_params['mu'][k]
            Sigma_k = self.q_params['Sigma'][k]
            
            # 计算似然
            diff = X - mu_k
            cov_term = np.sum(diff @ np.linalg.inv(Sigma_k) * diff, axis=1)
            log_lik = -0.5 * (n_features * np.log(2*np.pi) + 
                             np.log(np.linalg.det(Sigma_k)) + 
                             cov_term)
            
            # 更新phi
            self.q_params['phi'][:, k] = np.exp(log_lik)
        
        # 归一化
        self.q_params['phi'] /= self.q_params['phi'].sum(axis=1, keepdims=True)
    
    def _m_step(self):
        """M步:更新变分参数"""
        X = self.X
        n_samples, n_features = X.shape
        
        N_k = self.q_params['phi'].sum(axis=0)
        
        for k in range(self.n_components):
            # 更新均值
            self.q_params['mu'][k] = (self.q_params['phi'][:, k:k+1].T @ X) / N_k[k]
            
            # 更新协方差(简化版)
            diff = X - self.q_params['mu'][k]
            cov = (self.q_params['phi'][:, k:k+1].T * diff).T @ diff / N_k[k]
            self.q_params['Sigma'][k] = cov + np.eye(n_features) * 1e-6
    
    def _compute_elbo(self, X):
        """计算ELBO(简化版)"""
        return np.sum(self.q_params['phi'] * np.log(self.q_params['phi'] + 1e-10))
    
    def predict(self, X):
        """预测"""
        self.X = X
        self._e_step()
        return np.argmax(self.q_params['phi'], axis=1)
 
# 使用示例
if __name__ == "__main__":
    from sklearn.datasets import make_blobs
    
    # 生成数据
    X, _ = make_blobs(n_samples=300, centers=3, n_features=2, random_state=42)
    
    # 拟合
    vbgmm = VariationalBayesGaussianMixture(n_components=3)
    vbgmm.fit(X)
    
    # 预测
    labels = vbgmm.predict(X)
    print(f"Cluster centers:\n{vbgmm.q_params['mu']}")

更多内容请参见:变分推断深度解析变分推断进阶

5.2.2 马尔可夫链蒙特卡洛(MCMC)

MCMC通过构造马尔可夫链从后验分布中采样。5

Metropolis-Hastings算法

def metropolis_hastings(log_posterior, proposal, x0, n_samples=10000):
    """
    Metropolis-Hastings采样
    
    Args:
        log_posterior: 对数后验函数
        proposal: 提议分布
        x0: 初始值
        n_samples: 采样数量
    """
    samples = [x0]
    current = x0
    current_log_prob = log_posterior(x0)
    
    accept_count = 0
    
    for i in range(n_samples):
        # 提议新状态
        proposed = proposal(current)
        proposed_log_prob = log_posterior(proposed)
        
        # 计算接受率
        log_accept_ratio = proposed_log_prob - current_log_prob
        
        # 接受或拒绝
        if np.log(np.random.random()) < log_accept_ratio:
            current = proposed
            current_log_prob = proposed_log_prob
            accept_count += 1
        
        samples.append(current)
    
    print(f"Acceptance rate: {accept_count / n_samples:.2%}")
    return np.array(samples)
 
# 吉布斯采样
def gibbs_sampler(joint_log_prob, initial_state, n_samples=10000):
    """吉布斯采样"""
    current = initial_state.copy()
    samples = []
    
    for _ in range(n_samples):
        for dim in range(len(current)):
            # 条件分布:在其他维度固定时,当前维度的分布
            def conditional(x):
                state = current.copy()
                state[dim] = x
                return np.exp(joint_log_prob(state))
            
            # 采样(简化实现)
            current[dim] = np.random.choice(2, p=[0.5, 0.5])
        
        samples.append(current.copy())
    
    return np.array(samples)

Hamiltonian Monte Carlo(HMC):利用梯度信息提高采样效率。

更多内容请参见:MCMC方法HMC几何理论

5.3 推断方法对比

方法类型时间复杂度精度适用场景
变量消除精确精确小规模问题
信念传播精确精确(树结构)树结构模型
Junction Tree精确精确中等规模
变分推断近似可控近似大规模问题
MCMC近似不确定渐进精确复杂后验
Loopy BP近似不保证大规模有环图

6. PyTorch 实现:神经网络中的概率图模型

6.1 变分自编码器(VAE)

VAE 是深度学习与概率图模型的经典结合:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli
 
class VariationalAutoencoder(nn.Module):
    """变分自编码器"""
    
    def __init__(self, input_dim, latent_dim, hidden_dim=256):
        super().__init__()
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 潜在变量的均值和方差
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
        
        self.latent_dim = latent_dim
    
    def encode(self, x):
        """编码"""
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """重参数化技巧"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        """解码"""
        return self.decoder(z)
    
    def forward(self, x):
        """前向传播"""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar
    
    def loss(self, x, x_recon, mu, logvar):
        """VAE损失函数 = 重构损失 + KL散度"""
        # 重构损失(伯努利分布)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        
        # KL散度:KL(N(μ,σ) || N(0,I))
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + kl_loss
 
# 训练函数
def train_vae(model, dataloader, optimizer, device='cuda', n_epochs=50):
    model.train()
    
    for epoch in range(n_epochs):
        total_loss = 0
        
        for batch in dataloader:
            x = batch.to(device)
            
            optimizer.zero_grad()
            x_recon, mu, logvar = model(x)
            loss = model.loss(x, x_recon, mu, logvar)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Loss: {total_loss/len(dataloader.dataset):.4f}")

6.2 图神经网络中的消息传递

GNN 可以看作是概率图模型中信念传播的推广:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class MessagePassingLayer(nn.Module):
    """消息传递层(对应PGM中的信念传播)"""
    
    def __init__(self, node_dim, edge_dim, out_dim):
        super().__init__()
        
        # 消息函数
        self.message_mlp = nn.Sequential(
            nn.Linear(node_dim * 2 + edge_dim, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, node_dim)
        )
        
        # 聚合函数(对应消息组合)
        # 更新函数(对应信念更新)
        self.update_mlp = nn.Sequential(
            nn.Linear(node_dim * 2, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, out_dim)
        )
    
    def message(self, src, dst, edge_attr):
        """计算消息:对应 μ_{y→x}"""
        # 拼接源节点、目标节点和边特征
        cat = torch.cat([src, dst, edge_attr], dim=-1)
        return self.message_mlp(cat)
    
    def aggregate(self, messages, index):
        """聚合消息:对应 ∑ 或 ∏"""
        # Sum aggregation
        return torch.zeros(index.max() + 1, messages.size(-1)).to(messages.device).scatter_add(0, index.unsqueeze(-1).expand_as(messages), messages)
    
    def update(self, node_features, aggr_features):
        """更新节点特征:对应信念更新"""
        cat = torch.cat([node_features, aggr_features], dim=-1)
        return self.update_mlp(cat)
    
    def forward(self, x, edge_index, edge_attr):
        """
        完整的消息传递步骤
        
        Args:
            x: 节点特征 [num_nodes, node_dim]
            edge_index: 边索引 [2, num_edges]
            edge_attr: 边特征 [num_edges, edge_dim]
        """
        src, dst = edge_index
        
        # 1. 计算消息
        msgs = self.message(x[src], x[dst], edge_attr)
        
        # 2. 聚合消息
        aggr = torch.zeros(x.size(0), x.size(1)).to(x.device)
        aggr = aggr.scatter_add(0, src.unsqueeze(-1).expand_as(msgs), msgs)
        
        # 3. 更新
        out = self.update(x, aggr)
        
        return out
 
 
class GraphNeuralNetwork(nn.Module):
    """图神经网络"""
    
    def __init__(self, node_dim, edge_dim, hidden_dim, out_dim, n_layers=3):
        super().__init__()
        
        self.layers = nn.ModuleList([
            MessagePassingLayer(
                node_dim if i == 0 else hidden_dim,
                edge_dim,
                hidden_dim if i < n_layers - 1 else out_dim
            )
            for i in range(n_layers)
        ])
        
        # 边预测头
        self.edge_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x, edge_index, edge_attr):
        """前向传播"""
        for layer in self.layers[:-1]:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        
        x = self.layers[-1](x, edge_index, edge_attr)
        return x
    
    def link_prediction(self, x, edge_index):
        """链接预测"""
        src, dst = edge_index
        h_src = x[src]
        h_dst = x[dst]
        
        cat = torch.cat([h_src, h_dst], dim=-1)
        return self.edge_predictor(cat)

7. 总结

概率图模型是现代机器学习和人工智能的基础工具,它们:

  1. 统一表示:用图结构自然地表示变量之间的依赖关系
  2. 条件独立:利用条件独立性减少参数量,实现高效推理
  3. 模块化:因子图提供灵活的表示框架
  4. 与深度学习融合:变分推断、消息传递在神经网络中广泛应用

学习路径建议

入门:
├── 贝叶斯网络基础
├── 条件独立与D-分离
└── 朴素贝叶斯分类器

进阶:
├── 马尔可夫随机场
├── 因子图与消息传递
└── 精确推断算法

深入:
├── 变分推断
├── MCMC采样
└── 结构学习

参考

Footnotes

  1. Koller, D., & Friedman, N. (2009). Probabilistic Graphical Models: Principles and Techniques. MIT Press. 2 3 4

  2. Sutton, C., & McCallum, A. (2012). An Introduction to Conditional Random Fields. Foundations and Trends in Machine Learning, 4(4), 267-373.

  3. Kschischang, F. R., Frey, B. J., & Loeliger, H. A. (2001). Factor Graphs and the Sum-Product Algorithm. IEEE Transactions on Information Theory, 47(2), 498-519. 2

  4. Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859-877.

  5. Robert, C., & Casella, G. (2004). Monte Carlo Statistical Methods. Springer.