深度概率推断实践
1 引言
深度概率推断(Deep Probabilistic Inference)将概率图模型的推断能力与深度学习的表示学习能力相结合。本章提供完整的PyTorch实现,包括:
- 可微分消息传递层
- 概率推断模块
- 端到端训练流程
- 实际应用案例
2 可微分消息传递层
2.1 基础消息传递
消息传递是概率图模型的核心操作。我们首先实现一个通用的可微分消息传递层:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
class DifferentiableMessagePassing(nn.Module):
"""可微分消息传递层基类"""
def __init__(self, node_dim: int, edge_dim: int, msg_dim: int):
super().__init__()
self.node_dim = node_dim
self.edge_dim = edge_dim
self.msg_dim = msg_dim
# 消息函数
self.msg_fn = nn.Sequential(
nn.Linear(node_dim + edge_dim, msg_dim),
nn.ReLU(),
nn.Linear(msg_dim, msg_dim)
)
# 聚合函数
self.aggr_fn = nn.Linear(msg_dim, msg_dim)
# 更新函数
self.update_fn = nn.GRUCell(msg_dim, node_dim)
def message(self, source: torch.Tensor, target: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
计算从source到target的消息
Args:
source: (num_edges, node_dim) 源节点特征
target: (num_edges, node_dim) 目标节点特征
edge_attr: (num_edges, edge_dim) 边特征
Returns:
messages: (num_edges, msg_dim) 消息
"""
if edge_attr is not None:
combined = torch.cat([source, edge_attr], dim=-1)
else:
combined = source
return self.msg_fn(combined)
def aggregate(self, messages: torch.Tensor,
index: torch.Tensor,
num_nodes: int) -> torch.Tensor:
"""
聚合消息
Args:
messages: (num_edges, msg_dim) 消息
index: (num_edges,) 目标节点索引
num_nodes: int 节点数
Returns:
aggregated: (num_nodes, msg_dim) 聚合后的消息
"""
# 散点聚合
aggregated = torch.zeros(
num_nodes, messages.size(-1),
device=messages.device, dtype=messages.dtype
)
aggregated.index_add_(0, index, messages)
# 归一化(按邻居数量)
counts = torch.zeros(num_nodes, device=messages.device)
counts.index_add_(0, index, torch.ones_like(index, dtype=torch.float))
counts = counts.clamp(min=1).unsqueeze(-1)
return self.aggr_fn(aggregated / counts)
def update(self, node_attr: torch.Tensor,
messages: torch.Tensor) -> torch.Tensor:
"""
更新节点特征
Args:
node_attr: (num_nodes, node_dim) 当前节点特征
messages: (num_nodes, msg_dim) 聚合后的消息
Returns:
updated: (num_nodes, node_dim) 更新后的节点特征
"""
return self.update_fn(messages, node_attr)
def forward(self, node_attr: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
完整的前向传播
Args:
node_attr: (num_nodes, node_dim) 节点特征
edge_index: (2, num_edges) 边索引
edge_attr: (num_edges, edge_dim) 边特征
Returns:
updated_node_attr: (num_nodes, node_dim) 更新后的节点特征
"""
row, col = edge_index
# 消息计算
messages = self.message(node_attr[row], node_attr[col], edge_attr)
# 消息聚合
aggregated = self.aggregate(messages, col, node_attr.size(0))
# 节点更新
updated = self.update(node_attr, aggregated)
return updated2.2 信念传播层
实现标准的和-积算法(Sum-Product Algorithm):
class BeliefPropagationLayer(nn.Module):
"""和-积信念传播层"""
def __init__(self, num_states: int, message_dim: int):
super().__init__()
self.num_states = num_states
self.message_dim = message_dim
# 势函数参数化
self.potential_fn = nn.Sequential(
nn.Linear(num_states * 2, message_dim),
nn.ReLU(),
nn.Linear(message_dim, 1)
)
# 消息归一化
self.msg_normalize = nn.LayerNorm(message_dim)
def forward(self, beliefs: torch.Tensor,
edge_index: torch.Tensor,
num_iterations: int = 3) -> torch.Tensor:
"""
执行信念传播迭代
Args:
beliefs: (num_nodes, num_states) 初始信念
edge_index: (2, num_edges) 边索引
num_iterations: 迭代次数
Returns:
final_beliefs: (num_nodes, num_states) 最终信念
"""
num_nodes, num_states = beliefs.shape
row, col = edge_index
# 初始化消息为均匀分布
messages = torch.ones(
edge_index.size(1), num_states,
device=beliefs.device
) / num_states
for iteration in range(num_iterations):
new_messages = []
for e_idx in range(edge_index.size(1)):
src, tgt = row[e_idx], col[e_idx]
# 从源节点收集所有入消息
src_msg_sum = messages[row == src].sum(dim=0)
# 排除目标节点的消息(避免重复)
if (row == tgt).any():
tgt_incoming = messages[col == tgt]
mask = ~(row[row == tgt] == src)
src_msg_sum = tgt_incoming[mask].sum(dim=0)
# 消息更新
msg = F.softmax(
self.potential_fn(
torch.cat([
src_msg_sum.unsqueeze(0).expand(num_states, -1),
torch.eye(num_states, device=beliefs.device)
], dim=-1)
).squeeze(-1),
dim=-1
)
new_messages.append(msg)
messages = torch.stack(new_messages)
# 消息归一化
messages = self.msg_normalize(messages)
# 计算最终信念
final_beliefs = beliefs.clone()
for node in range(num_nodes):
incoming = messages[col == node]
if incoming.numel() > 0:
final_beliefs[node] = (incoming * beliefs[node]).sum(dim=0)
final_beliefs[node] = F.softmax(final_beliefs[node], dim=-1)
return final_beliefs2.3 变分消息传递层
将变分推断嵌入消息传递框架:
class VariationalMessagePassing(nn.Module):
"""变分消息传递层"""
def __init__(self, node_dim: int, latent_dim: int):
super().__init__()
self.node_dim = node_dim
self.latent_dim = latent_dim
# 编码器:从观测推断变分参数
self.encoder = nn.Sequential(
nn.Linear(node_dim, latent_dim * 2),
nn.Tanh(),
nn.Linear(latent_dim * 2, latent_dim * 2) # mu和log_var
)
# 消息网络
self.message_net = nn.GRU(
input_size=latent_dim,
hidden_size=latent_dim,
batch_first=True
)
# 解码器:从隐变量重建观测
self.decoder = nn.Sequential(
nn.Linear(latent_dim, node_dim),
nn.Sigmoid()
)
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""重参数化技巧"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x: torch.Tensor,
edge_index: torch.Tensor,
num_iterations: int = 3) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
变分消息传递前向传播
Returns:
z: 隐变量样本
mu: 隐变量均值
log_var: 隐变量对数方差
"""
num_nodes = x.size(0)
# 编码初始信念
h = self.encoder(x) # (N, latent_dim * 2)
mu = h[:, :self.latent_dim]
log_var = h[:, self.latent_dim:]
# 迭代消息传递
z = self.reparameterize(mu, log_var)
for _ in range(num_iterations):
# 消息传递
z_packed = z.unsqueeze(0) # (1, N, latent_dim)
msg_out, _ = self.message_net(z_packed)
msg = msg_out.squeeze(0) # (N, latent_dim)
# 更新变分参数
z = torch.tanh(msg + mu)
mu = mu + 0.1 * msg
log_var = log_var - 0.05 * msg.pow(2)
# 重建观测
x_recon = self.decoder(z)
return z, mu, log_var, x_recon
def elbo(self, x: torch.Tensor,
z: torch.Tensor, mu: torch.Tensor,
log_var: torch.Tensor) -> torch.Tensor:
"""
计算ELBO
Returns:
elbo: 证据下界
"""
# 重构损失
x_recon = self.decoder(z)
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
# KL散度
kl_loss = -0.5 * torch.sum(
1 + log_var - mu.pow(2) - log_var.exp()
)
return -(recon_loss + kl_loss)3 高斯过程推断层
3.1 变分高斯过程
class VariationalGPLayer(nn.Module):
"""变分高斯过程层"""
def __init__(self, input_dim: int, num_inducing: int, kernel_dim: int = 128):
super().__init__()
self.input_dim = input_dim
self.num_inducing = num_inducing
self.kernel_dim = kernel_dim
# 诱导点
self.inducing_points = nn.Parameter(
torch.randn(num_inducing, input_dim) * 0.1
)
# 均值和方差参数
self.mean = nn.Parameter(torch.zeros(num_inducing, 1))
self.cov_log_diag = nn.Parameter(torch.zeros(num_inducing))
# 核函数
self.kernel = nn.Sequential(
nn.Linear(input_dim, kernel_dim),
nn.RBFKernel(),
nn.Linear(kernel_dim, 1)
)
# 均值函数
self.mean_fn = nn.Linear(input_dim, 1)
def kernel_matrix(self, X: torch.Tensor, Z: torch.Tensor) -> torch.Tensor:
"""计算核矩阵"""
# 使用RBF核
pairwise_sq_dists = torch.cdist(X, Z).pow(2)
K = torch.exp(-0.5 * pairwise_sq_dists)
return K
def forward(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
前向传播,返回预测均值和方差
Returns:
mean: (batch_size, 1) 预测均值
var: (batch_size, 1) 预测方差
"""
# 核矩阵
Kzz = self.kernel_matrix(self.inducing_points, self.inducing_points)
Kxz = self.kernel_matrix(X, self.inducing_points)
# 添加 jitter 以确保数值稳定
jitter = 1e-6 * torch.eye(self.num_inducing, device=Kzz.device)
Kzz_inv = torch.inverse(Kzz + jitter)
# 预测均值
mean = Kxz @ Kzz_inv @ self.mean
# 预测方差
k_xx = torch.ones(X.size(0), 1, device=X.device) # 对角线
var = k_xx - (Kxz * (Kzz_inv @ Kxz.T).T).sum(dim=-1, keepdim=True)
var = F.softplus(var + 1e-6) # 确保正值
return mean + self.mean_fn(X), var4 端到端训练框架
4.1 概率推断训练器
from torch.utils.data import DataLoader
from typing import Dict, Any
class ProbabilisticInferenceTrainer:
"""概率推断模型训练器"""
def __init__(self, model: nn.Module,
optimizer: torch.optim.Optimizer,
device: str = 'cuda'):
self.model = model
self.optimizer = optimizer
self.device = device
self.model.to(device)
# 训练历史
self.train_history = []
self.val_history = []
def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
"""训练一个epoch"""
self.model.train()
epoch_loss = 0.0
epoch_metrics = {}
for batch in dataloader:
# 数据移动到设备
batch = {k: v.to(self.device) for k, v in batch.items()}
# 前向传播
output = self.model(**batch)
# 计算损失
loss = self.compute_loss(output, batch)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=1.0
)
self.optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(dataloader)
return {'loss': epoch_loss, **epoch_metrics}
def compute_loss(self, output: Dict[str, torch.Tensor],
batch: Dict[str, Any]) -> torch.Tensor:
"""计算损失"""
if 'loss' in output:
return output['loss']
# 默认使用重构损失
if 'recon' in output and 'target' in batch:
recon_loss = F.mse_loss(output['recon'], batch['target'])
# KL损失
kl_loss = 0.0
if 'mu' in output and 'log_var' in output:
kl_loss = -0.5 * torch.sum(
1 + output['log_var'] - output['mu'].pow(2) - output['log_var'].exp()
)
return recon_loss + 0.01 * kl_loss
raise ValueError("Cannot compute loss from output")
def validate(self, dataloader: DataLoader) -> Dict[str, float]:
"""验证"""
self.model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(self.device) for k, v in batch.items()}
output = self.model(**batch)
loss = self.compute_loss(output, batch)
val_loss += loss.item()
return {'val_loss': val_loss / len(dataloader)}
def fit(self, train_loader: DataLoader,
val_loader: DataLoader,
num_epochs: int) -> None:
"""完整训练流程"""
best_val_loss = float('inf')
for epoch in range(num_epochs):
# 训练
train_metrics = self.train_epoch(train_loader)
# 验证
val_metrics = self.validate(val_loader)
# 记录历史
self.train_history.append(train_metrics)
self.val_history.append(val_metrics)
# 打印
print(f"Epoch {epoch+1}/{num_epochs}")
print(f" Train Loss: {train_metrics['loss']:.4f}")
print(f" Val Loss: {val_metrics['val_loss']:.4f}")
# 保存最佳模型
if val_metrics['val_loss'] < best_val_loss:
best_val_loss = val_metrics['val_loss']
self.save_checkpoint('best_model.pt')
def save_checkpoint(self, path: str) -> None:
"""保存检查点"""
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)5 实际应用案例
5.1 贝叶斯图神经网络
class BayesianGNN(nn.Module):
"""贝叶斯图神经网络用于节点分类"""
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
super().__init__()
# 消息传递层
self.mpnn1 = DifferentiableMessagePassing(in_dim, 0, hidden_dim)
self.mpnn2 = DifferentiableMessagePassing(hidden_dim, 0, hidden_dim)
# 变分层
self.vmp = VariationalMessagePassing(hidden_dim, hidden_dim)
# 分类器
self.classifier = nn.Linear(hidden_dim, out_dim)
def forward(self, x: torch.Tensor,
edge_index: torch.Tensor,
training: bool = True) -> Dict[str, torch.Tensor]:
"""
前向传播
Returns:
logits: 分类logits
mu: 变分均值
log_var: 变分对数方差
"""
# 消息传递
h = F.relu(self.mpnn1(x, edge_index))
h = F.relu(self.mpnn2(h, edge_index))
# 变分推断
z, mu, log_var, h_recon = self.vmp(h, edge_index)
# 分类
logits = self.classifier(h)
return {
'logits': logits,
'mu': mu,
'log_var': log_var,
'recon': h_recon
}
def predict_with_uncertainty(self, x: torch.Tensor,
edge_index: torch.Tensor,
num_samples: int = 10) -> Tuple[torch.Tensor, torch.Tensor]:
"""
使用MC Dropout进行不确定性估计
Returns:
mean: 平均预测
std: 预测标准差(不确定性)
"""
self.train() # 启用dropout
predictions = []
for _ in range(num_samples):
logits = self.forward(x, edge_index)['logits']
probs = F.softmax(logits, dim=-1)
predictions.append(probs)
predictions = torch.stack(predictions)
mean = predictions.mean(dim=0)
std = predictions.std(dim=0)
return mean, std5.2 概率链接预测
class ProbabilisticLinkPrediction(nn.Module):
"""概率链接预测模型"""
def __init__(self, node_dim: int, num_layers: int = 3):
super().__init__()
# 节点嵌入
self.embedding = nn.Parameter(
torch.randn(1, node_dim) * 0.1
) # 广播到所有节点
# 马尔可夫消息传递层
self.message_layers = nn.ModuleList([
DifferentiableMessagePassing(node_dim, 0, node_dim)
for _ in range(num_layers)
])
# 链接预测器
self.link_predictor = nn.Sequential(
nn.Linear(node_dim * 2, node_dim),
nn.ReLU(),
nn.Linear(node_dim, 1)
)
def forward(self, x: torch.Tensor,
edge_index: torch.Tensor,
edge_index_pos: torch.Tensor,
edge_index_neg: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
链接预测前向传播
Args:
x: 节点特征
edge_index: 所有边(训练+测试)
edge_index_pos: 正样本边
edge_index_neg: 负样本边
"""
num_nodes = x.size(0)
# 初始化节点嵌入
if x.size(0) == 1:
h = self.embedding.expand(num_nodes, -1)
else:
h = x
# 消息传递
for layer in self.message_layers:
h_new = layer(h, edge_index)
h = h + h_new # 残差连接
# 正样本得分
pos_src, pos_dst = edge_index_pos
pos_h_src, pos_h_dst = h[pos_src], h[pos_dst]
pos_score = self.link_predictor(
torch.cat([pos_h_src, pos_h_dst], dim=-1)
)
# 负样本得分
neg_src, neg_dst = edge_index_neg
neg_h_src, neg_h_dst = h[neg_src], h[neg_dst]
neg_score = self.link_predictor(
torch.cat([neg_h_src, neg_h_dst], dim=-1)
)
return {
'pos_score': pos_score,
'neg_score': neg_score,
'embeddings': h
}
def loss(self, output: Dict[str, torch.Tensor]) -> torch.Tensor:
"""链接预测损失"""
pos_score = output['pos_score']
neg_score = output['neg_score']
# 铰链损失
margin = 1.0
loss = F.margin_ranking_loss(
pos_score, neg_score,
torch.ones_like(pos_score),
margin=margin
)
return loss6 总结与展望
6.1 核心要点
- 可微分消息传递是连接概率推断与深度学习的桥梁
- 变分推断提供了近似推断的可扩展框架
- 端到端训练允许联合学习推断网络和下游任务
6.2 进阶方向
- 结构化变分推断:利用图的稀疏性
- 层次化消息传递:多尺度特征聚合
- 可解释性:将不确定性量化融入模型解释
- 组合优化:将推理问题嵌入神经网络
6.3 实际建议
- 从简单的消息传递层开始,逐步增加复杂度
- 使用梯度裁剪防止训练不稳定
- 监控KL散度与重构损失的比例
- 对于大规模图,考虑稀疏矩阵操作和采样技术