引言
自2017年CapsNet首次提出以来,研究者们提出了多种改进架构,以提高胶囊网络的效率、表达能力和应用范围。本文档将系统介绍这些现代架构及其核心创新。
经典CapsNet架构
Sabour et al. (2017) - 原始CapsNet
架构:
输入 (28×28)
→ Conv1 (256, 9×9, stride=1)
→ PrimaryCaps (32×8, 9×9, stride=2)
→ DigitCaps (10×16)
→ 重构网络
特点:
- 首次提出动态路由机制
- 向量胶囊替代标量神经元
- 重构作为正则化
Hinton et al. (2018) - Matrix Capsules with EM Routing
架构创新:
- 矩阵胶囊替代向量胶囊
- EM路由替代标准路由
- 位置编码更加丰富
EM路由优势:
- 更快的收敛
- 更好的聚类效果
- 对噪声更鲁棒
Efficient-CapsNet (2021)
核心创新
Efficient-CapsNet 由 Mazzia 等人提出,主要改进包括:
- 自注意力路由:替代迭代式动态路由
- 可分离卷积:减少参数量
- 多头注意力:增强表达能力
架构
class EfficientCapsNet(nn.Module):
"""Efficient-CapsNet架构"""
def __init__(self, num_classes=10):
super().__init__()
# 特征提取器
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, 3, stride=2, padding=1),
)
# 胶囊层
self.caps1 = ConvCapsuleLayer(
in_channels=32,
out_channels=16,
kernel_size=3,
stride=2,
num_capsules=8,
dim_capsules=8
)
# 自注意力路由
self.attention_routing = AttentionRouting(dim_caps=16, num_heads=4)
# 输出胶囊
self.caps2 = DenseCapsuleLayer(
num_caps_in=8 * 7 * 7,
dim_caps_in=64,
num_caps_out=num_classes,
dim_caps_out=16
)
def forward(self, x):
x = self.conv1(x)
x = self.caps1(x)
x, attn = self.attention_routing(x)
x = self.caps2(x)
return x, attn性能对比
| 方法 | MNIST错误率 | 参数量 |
|---|---|---|
| 原始CapsNet | 0.25% | 8.2M |
| Efficient-CapsNet | 0.28% | 1.5M |
OrthCaps (CVPR 2024)
核心创新
OrthCaps 由 Geng 等人在 CVPR 2024 提出,主要解决胶囊网络的冗余问题。
三大创新:
- 正交权重:减少胶囊间的冗余
- 稀疏注意力路由:提高计算效率
- 剪枝:去除不重要的胶囊
正交权重初始化
传统CapsNet的权重初始化可能导致胶囊之间的冗余。OrthCaps采用正交初始化:
class OrthogonalCapsuleLayer(nn.Module):
"""正交胶囊层"""
def __init__(self, num_caps_in, num_caps_out, dim_caps):
super().__init__()
self.W = nn.Parameter(torch.randn(num_caps_in, num_caps_out, dim_caps, dim_caps))
# 正交初始化
nn.init.orthogonal_(self.W.data)
def forward(self, u):
# 正交变换
u_hat = torch.einsum('bic,icoo->bico', u, self.W)
return u_hat稀疏注意力路由
OrthCaps使用稀疏注意力来替代密集路由:
class SparseAttentionRouting(nn.Module):
"""稀疏注意力路由"""
def __init__(self, num_heads=4, top_k=4):
super().__init__()
self.num_heads = num_heads
self.top_k = top_k
def forward(self, capsules):
batch_size, num_caps, dim_caps = capsules.shape
# 多头查询
Q = capsules.view(batch_size, num_caps, self.num_heads, -1)
K = capsules.view(batch_size, num_caps, self.num_heads, -1)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (dim_caps ** 0.5)
# Top-K 稀疏化
top_k_scores, indices = torch.topk(scores, k=self.top_k, dim=-1)
# 归一化
sparse_attn = F.softmax(top_k_scores, dim=-1)
return sparse_attn, indices性能对比
| 方法 | CIFAR-10 | 参数量 | 推理速度 |
|---|---|---|---|
| 原始CapsNet | 5.6% | 7.2M | 1× |
| Efficient-CapsNet | 6.1% | 1.5M | 3.2× |
| OrthCaps | 5.2% | 2.1M | 2.8× |
IBCapsNet (2026)
信息瓶颈理论
IBCapsNet 将信息瓶颈(Information Bottleneck)理论引入胶囊网络,实现噪声鲁棒的表示学习。
核心思想
最大化压缩表示与任务相关信息之间的互信息:
其中 是输入, 是胶囊表示, 是标签。
架构设计
class IBCapsLayer(nn.Module):
"""信息瓶颈胶囊层"""
def __init__(self, dim_caps, beta=0.1):
super().__init__()
self.beta = beta
# 编码器
self.encoder = nn.Sequential(
nn.Linear(dim_caps, dim_caps),
nn.ReLU()
)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(dim_caps, dim_caps),
nn.Sigmoid()
)
# beta 参数
self.beta_param = nn.Parameter(torch.tensor(beta))
def information_bottleneck(self, u):
"""信息瓶颈压缩"""
# 均值和方差
mu = self.encoder(u)
log_var = torch.log(torch.var(u, dim=-1, keepdim=True) + 1e-8)
# 重参化
z = mu + torch.exp(0.5 * log_var) * torch.randn_like(mu)
# 重构
u_recon = self.decoder(z)
# 信息瓶颈损失
ib_loss = self.beta_param * (mu ** 2 + log_var).mean()
return z, ib_lossPR-CapsNet (2025)
伪黎曼几何
PR-CapsNet 将胶囊网络推广到伪黎曼流形上,特别适用于图结构数据。
核心公式
在伪黎曼流形上定义胶囊:
其中 是伪黎曼范数。
图胶囊路由
class PseudoRiemannianRouting(nn.Module):
"""伪黎曼路由"""
def __init__(self, curvature=1.0):
super().__init__()
self.curvature = curvature
def metric(self, x):
"""伪黎曼度量"""
# diag(-1, 1, 1, ..., 1)
g = torch.eye(x.size(-1)).to(x.device)
g[0, 0] = -1
return g
def riemannian_norm(self, x):
"""伪黎曼范数"""
g = self.metric(x)
return torch.sqrt(torch.abs(torch.sum(x @ g * x, dim=-1)) + 1e-8)
def forward(self, u_hat, adj_matrix):
"""
Args:
u_hat: 预测向量
adj_matrix: 邻接矩阵
"""
batch_size = u_hat.size(0)
num_nodes = u_hat.size(1)
# 伪黎曼距离
dist = self.riemannian_norm(u_hat.unsqueeze(2) - u_hat.unsqueeze(1))
# 结合图结构
routing_weights = torch.softmax(-dist * adj_matrix, dim=-1)
s = torch.sum(routing_weights.unsqueeze(-1) * u_hat, dim=2)
v = self.squash_r(s)
return vMSPCaps (2025)
多尺度补丁
MSPCaps 提出多尺度补丁表示与交叉协议路由,显著提升了视觉识别能力。
核心创新
- 多尺度补丁划分:不同尺度的局部特征
- 交叉协议路由:跨尺度信息交互
- 尺度感知注意力:自适应融合多尺度信息
架构
class MultiScalePatchify(nn.Module):
"""多尺度补丁划分"""
def __init__(self, scales=[2, 4, 8]):
super().__init__()
self.scales = scales
def forward(self, x):
"""
Args:
x: (batch, channels, H, W)
Returns:
多尺度补丁列表
"""
patches = []
for scale in self.scales:
# 划分补丁
B, C, H, W = x.shape
patch_size = H // scale
x_patched = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
x_patched = x_patched.contiguous().view(B, C, -1, patch_size, patch_size)
patches.append(x_patched)
return patches
class CrossScaleRouting(nn.Module):
"""交叉尺度路由"""
def __init__(self, dim_caps):
super().__init__()
self.scale_attention = nn.Linear(dim_caps * len(scales), dim_caps)
def forward(self, multi_scale_caps):
"""
Args:
multi_scale_caps: 多尺度胶囊列表
"""
# 拼接多尺度胶囊
concat_caps = torch.cat(multi_scale_caps, dim=-1)
# 尺度感知注意力
scale_weights = torch.sigmoid(self.scale_attention(concat_caps))
# 加权融合
fused = sum(w * c for w, c in zip(scale_weights.chunk(len(multi_scale_caps), -1),
multi_scale_caps))
return fused图胶囊网络
Graph CapsNet
将胶囊网络扩展到图结构数据:
class GraphCapsuleLayer(nn.Module):
"""图胶囊层"""
def __init__(self, in_features, out_caps, dim_caps, num_iterations=3):
super().__init__()
self.out_caps = out_caps
self.dim_caps = dim_caps
self.num_iterations = num_iterations
# 节点特定的变换
self.W = nn.Parameter(
torch.randn(in_features, out_caps * dim_caps)
)
def forward(self, x, adj):
"""
Args:
x: 节点特征 (batch, num_nodes, in_features)
adj: 邻接矩阵 (batch, num_nodes, num_nodes)
"""
batch_size, num_nodes, _ = x.shape
# 线性变换
h = torch.matmul(x, self.W)
h = h.view(batch_size, num_nodes, self.out_caps, self.dim_caps)
# 图感知的动态路由
for _ in range(self.num_iterations):
# 基于图的路由权重
routing = torch.softmax(adj.unsqueeze(-1).unsqueeze(-1), dim=2)
# 加权聚合
s = torch.sum(routing * h, dim=1, keepdim=True)
v = self.squash(s)
# 更新(使用图结构)
adj_expanded = adj.unsqueeze(-1).unsqueeze(-1)
h = h + torch.matmul(adj_expanded, (v - h))
return v.squeeze(1)架构对比总结
| 架构 | 年份 | 主要创新 | 优势 | 应用场景 |
|---|---|---|---|---|
| 原始CapsNet | 2017 | 动态路由 | 理论基础 | 图像分类 |
| Matrix Capsules | 2018 | EM路由 | 表达能力强 | 小样本学习 |
| Efficient-CapsNet | 2021 | 注意力路由 | 高效 | 实时应用 |
| OrthCaps | 2024 | 正交+稀疏 | 减少冗余 | 大规模部署 |
| IBCapsNet | 2026 | 信息瓶颈 | 噪声鲁棒 | 医学图像 |
| PR-CapsNet | 2025 | 伪黎曼几何 | 图结构建模 | 知识图谱 |
| MSPCaps | 2025 | 多尺度 | 细粒度特征 | 细粒度分类 |
实现建议
选择指南
- 小数据集:原始CapsNet或Matrix Capsules
- 实时应用:Efficient-CapsNet
- 大规模部署:OrthCaps
- 噪声环境:IBCapsNet
- 图数据:PR-CapsNet
- 细粒度任务:MSPCaps
实践技巧
# 1. 根据任务选择路由算法
if dataset_size < 10000:
routing = "EM" # 小数据集用EM路由
else:
routing = "attention" # 大数据集用注意力路由
# 2. 胶囊维度选择
if num_classes <= 10:
dim_caps = 16
elif num_classes <= 100:
dim_caps = 8
else:
dim_caps = 4
# 3. 层数选择
if input_resolution < 64:
num_caps_layers = 2
else:
num_caps_layers = 3