动态路由概述
动态路由(Dynamic Routing)是胶囊网络的核心机制,它通过迭代协议使低层胶囊能够自适应地将信息传递到最需要的高层胶囊。相比于CNN中的最大池化,动态路由是一种”软件选择”(software choice)而非”硬选择”(hard choice)。
为什么需要动态路由?
- 信息聚合的自适应性:不同输入需要不同的信息聚合方式
- 层次结构的显式建模:空间关系通过路由显式编码
- 协议式信息传递:低层胶囊”投票”决定信息流向
标准动态路由算法
算法流程
def dynamic_routing(u_hat, num_routing=3):
"""
标准动态路由算法
Args:
u_hat: 预测向量 (batch, num_caps_in, num_caps_out, dim_out)
num_routing: 迭代次数
Returns:
v: 高层胶囊输出
"""
b = torch.zeros_like(u_hat[..., 0]) # (batch, num_caps_in, num_caps_out)
for _ in range(num_routing):
# 1. 计算路由系数
c = F.softmax(b, dim=2) # 沿输出胶囊维度归一化
# 2. 加权求和
s = torch.sum(c.unsqueeze(-1) * u_hat, dim=1)
# 3. 非线性压缩
v = squash(s)
# 4. 更新 logits(一致性)
b = b + torch.sum(u_hat * v.unsqueeze(1), dim=-1)
return v数学形式化
路由系数的softmax计算:
加权求和:
一致性更新:
标准路由的问题
- 计算开销大:每次迭代需要计算所有胶囊对之间的关系
- 梯度不稳定:更新公式依赖于输出,容易陷入局部最优
- 冗余胶囊:胶囊之间可能存在冗余
EM路由算法
Hinton等人于2018年提出了基于期望最大化(EM)的路由算法,用于Matrix Capsules架构。
EM算法的类比
| EM算法 | Matrix Capsules |
|---|---|
| E步 | 分配样本到聚类 |
| M步 | 更新聚类参数 |
| 混合系数 | 路由系数 |
EM路由步骤
输入:低层胶囊激活 ,迭代次数
E步(Expectation):
其中 是高层胶囊 的位置参数。
M步(Maximization):
更新位置:
更新方差:
更新成本:
PyTorch实现
class EMRouting(nn.Module):
"""EM路由实现"""
def __init__(self, num_caps_in, dim_caps_in,
num_caps_out, dim_caps_out,
num_iterations=2):
super().__init__()
self.num_iterations = num_iterations
# 权重矩阵
self.W = nn.Parameter(
torch.randn(num_caps_in, num_caps_out, dim_caps_in, dim_caps_out)
)
# 可学习的缩放参数
self.beta_u = nn.Parameter(torch.zeros(1, num_caps_out))
self.beta_a = nn.Parameter(torch.zeros(1, num_caps_out))
def em_routing(self, u_hat, lambda_=1.0):
"""
EM路由过程
Args:
u_hat: 预测向量 (batch, num_in, num_out, dim)
lambda_: annealing参数
"""
batch_size = u_hat.size(0)
num_in = u_hat.size(1)
num_out = u_hat.size(2)
dim_out = u_hat.size(3)
# 初始化
R = torch.ones(batch_size, num_in, num_out).to(u_hat.device) / num_out
# 初始化参数
sigma_sq = torch.ones(batch_size, num_out, 1).to(u_hat.device)
beta = self.beta_a.data.expand(batch_size, num_out, 1)
for _ in range(self.num_iterations):
# E步:计算后验概率
# 成本 = (u_hat - mu)^2 / sigma^2 + log(sigma^3)
cost = torch.sum((u_hat.unsqueeze(2) - sigma_sq) ** 2 / sigma_sq + torch.log(sigma_sq + 1e-8), dim=-1)
cost = cost - torch.max(cost, dim=-1, keepdim=True)[0]
# 使用lambda_进行annealing
logit = -(cost + self.beta_u.data) * lambda_
R = F.softmax(logit, dim=2)
# M步:更新参数
R_sum = R.sum(dim=1, keepdim=True) + 1e-8
# 更新均值
mu = torch.sum(R.unsqueeze(-1) * u_hat, dim=1) / R_sum
# 更新方差
diff_sq = (u_hat - mu.unsqueeze(1)) ** 2
sigma_sq = torch.sum(R.unsqueeze(-1) * diff_sq, dim=1) / R_sum + 1e-8
# 更新beta
beta = torch.sum(R.unsqueeze(-1) * (diff_sq / sigma_sq + torch.log(sigma_sq + 1e-8)), dim=1) / R_sum
# 最终加权求和
s = torch.sum(R.unsqueeze(-1) * u_hat, dim=1)
# 激活函数
v = squash(s)
return v注意力路由
Efficient-CapsNet (2021) 提出了基于自注意力的路由机制,显著降低了计算复杂度。
核心思想
将路由系数的计算视为一种注意力机制:
注意力路由的优势
- 可并行计算:利用矩阵乘法高效实现
- 感受野扩展:注意力机制可以建模更长距离的关系
- 可学习性:注意力权重可以通过反向传播学习
实现
class AttentionRouting(nn.Module):
"""注意力路由"""
def __init__(self, dim_caps, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim_caps // num_heads
# 可学习的查询和键
self.query = nn.Linear(dim_caps, dim_caps)
self.key = nn.Linear(dim_caps, dim_caps)
def forward(self, capsules):
"""
Args:
capsules: (batch, num_caps, dim_caps)
"""
batch_size, num_caps, dim_caps = capsules.shape
# 多头注意力
Q = self.query(capsules) # (batch, num_caps, dim_caps)
K = self.key(capsules) # (batch, num_caps, dim_caps)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (dim_caps ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
# 加权聚合
v = torch.matmul(attn_weights, capsules)
return v, attn_weights稀疏注意力路由
OrthCaps (CVPR 2024) 提出了正交权重与稀疏注意力路由的结合。
稀疏注意力机制
只保留最重要的 个连接:
class SparseAttentionRouting(nn.Module):
"""稀疏注意力路由"""
def __init__(self, top_k=8):
super().__init__()
self.top_k = top_k
def forward(self, capsules):
"""
Args:
capsules: (batch, num_caps, dim_caps)
Returns:
稀疏路由后的胶囊
"""
batch_size, num_caps, dim_caps = capsules.shape
# 计算胶囊之间的相似度
similarity = torch.matmul(capsules, capsules.transpose(-2, -1))
# Top-k 选择
top_k_values, indices = torch.topk(similarity, k=self.top_k, dim=-1)
# 创建稀疏注意力矩阵
sparse_attn = torch.zeros_like(similarity)
for i in range(batch_size):
for j in range(num_caps):
sparse_attn[i, j, indices[i, j]] = F.softmax(top_k_values[i, j], dim=-1)
# 加权聚合
v = torch.matmul(sparse_attn, capsules)
return v, sparse_attn窗口气囊路由
2024年的研究表明,使用局部窗口可以提高路由效率并减少梯度消失问题。
窗口设计
将胶囊划分为局部窗口,只在窗口内进行路由:
class WindowedRouting(nn.Module):
"""窗口气囊路由"""
def __init__(self, window_size=3, stride=1):
super().__init__()
self.window_size = window_size
self.stride = stride
def forward(self, capsules):
"""
Args:
capsules: (batch, num_caps, dim_caps)
Returns:
路由后的胶囊
"""
batch_size, num_caps, dim_caps = capsules.shape
out_capsules = []
# 滑动窗口
for start in range(0, num_caps - self.window_size + 1, self.stride):
window = capsules[:, start:start + self.window_size, :]
# 在窗口内进行路由
routed = self._intra_window_routing(window)
out_capsules.append(routed)
return torch.stack(out_capsules, dim=1)
def _intra_window_routing(self, window):
"""窗口内路由"""
# 简化:使用平均池化
return window.mean(dim=1)协议路由机制
协议路由(Agreement Routing)通过最大化胶囊之间的”协议”来优化路由系数。
核心公式
其中 是通过以下协议更新:
协议度量
- 点积相似度:
- 余弦相似度:
- 欧氏距离相似度:
路由算法对比
| 算法 | 计算复杂度 | 收敛速度 | 表达能力 | 主要应用 |
|---|---|---|---|---|
| 标准路由 | 慢 | 中 | MNIST/CIFAR | |
| EM路由 | 快 | 高 | SmallNORB | |
| 注意力路由 | 快 | 高 | 实时应用 | |
| 稀疏路由 | 快 | 中 | 大规模 | |
| 窗口气囊路由 | 快 | 中 | 高效推理 |
实践技巧
1. 路由迭代次数
- MNIST: 3次迭代足够
- CIFAR-10: 3-5次迭代
- 迭代过多可能导致过拟合
2. 初始化
# 初始化 logits 为零
b = torch.zeros(num_caps_in, num_caps_out)3. 梯度裁剪
# 防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)4. 温度参数
使用温度参数控制路由的”软硬”程度:
- :更软的路由
- :更硬的路由