概述

几何感知概率电路(Geometry-Aware Probabilistic Circuits)是一种通过Voronoi镶嵌(Voronoi Tessellation)局部几何结构直接嵌入概率电路的框架。1

传统概率电路的局限性在于:

  • 混合权重与数据无关: 固定的高斯/均匀混合权重无法捕捉数据的局部几何
  • 忽略流形结构: 假设数据分布是全局均匀的,无法适应弯曲的数据流形
  • 表达能力受限: 在高维弯曲流形上表现不佳

几何感知概率电路的核心洞察:

数据流形的局部几何结构可以通过Voronoi镶嵌自然地编码到概率电路中,同时保持可处理推断的能力。


1. 问题背景

1.1 数据流形与几何结构

现实世界的数据往往位于高维空间中的低维流形上:

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│        高维数据空间 R^d                                    │
│                                                             │
│                    ╭───────────╮                           │
│                   ╱             ╲                          │
│                  ╱    数据流形    ╲                        │
│                 ╱  (低维曲面)     ╲                       │
│                ╱                   ╲                      │
│               │    ● ●    ●         │                     │
│               │  ●    ●  ●  ●       │                     │
│               │    ●  ●   ●         │                     │
│               ╲      ●        ●    ╱                      │
│                ╲       ●   ●      ╱                       │
│                 ╲      ●●     ╱                          │
│                  ╰─────────────╯                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 传统概率电路的问题

传统概率电路(如SPN)使用固定的混合权重:

# 传统SPN的混合
def traditional_mixture(x):
    total = 0
    for k in range(num_mixtures):
        # 固定权重,与数据无关
        weight = mixture_weights[k]  # 与x无关
        prob = gaussian(x, mean[k], std[k])
        total += weight * prob
    return total

问题

  1. 权重 与输入 无关
  2. 无法适应局部几何变化
  3. 在流形上表现不佳

1.3 几何感知的解决方案

几何感知概率电路的核心思想:

用Voronoi镶嵌来划分输入空间,每个Voronoi区域独立建模,从而捕捉局部几何

┌─────────────────────────────────────────────────────────────┐
│                    Voronoi镶嵌                              │
│                                                             │
│         ╱╲     ╱╲     ╱╲                                  │
│        ╱  ╲   ╱  ╲   ╱  ╲                                 │
│       ╱    ╲ ╱    ╲ ╱    ╲                                │
│      ╱   ╱╲ ╲   ╱╲ ╲   ╱╲ ╲                              │
│     ╱   ╱  ╲ ╲ ╱  ╲ ╲ ╱  ╲ ╲                             │
│    │   │ V1 │ │ V2 │ │ V3 │  ...                          │
│    │   │ ●  │ │ ●● │ │  ● │                              │
│    ╲   ╲  ●╱ ╲   ●╱ ╲  ●╱   ╲                            │
│     ╲   ╲╱    ╲╱    ╲╱    ╱   ╲                           │
│      ╲       ╲      ╱      ╱                              │
│       ╲       ╲    ╱      ╱                               │
│                                                             │
│  每个Voronoi区域Vi有独立的概率模型P(x | x ∈ Vi)            │
└─────────────────────────────────────────────────────────────┘

2. Voronoi镶嵌基础

2.1 Voronoi镶嵌定义

定义(Voronoi镶嵌): 给定一组种子点 ,Voronoi镶嵌将空间划分为 个区域:

其中 是第 个Voronoi单元的种子点/中心。

2.2 几何性质

性质描述
覆盖性
互斥性
局部性每个区域可独立建模
可微性边界处可定义次梯度

2.3 Voronoi镶嵌与概率分布

关键洞察: Voronoi镶嵌可以与概率分布结合:

其中 属于第 个Voronoi区域的先验。


3. 几何感知概率电路架构

3.1 核心架构

┌─────────────────────────────────────────────────────────────┐
│               几何感知概率电路 (VT-PC)                       │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  输入x ──→ [距离计算层] ──→ [Voronoi分配]                  │
│                  │                │                         │
│                  ▼                ▼                         │
│           计算到各中心的距离    确定所属Voronoi单元          │
│                  │                │                         │
│                  └────────┬───────┘                         │
│                           ▼                                 │
│                    [几何感知混合]                          │
│                           │                                 │
│          ┌────────────────┼────────────────┐               │
│          ▼                ▼                ▼               │
│      几何权重          局部概率          边界处理           │
│   (基于Voronoi)    (Voronoi内)      (软化边界)             │
│                           │                                 │
│                           ▼                                 │
│                      输出概率P(x)                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

3.2 几何权重计算

class GeometricWeights(nn.Module):
    """
    基于Voronoi的几何权重计算
    """
    def __init__(self, input_dim, num_centers, temperature=1.0):
        super().__init__()
        self.centers = nn.Parameter(
            torch.randn(num_centers, input_dim)
        )  # Voronoi种子点
        self.temperature = temperature
    
    def compute_weights(self, x):
        """
        计算每个Voronoi区域的几何权重
        
        Args:
            x: (batch, input_dim)
        
        Returns:
            weights: (batch, num_centers) 每个样本属于各区域的后验概率
        """
        # 计算距离
        distances = torch.cdist(x, self.centers)  # (batch, num_centers)
        
        # 软分配(基于距离的softmax)
        logits = -distances / self.temperature
        weights = F.softmax(logits, dim=-1)  # (batch, num_centers)
        
        return weights, distances
    
    def voronoi_assignment(self, x):
        """
        硬Voronoi分配(每个点属于最近中心)
        """
        distances = torch.cdist(x, self.centers)
        assignment = distances.argmin(dim=-1)  # (batch,)
        return assignment, distances

3.3 局部概率模型

class LocalProbabilisticModel(nn.Module):
    """
    Voronoi单元内的局部概率模型
    """
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        # 每个Voronoi单元有独立的概率参数
        self.means = nn.Parameter(torch.randn(1, input_dim))
        self.log_stds = nn.Parameter(torch.zeros(1, input_dim))
    
    def forward(self, x):
        """
        计算局部概率密度
        """
        diff = x - self.means
        log_prob = -0.5 * (
            diff ** 2 / (torch.exp(2 * self.log_stds) + 1e-8)
        ) - 0.5 * torch.log(2 * torch.pi * torch.exp(2 * self.log_stds) + 1e-8)
        return log_prob.sum(dim=-1)
 
 
class GeometryAwarePC(nn.Module):
    """
    几何感知概率电路完整实现
    """
    def __init__(self, input_dim, num_centers, hidden_dim):
        super().__init__()
        self.num_centers = num_centers
        
        # Voronoi几何层
        self.geo_weights = GeometricWeights(input_dim, num_centers)
        
        # 局部概率模型(每个Voronoi单元一个)
        self.local_models = nn.ModuleList([
            LocalProbabilisticModel(input_dim, hidden_dim)
            for _ in range(num_centers)
        ])
        
        # 边界软化参数
        self.boundary_softness = nn.Parameter(torch.tensor(0.1))
    
    def forward(self, x):
        """
        几何感知前向传播
        """
        # 获取几何权重
        weights, distances = self.geo_weights.compute_weights(x)
        
        # 计算各Voronoi单元的局部概率
        local_probs = []
        for i, model in enumerate(self.local_models):
            prob = model(x)  # (batch,)
            local_probs.append(prob)
        
        local_probs = torch.stack(local_probs, dim=1)  # (batch, num_centers)
        
        # 加权求和
        log_weights = torch.log(weights + 1e-8)
        weighted_probs = local_probs + log_weights
        
        # 混合
        log_marginal = torch.logsumexp(weighted_probs, dim=1)
        
        return log_marginal
    
    def voronoi_density(self, x):
        """
        Voronoi条件密度
        P(x | Voronoi region)
        """
        assignment, distances = self.geo_weights.voronoi_assignment(x)
        
        densities = []
        for i, model in enumerate(self.local_models):
            mask = (assignment == i).float().unsqueeze(-1)
            prob = model(x) * mask
            densities.append(prob)
        
        return torch.cat(densities, dim=1)

4. 可处理性分析

4.1 核心定理

定理(VT-PC可处理性): 设 VT-PC 是一个几何感知概率电路,包含 个Voronoi单元,则:

  1. 边际推断: 可以在 时间内计算
  2. 条件概率: 可以精确计算
  3. 期望计算: 可以精确计算

证明思路

  • Voronoi镶嵌的互斥性保证:
  • 每个 内部的积分是局部的
  • 总复杂度与Voronoi单元数线性相关

4.2 梯度计算

def compute_gradient(x, model):
    """
    计算VT-PC的梯度
    """
    x.requires_grad_(True)
    log_prob = model(x)
    
    # 反向传播
    log_prob.backward()
    
    # 提取梯度
    grad_x = x.grad
    grad_params = {name: p.grad for name, p in model.named_parameters()}
    
    return grad_x, grad_params
 
 
def voronoi_gradient(x, assignment, model):
    """
    Voronoi感知梯度
    """
    x.requires_grad_(True)
    log_prob = model(x)
    log_prob.backward()
    
    # Voronoi感知的梯度修正
    # 梯度方向应该指向局部几何流形方向
    centers = model.geo_weights.centers
    
    # 指向最近中心的方向
    distances, assignment = model.geo_weights.compute_weights(x)
    nearest_center = centers[assignment]
    direction = nearest_center - x
    
    # 修正梯度
    grad_x = x.grad
    corrected_grad = grad_x + model.boundary_softness * direction
    
    return corrected_grad

5. 软边界与连续性

5.1 硬边界的问题

硬Voronoi边界会导致不连续性:

硬边界:                          软边界:
    │                            /
    │    A │ B                   │   A+B
    │──────┼────                 │──/────
    │      │                     │ /
    ▼      ▼                     ▼▼
    
不连续                         连续

5.2 软边界实现

class SoftVoronoiBoundary(nn.Module):
    """
    Voronoi软边界实现
    """
    def __init__(self, num_centers, input_dim, beta=1.0):
        super().__init__()
        self.centers = nn.Parameter(torch.randn(num_centers, input_dim))
        self.beta = beta  # 软化参数
    
    def soft_boundary_weights(self, x):
        """
        软边界权重(基于logistic sigmoid)
        """
        distances = torch.cdist(x, self.centers)  # (batch, K)
        
        # 排序距离
        sorted_dist, indices = torch.sort(distances, dim=-1)
        
        # 最近和次近距离
        d1 = sorted_dist[:, 0]
        d2 = sorted_dist[:, 1]
        
        # 软边界权重(越接近边界越小)
        boundary_proximity = torch.sigmoid(self.beta * (d2 - d1))
        
        # 各单元的软权重
        weights = torch.zeros_like(distances)
        weights.scatter_(1, indices[:, :1], 1.0)  # 最近中心权重为1
        
        # 边界处平滑
        for b in range(x.size(0)):
            i = indices[b, 0].item()
            j = indices[b, 1].item()
            
            alpha = boundary_proximity[b]
            weights[b, i] = 1 - alpha * 0.5
            weights[b, j] = alpha * 0.5
        
        return weights, indices
 
 
class ContinuousVTPC(nn.Module):
    """
    连续几何感知概率电路
    """
    def __init__(self, input_dim, num_centers, hidden_dim):
        super().__init__()
        self.soft_boundary = SoftVoronoiBoundary(num_centers, input_dim)
        self.local_models = nn.ModuleList([
            LocalProbabilisticModel(input_dim, hidden_dim)
            for _ in range(num_centers)
        ])
    
    def forward(self, x):
        """
        软边界前向传播
        """
        # 软边界权重
        weights, indices = self.soft_boundary.soft_boundary_weights(x)
        
        # 局部概率
        local_probs = torch.stack([
            model(x) for model in self.local_models
        ], dim=1)  # (batch, K)
        
        # 加权混合
        log_weights = torch.log(weights + 1e-8)
        weighted_probs = local_probs + log_weights
        
        # log-sum-exp混合
        log_marginal = torch.logsumexp(weighted_probs, dim=1)
        
        return log_marginal

6. 学习算法

6.1 Voronoi中心学习

def train_voronoi_centers(model, data_loader, lr=1e-3, num_iterations=100):
    """
    学习Voronoi中心
    """
    optimizer = torch.optim.Adam(model.geo_weights.centers, lr=lr)
    
    for iteration in range(num_iterations):
        total_loss = 0
        for x, in data_loader:
            optimizer.zero_grad()
            
            # 负对数似然
            log_prob = model(x)
            loss = -log_prob.mean()
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if iteration % 10 == 0:
            print(f"Iteration {iteration}: Loss = {total_loss / len(data_loader):.4f}")
    
    return model
 
 
def voronoi_kmeans_loss(x, centers):
    """
    Voronoi K-means损失
    用于初始化中心
    """
    distances = torch.cdist(x, centers)
    assignment = distances.argmin(dim=1)
    
    loss = 0
    for i in range(centers.size(0)):
        mask = (assignment == i)
        if mask.sum() > 0:
            cluster_points = x[mask]
            center = centers[i]
            loss += ((cluster_points - center) ** 2).sum()
    
    return loss

6.2 端到端训练

def train_geometry_aware_pc(model, train_loader, num_epochs=100, lr=1e-3):
    """
    端到端训练几何感知概率电路
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        for x, in train_loader:
            optimizer.zero_grad()
            
            # 负对数似然
            log_prob = model(x)
            nll_loss = -log_prob.mean()
            
            # 正则化:鼓励均匀分布的Voronoi
            weights, _ = model.geo_weights.compute_weights(x)
            entropy_loss = -torch.mean(
                weights * torch.log(weights + 1e-8)
            )
            
            # 总损失
            loss = nll_loss + 0.1 * entropy_loss
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        scheduler.step()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {epoch_loss / len(train_loader):.4f}")
    
    return model

7. 应用场景

7.1 密度估计

class DensityEstimator(nn.Module):
    """
    基于几何感知PC的密度估计
    """
    def __init__(self, input_dim, num_components):
        super().__init__()
        self.pc = GeometryAwarePC(
            input_dim=input_dim,
            num_centers=num_components,
            hidden_dim=128
        )
    
    def fit(self, data, num_epochs=100):
        """
        拟合数据分布
        """
        loader = DataLoader(TensorDataset(data), batch_size=256, shuffle=True)
        return train_geometry_aware_pc(self.pc, loader, num_epochs)
    
    def log_prob(self, x):
        """计算对数概率密度"""
        return self.pc(x)
    
    def sample(self, num_samples):
        """从学到的分布中采样"""
        with torch.no_grad():
            # 采样Voronoi单元
            indices = torch.randint(0, self.pc.num_centers, (num_samples,))
            
            # 从各单元分布采样
            samples = []
            for i in range(num_samples):
                model = self.pc.local_models[indices[i]]
                z = torch.randn(1, self.pc.geo_weights.centers.size(1))
                mean = model.means[0]
                std = torch.exp(model.log_stds[0])
                sample = mean + std * z
                samples.append(sample)
            
            return torch.cat(samples, dim=0)

7.2 异常检测

class AnomalyDetector(nn.Module):
    """
    基于几何感知PC的异常检测
    """
    def __init__(self, input_dim, num_components, threshold=0.01):
        super().__init__()
        self.pc = GeometryAwarePC(input_dim, num_components, 128)
        self.threshold = threshold
    
    def forward(self, x):
        """
        返回异常分数(越低越异常)
        """
        log_prob = self.pc(x)
        anomaly_score = torch.exp(log_prob)
        return anomaly_score
    
    def predict(self, x):
        """
        预测是否为异常
        """
        scores = self.forward(x)
        return scores < self.threshold

7.3 图像生成

class ImageGenerator(nn.Module):
    """
    基于几何感知PC的图像生成
    """
    def __init__(self, latent_dim, num_components, image_size):
        super().__init__()
        self.latent_dim = latent_dim
        self.image_size = image_size
        
        # 像素级的几何感知PC
        self.pc = GeometryAwarePC(
            input_dim=latent_dim,
            num_centers=num_components,
            hidden_dim=256
        )
        
        # 从潜在空间到图像空间的映射
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, image_size * image_size * 3),
            nn.Sigmoid()
        )
    
    def forward(self, z):
        """
        生成图像
        """
        # 生成
        images = self.decoder(z)
        return images.view(-1, 3, self.image_size, self.image_size)
    
    def log_prob(self, x):
        """
        计算图像的对数概率
        """
        # 简化:仅在潜在空间计算
        z = self.encoder(x)
        return self.pc(z)

8. 与其他方法的对比

8.1 方法对比表

方法几何感知可处理性表达能力扩展性
几何感知PC✓ 原生✓ 精确中等
标准GMM✓ 精确中等
标准SPN✓ 精确中等中等
VAE✗ 近似
NF部分✓ 精确中等

8.2 几何感知PC的优势

  1. 局部几何建模: Voronoi镶嵌自然编码局部结构
  2. 精确推断: 保持概率电路的可处理性
  3. 可微学习: 端到端可训练
  4. 灵活边界: 软边界处理连续性

9. 理论分析

9.1 表达能力

定理(几何感知表达能力): 设 是定义在流形 上的真实分布,VT-PC的近似误差:

其中 在Voronoi镶嵌下的投影。

9.2 收敛性

定理(收敛性): 设训练集 采样,则经验风险:

满足:

9.3 计算复杂度

操作时间复杂度空间复杂度
前向传播
梯度计算
Voronoi更新

其中 是批量大小, 是Voronoi单元数, 是输入维度。


10. 实现细节

10.1 完整实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
 
class VoronoiTessellationPC(nn.Module):
    """
    Voronoi镶嵌概率电路完整实现
    """
    def __init__(self, input_dim, num_centers, hidden_dim=128, 
                 learnable_boundary=True, temperature=1.0):
        super().__init__()
        
        self.input_dim = input_dim
        self.num_centers = num_centers
        self.temperature = temperature
        
        # Voronoi中心(可学习)
        self.centers = nn.Parameter(
            torch.randn(num_centers, input_dim) * 0.1
        )
        
        # 局部概率模型参数
        self.local_means = nn.Parameter(
            torch.randn(num_centers, input_dim) * 0.1
        )
        self.local_log_stds = nn.Parameter(
            torch.zeros(num_centers, input_dim)
        )
        
        # 混合权重(可选)
        self.mixture_logits = nn.Parameter(torch.zeros(num_centers))
        
        # 边界软化参数
        if learnable_boundary:
            self.boundary_softness = nn.Parameter(torch.tensor(1.0))
        else:
            self.boundary_softness = None
    
    def compute_voronoi_weights(self, x):
        """
        计算Voronoi权重
        """
        # 距离
        distances = torch.cdist(x, self.centers)  # (B, K)
        
        # 软分配
        logits = -distances / self.temperature
        weights = F.softmax(logits, dim=-1)
        
        # 硬分配
        assignment = distances.argmin(dim=-1)
        
        return weights, assignment, distances
    
    def local_log_prob(self, x, indices):
        """
        计算局部对数概率
        """
        means = self.local_means[indices]
        stds = torch.exp(self.local_log_stds[indices])
        
        # 高斯对数密度
        diff = x.unsqueeze(1) - means  # (B, 1, d)
        log_prob = -0.5 * (
            (diff ** 2) / (stds ** 2 + 1e-8) + 
            torch.log(2 * torch.pi * stds ** 2 + 1e-8)
        )
        
        return log_prob.sum(dim=-1)  # (B, K)
    
    def forward(self, x):
        """
        前向传播
        """
        B = x.size(0)
        
        # Voronoi权重
        weights, assignment, distances = self.compute_voronoi_weights(x)
        
        # 局部概率
        local_probs = self.local_log_prob(x, assignment)  # (B, K)
        
        # 混合权重
        mix_weights = F.softmax(self.mixture_logits, dim=0)  # (K,)
        
        # 加权混合
        weighted_probs = local_probs + torch.log(weights + 1e-8) + torch.log(mix_weights)
        
        # log-sum-exp
        log_marginal = torch.logsumexp(weighted_probs, dim=-1)
        
        return log_marginal
    
    def sample(self, num_samples):
        """
        从模型中采样
        """
        with torch.no_grad():
            # 采样Voronoi单元
            indices = torch.multinomial(
                F.softmax(self.mixture_logits, dim=0),
                num_samples,
                replacement=True
            )
            
            # 从各单元高斯采样
            means = self.local_means[indices]
            stds = torch.exp(self.local_log_stds[indices])
            
            z = torch.randn(num_samples, self.input_dim, device=means.device)
            samples = means + stds * z
            
            return samples

10.2 训练脚本

def train_voronoi_pc(model, train_data, num_epochs=100, batch_size=256, lr=1e-3):
    """
    训练Voronoi概率电路
    """
    dataset = torch.utils.data.TensorDataset(train_data)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for x, in loader:
            optimizer.zero_grad()
            
            # 负对数似然
            log_prob = model(x)
            loss = -log_prob.mean()
            
            # 可选:正则化
            # 1. 熵正则化(鼓励均匀分配)
            weights, _, _ = model.compute_voronoi_weights(x)
            entropy = -(weights * torch.log(weights + 1e-8)).sum(dim=-1).mean()
            
            # 2. 中心分散正则化
            center_dists = torch.pdist(model.centers).mean()
            
            loss = loss - 0.1 * entropy + 0.01 * center_dists
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        
        if epoch % 10 == 0:
            avg_loss = total_loss / len(loader)
            print(f"Epoch {epoch}: NLL = {avg_loss:.4f}")
    
    return model

11. 局限性与未来方向

11.1 当前局限

问题描述影响
中心初始化敏感初始中心影响最终效果需要好的初始化
高维挑战Voronoi在高维可能表现不佳维度诅咒
动态几何静态Voronoi无法适应动态数据限制应用场景

11.2 未来方向

  1. 层次化Voronoi: 多尺度几何建模
  2. 动态Voronoi: 随时间演化的几何结构
  3. 深度Voronoi: Voronoi + 深度表示学习
  4. 与其他模型融合: Voronoi + NFs, Voronoi + GNNs

12. 参考


相关文档: 神经概率电路 | 概率图电路 | 概率电路基础

Footnotes

  1. Sidheekh et al. (2026): Geometry-Aware Probabilistic Circuits via Voronoi Tessellations. arXiv:2603.11946.