概述
几何感知概率电路(Geometry-Aware Probabilistic Circuits)是一种通过Voronoi镶嵌(Voronoi Tessellation)将局部几何结构直接嵌入概率电路的框架。1
传统概率电路的局限性在于:
- 混合权重与数据无关: 固定的高斯/均匀混合权重无法捕捉数据的局部几何
- 忽略流形结构: 假设数据分布是全局均匀的,无法适应弯曲的数据流形
- 表达能力受限: 在高维弯曲流形上表现不佳
几何感知概率电路的核心洞察:
数据流形的局部几何结构可以通过Voronoi镶嵌自然地编码到概率电路中,同时保持可处理推断的能力。
1. 问题背景
1.1 数据流形与几何结构
现实世界的数据往往位于高维空间中的低维流形上:
┌─────────────────────────────────────────────────────────────┐
│ │
│ 高维数据空间 R^d │
│ │
│ ╭───────────╮ │
│ ╱ ╲ │
│ ╱ 数据流形 ╲ │
│ ╱ (低维曲面) ╲ │
│ ╱ ╲ │
│ │ ● ● ● │ │
│ │ ● ● ● ● │ │
│ │ ● ● ● │ │
│ ╲ ● ● ╱ │
│ ╲ ● ● ╱ │
│ ╲ ●● ╱ │
│ ╰─────────────╯ │
│ │
└─────────────────────────────────────────────────────────────┘
1.2 传统概率电路的问题
传统概率电路(如SPN)使用固定的混合权重:
# 传统SPN的混合
def traditional_mixture(x):
total = 0
for k in range(num_mixtures):
# 固定权重,与数据无关
weight = mixture_weights[k] # 与x无关
prob = gaussian(x, mean[k], std[k])
total += weight * prob
return total问题:
- 权重 与输入 无关
- 无法适应局部几何变化
- 在流形上表现不佳
1.3 几何感知的解决方案
几何感知概率电路的核心思想:
用Voronoi镶嵌来划分输入空间,每个Voronoi区域独立建模,从而捕捉局部几何
┌─────────────────────────────────────────────────────────────┐
│ Voronoi镶嵌 │
│ │
│ ╱╲ ╱╲ ╱╲ │
│ ╱ ╲ ╱ ╲ ╱ ╲ │
│ ╱ ╲ ╱ ╲ ╱ ╲ │
│ ╱ ╱╲ ╲ ╱╲ ╲ ╱╲ ╲ │
│ ╱ ╱ ╲ ╲ ╱ ╲ ╲ ╱ ╲ ╲ │
│ │ │ V1 │ │ V2 │ │ V3 │ ... │
│ │ │ ● │ │ ●● │ │ ● │ │
│ ╲ ╲ ●╱ ╲ ●╱ ╲ ●╱ ╲ │
│ ╲ ╲╱ ╲╱ ╲╱ ╱ ╲ │
│ ╲ ╲ ╱ ╱ │
│ ╲ ╲ ╱ ╱ │
│ │
│ 每个Voronoi区域Vi有独立的概率模型P(x | x ∈ Vi) │
└─────────────────────────────────────────────────────────────┘
2. Voronoi镶嵌基础
2.1 Voronoi镶嵌定义
定义(Voronoi镶嵌): 给定一组种子点 ,Voronoi镶嵌将空间划分为 个区域:
其中 是第 个Voronoi单元的种子点/中心。
2.2 几何性质
| 性质 | 描述 |
|---|---|
| 覆盖性 | |
| 互斥性 | |
| 局部性 | 每个区域可独立建模 |
| 可微性 | 边界处可定义次梯度 |
2.3 Voronoi镶嵌与概率分布
关键洞察: Voronoi镶嵌可以与概率分布结合:
其中 是 属于第 个Voronoi区域的先验。
3. 几何感知概率电路架构
3.1 核心架构
┌─────────────────────────────────────────────────────────────┐
│ 几何感知概率电路 (VT-PC) │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入x ──→ [距离计算层] ──→ [Voronoi分配] │
│ │ │ │
│ ▼ ▼ │
│ 计算到各中心的距离 确定所属Voronoi单元 │
│ │ │ │
│ └────────┬───────┘ │
│ ▼ │
│ [几何感知混合] │
│ │ │
│ ┌────────────────┼────────────────┐ │
│ ▼ ▼ ▼ │
│ 几何权重 局部概率 边界处理 │
│ (基于Voronoi) (Voronoi内) (软化边界) │
│ │ │
│ ▼ │
│ 输出概率P(x) │
│ │
└─────────────────────────────────────────────────────────────┘
3.2 几何权重计算
class GeometricWeights(nn.Module):
"""
基于Voronoi的几何权重计算
"""
def __init__(self, input_dim, num_centers, temperature=1.0):
super().__init__()
self.centers = nn.Parameter(
torch.randn(num_centers, input_dim)
) # Voronoi种子点
self.temperature = temperature
def compute_weights(self, x):
"""
计算每个Voronoi区域的几何权重
Args:
x: (batch, input_dim)
Returns:
weights: (batch, num_centers) 每个样本属于各区域的后验概率
"""
# 计算距离
distances = torch.cdist(x, self.centers) # (batch, num_centers)
# 软分配(基于距离的softmax)
logits = -distances / self.temperature
weights = F.softmax(logits, dim=-1) # (batch, num_centers)
return weights, distances
def voronoi_assignment(self, x):
"""
硬Voronoi分配(每个点属于最近中心)
"""
distances = torch.cdist(x, self.centers)
assignment = distances.argmin(dim=-1) # (batch,)
return assignment, distances3.3 局部概率模型
class LocalProbabilisticModel(nn.Module):
"""
Voronoi单元内的局部概率模型
"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
# 每个Voronoi单元有独立的概率参数
self.means = nn.Parameter(torch.randn(1, input_dim))
self.log_stds = nn.Parameter(torch.zeros(1, input_dim))
def forward(self, x):
"""
计算局部概率密度
"""
diff = x - self.means
log_prob = -0.5 * (
diff ** 2 / (torch.exp(2 * self.log_stds) + 1e-8)
) - 0.5 * torch.log(2 * torch.pi * torch.exp(2 * self.log_stds) + 1e-8)
return log_prob.sum(dim=-1)
class GeometryAwarePC(nn.Module):
"""
几何感知概率电路完整实现
"""
def __init__(self, input_dim, num_centers, hidden_dim):
super().__init__()
self.num_centers = num_centers
# Voronoi几何层
self.geo_weights = GeometricWeights(input_dim, num_centers)
# 局部概率模型(每个Voronoi单元一个)
self.local_models = nn.ModuleList([
LocalProbabilisticModel(input_dim, hidden_dim)
for _ in range(num_centers)
])
# 边界软化参数
self.boundary_softness = nn.Parameter(torch.tensor(0.1))
def forward(self, x):
"""
几何感知前向传播
"""
# 获取几何权重
weights, distances = self.geo_weights.compute_weights(x)
# 计算各Voronoi单元的局部概率
local_probs = []
for i, model in enumerate(self.local_models):
prob = model(x) # (batch,)
local_probs.append(prob)
local_probs = torch.stack(local_probs, dim=1) # (batch, num_centers)
# 加权求和
log_weights = torch.log(weights + 1e-8)
weighted_probs = local_probs + log_weights
# 混合
log_marginal = torch.logsumexp(weighted_probs, dim=1)
return log_marginal
def voronoi_density(self, x):
"""
Voronoi条件密度
P(x | Voronoi region)
"""
assignment, distances = self.geo_weights.voronoi_assignment(x)
densities = []
for i, model in enumerate(self.local_models):
mask = (assignment == i).float().unsqueeze(-1)
prob = model(x) * mask
densities.append(prob)
return torch.cat(densities, dim=1)4. 可处理性分析
4.1 核心定理
定理(VT-PC可处理性): 设 VT-PC 是一个几何感知概率电路,包含 个Voronoi单元,则:
- 边际推断: 可以在 时间内计算
- 条件概率: 可以精确计算
- 期望计算: 可以精确计算
证明思路:
- Voronoi镶嵌的互斥性保证:
- 每个 内部的积分是局部的
- 总复杂度与Voronoi单元数线性相关
4.2 梯度计算
def compute_gradient(x, model):
"""
计算VT-PC的梯度
"""
x.requires_grad_(True)
log_prob = model(x)
# 反向传播
log_prob.backward()
# 提取梯度
grad_x = x.grad
grad_params = {name: p.grad for name, p in model.named_parameters()}
return grad_x, grad_params
def voronoi_gradient(x, assignment, model):
"""
Voronoi感知梯度
"""
x.requires_grad_(True)
log_prob = model(x)
log_prob.backward()
# Voronoi感知的梯度修正
# 梯度方向应该指向局部几何流形方向
centers = model.geo_weights.centers
# 指向最近中心的方向
distances, assignment = model.geo_weights.compute_weights(x)
nearest_center = centers[assignment]
direction = nearest_center - x
# 修正梯度
grad_x = x.grad
corrected_grad = grad_x + model.boundary_softness * direction
return corrected_grad5. 软边界与连续性
5.1 硬边界的问题
硬Voronoi边界会导致不连续性:
硬边界: 软边界:
│ /
│ A │ B │ A+B
│──────┼──── │──/────
│ │ │ /
▼ ▼ ▼▼
不连续 连续
5.2 软边界实现
class SoftVoronoiBoundary(nn.Module):
"""
Voronoi软边界实现
"""
def __init__(self, num_centers, input_dim, beta=1.0):
super().__init__()
self.centers = nn.Parameter(torch.randn(num_centers, input_dim))
self.beta = beta # 软化参数
def soft_boundary_weights(self, x):
"""
软边界权重(基于logistic sigmoid)
"""
distances = torch.cdist(x, self.centers) # (batch, K)
# 排序距离
sorted_dist, indices = torch.sort(distances, dim=-1)
# 最近和次近距离
d1 = sorted_dist[:, 0]
d2 = sorted_dist[:, 1]
# 软边界权重(越接近边界越小)
boundary_proximity = torch.sigmoid(self.beta * (d2 - d1))
# 各单元的软权重
weights = torch.zeros_like(distances)
weights.scatter_(1, indices[:, :1], 1.0) # 最近中心权重为1
# 边界处平滑
for b in range(x.size(0)):
i = indices[b, 0].item()
j = indices[b, 1].item()
alpha = boundary_proximity[b]
weights[b, i] = 1 - alpha * 0.5
weights[b, j] = alpha * 0.5
return weights, indices
class ContinuousVTPC(nn.Module):
"""
连续几何感知概率电路
"""
def __init__(self, input_dim, num_centers, hidden_dim):
super().__init__()
self.soft_boundary = SoftVoronoiBoundary(num_centers, input_dim)
self.local_models = nn.ModuleList([
LocalProbabilisticModel(input_dim, hidden_dim)
for _ in range(num_centers)
])
def forward(self, x):
"""
软边界前向传播
"""
# 软边界权重
weights, indices = self.soft_boundary.soft_boundary_weights(x)
# 局部概率
local_probs = torch.stack([
model(x) for model in self.local_models
], dim=1) # (batch, K)
# 加权混合
log_weights = torch.log(weights + 1e-8)
weighted_probs = local_probs + log_weights
# log-sum-exp混合
log_marginal = torch.logsumexp(weighted_probs, dim=1)
return log_marginal6. 学习算法
6.1 Voronoi中心学习
def train_voronoi_centers(model, data_loader, lr=1e-3, num_iterations=100):
"""
学习Voronoi中心
"""
optimizer = torch.optim.Adam(model.geo_weights.centers, lr=lr)
for iteration in range(num_iterations):
total_loss = 0
for x, in data_loader:
optimizer.zero_grad()
# 负对数似然
log_prob = model(x)
loss = -log_prob.mean()
loss.backward()
optimizer.step()
total_loss += loss.item()
if iteration % 10 == 0:
print(f"Iteration {iteration}: Loss = {total_loss / len(data_loader):.4f}")
return model
def voronoi_kmeans_loss(x, centers):
"""
Voronoi K-means损失
用于初始化中心
"""
distances = torch.cdist(x, centers)
assignment = distances.argmin(dim=1)
loss = 0
for i in range(centers.size(0)):
mask = (assignment == i)
if mask.sum() > 0:
cluster_points = x[mask]
center = centers[i]
loss += ((cluster_points - center) ** 2).sum()
return loss6.2 端到端训练
def train_geometry_aware_pc(model, train_loader, num_epochs=100, lr=1e-3):
"""
端到端训练几何感知概率电路
"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
for epoch in range(num_epochs):
epoch_loss = 0
for x, in train_loader:
optimizer.zero_grad()
# 负对数似然
log_prob = model(x)
nll_loss = -log_prob.mean()
# 正则化:鼓励均匀分布的Voronoi
weights, _ = model.geo_weights.compute_weights(x)
entropy_loss = -torch.mean(
weights * torch.log(weights + 1e-8)
)
# 总损失
loss = nll_loss + 0.1 * entropy_loss
loss.backward()
optimizer.step()
epoch_loss += loss.item()
scheduler.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {epoch_loss / len(train_loader):.4f}")
return model7. 应用场景
7.1 密度估计
class DensityEstimator(nn.Module):
"""
基于几何感知PC的密度估计
"""
def __init__(self, input_dim, num_components):
super().__init__()
self.pc = GeometryAwarePC(
input_dim=input_dim,
num_centers=num_components,
hidden_dim=128
)
def fit(self, data, num_epochs=100):
"""
拟合数据分布
"""
loader = DataLoader(TensorDataset(data), batch_size=256, shuffle=True)
return train_geometry_aware_pc(self.pc, loader, num_epochs)
def log_prob(self, x):
"""计算对数概率密度"""
return self.pc(x)
def sample(self, num_samples):
"""从学到的分布中采样"""
with torch.no_grad():
# 采样Voronoi单元
indices = torch.randint(0, self.pc.num_centers, (num_samples,))
# 从各单元分布采样
samples = []
for i in range(num_samples):
model = self.pc.local_models[indices[i]]
z = torch.randn(1, self.pc.geo_weights.centers.size(1))
mean = model.means[0]
std = torch.exp(model.log_stds[0])
sample = mean + std * z
samples.append(sample)
return torch.cat(samples, dim=0)7.2 异常检测
class AnomalyDetector(nn.Module):
"""
基于几何感知PC的异常检测
"""
def __init__(self, input_dim, num_components, threshold=0.01):
super().__init__()
self.pc = GeometryAwarePC(input_dim, num_components, 128)
self.threshold = threshold
def forward(self, x):
"""
返回异常分数(越低越异常)
"""
log_prob = self.pc(x)
anomaly_score = torch.exp(log_prob)
return anomaly_score
def predict(self, x):
"""
预测是否为异常
"""
scores = self.forward(x)
return scores < self.threshold7.3 图像生成
class ImageGenerator(nn.Module):
"""
基于几何感知PC的图像生成
"""
def __init__(self, latent_dim, num_components, image_size):
super().__init__()
self.latent_dim = latent_dim
self.image_size = image_size
# 像素级的几何感知PC
self.pc = GeometryAwarePC(
input_dim=latent_dim,
num_centers=num_components,
hidden_dim=256
)
# 从潜在空间到图像空间的映射
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 512),
nn.ReLU(),
nn.Linear(512, image_size * image_size * 3),
nn.Sigmoid()
)
def forward(self, z):
"""
生成图像
"""
# 生成
images = self.decoder(z)
return images.view(-1, 3, self.image_size, self.image_size)
def log_prob(self, x):
"""
计算图像的对数概率
"""
# 简化:仅在潜在空间计算
z = self.encoder(x)
return self.pc(z)8. 与其他方法的对比
8.1 方法对比表
| 方法 | 几何感知 | 可处理性 | 表达能力 | 扩展性 |
|---|---|---|---|---|
| 几何感知PC | ✓ 原生 | ✓ 精确 | 中等 | 好 |
| 标准GMM | ✗ | ✓ 精确 | 中等 | 好 |
| 标准SPN | ✗ | ✓ 精确 | 中等 | 中等 |
| VAE | ✗ | ✗ 近似 | 强 | 好 |
| NF | 部分 | ✓ 精确 | 强 | 中等 |
8.2 几何感知PC的优势
- 局部几何建模: Voronoi镶嵌自然编码局部结构
- 精确推断: 保持概率电路的可处理性
- 可微学习: 端到端可训练
- 灵活边界: 软边界处理连续性
9. 理论分析
9.1 表达能力
定理(几何感知表达能力): 设 是定义在流形 上的真实分布,VT-PC的近似误差:
其中 是 在Voronoi镶嵌下的投影。
9.2 收敛性
定理(收敛性): 设训练集 从 采样,则经验风险:
满足:
9.3 计算复杂度
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 前向传播 | ||
| 梯度计算 | ||
| Voronoi更新 |
其中 是批量大小, 是Voronoi单元数, 是输入维度。
10. 实现细节
10.1 完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
class VoronoiTessellationPC(nn.Module):
"""
Voronoi镶嵌概率电路完整实现
"""
def __init__(self, input_dim, num_centers, hidden_dim=128,
learnable_boundary=True, temperature=1.0):
super().__init__()
self.input_dim = input_dim
self.num_centers = num_centers
self.temperature = temperature
# Voronoi中心(可学习)
self.centers = nn.Parameter(
torch.randn(num_centers, input_dim) * 0.1
)
# 局部概率模型参数
self.local_means = nn.Parameter(
torch.randn(num_centers, input_dim) * 0.1
)
self.local_log_stds = nn.Parameter(
torch.zeros(num_centers, input_dim)
)
# 混合权重(可选)
self.mixture_logits = nn.Parameter(torch.zeros(num_centers))
# 边界软化参数
if learnable_boundary:
self.boundary_softness = nn.Parameter(torch.tensor(1.0))
else:
self.boundary_softness = None
def compute_voronoi_weights(self, x):
"""
计算Voronoi权重
"""
# 距离
distances = torch.cdist(x, self.centers) # (B, K)
# 软分配
logits = -distances / self.temperature
weights = F.softmax(logits, dim=-1)
# 硬分配
assignment = distances.argmin(dim=-1)
return weights, assignment, distances
def local_log_prob(self, x, indices):
"""
计算局部对数概率
"""
means = self.local_means[indices]
stds = torch.exp(self.local_log_stds[indices])
# 高斯对数密度
diff = x.unsqueeze(1) - means # (B, 1, d)
log_prob = -0.5 * (
(diff ** 2) / (stds ** 2 + 1e-8) +
torch.log(2 * torch.pi * stds ** 2 + 1e-8)
)
return log_prob.sum(dim=-1) # (B, K)
def forward(self, x):
"""
前向传播
"""
B = x.size(0)
# Voronoi权重
weights, assignment, distances = self.compute_voronoi_weights(x)
# 局部概率
local_probs = self.local_log_prob(x, assignment) # (B, K)
# 混合权重
mix_weights = F.softmax(self.mixture_logits, dim=0) # (K,)
# 加权混合
weighted_probs = local_probs + torch.log(weights + 1e-8) + torch.log(mix_weights)
# log-sum-exp
log_marginal = torch.logsumexp(weighted_probs, dim=-1)
return log_marginal
def sample(self, num_samples):
"""
从模型中采样
"""
with torch.no_grad():
# 采样Voronoi单元
indices = torch.multinomial(
F.softmax(self.mixture_logits, dim=0),
num_samples,
replacement=True
)
# 从各单元高斯采样
means = self.local_means[indices]
stds = torch.exp(self.local_log_stds[indices])
z = torch.randn(num_samples, self.input_dim, device=means.device)
samples = means + stds * z
return samples10.2 训练脚本
def train_voronoi_pc(model, train_data, num_epochs=100, batch_size=256, lr=1e-3):
"""
训练Voronoi概率电路
"""
dataset = torch.utils.data.TensorDataset(train_data)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
for epoch in range(num_epochs):
total_loss = 0
for x, in loader:
optimizer.zero_grad()
# 负对数似然
log_prob = model(x)
loss = -log_prob.mean()
# 可选:正则化
# 1. 熵正则化(鼓励均匀分配)
weights, _, _ = model.compute_voronoi_weights(x)
entropy = -(weights * torch.log(weights + 1e-8)).sum(dim=-1).mean()
# 2. 中心分散正则化
center_dists = torch.pdist(model.centers).mean()
loss = loss - 0.1 * entropy + 0.01 * center_dists
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
if epoch % 10 == 0:
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch}: NLL = {avg_loss:.4f}")
return model11. 局限性与未来方向
11.1 当前局限
| 问题 | 描述 | 影响 |
|---|---|---|
| 中心初始化敏感 | 初始中心影响最终效果 | 需要好的初始化 |
| 高维挑战 | Voronoi在高维可能表现不佳 | 维度诅咒 |
| 动态几何 | 静态Voronoi无法适应动态数据 | 限制应用场景 |
11.2 未来方向
- 层次化Voronoi: 多尺度几何建模
- 动态Voronoi: 随时间演化的几何结构
- 深度Voronoi: Voronoi + 深度表示学习
- 与其他模型融合: Voronoi + NFs, Voronoi + GNNs
12. 参考
Footnotes
-
Sidheekh et al. (2026): Geometry-Aware Probabilistic Circuits via Voronoi Tessellations. arXiv:2603.11946. ↩