引言
胶囊神经网络(Capsule Network,CapsNet)由 Geoffrey Hinton 等人于2017年在论文 Dynamic Routing Between Capsules 中首次提出。1
CNN的局限性
传统的卷积神经网络(CNN)在图像分类任务中取得了巨大成功,但仍存在以下局限性:
- 最大池化的信息损失:最大池化只保留最显著的特征,丢弃了其他位置信息
- 空间关系建模不足:CNN难以学习特征之间的空间层次关系
- 平移不变性过强:CNN的对平移的不变性可能导致对旋转、缩放等变换过于敏感
胶囊网络的核心思想
胶囊网络通过以下创新来解决这些问题:
- 向量表示:胶囊使用向量而非标量来表示神经元
- 空间编码:向量的方向编码了特征的属性(如位置、姿态)
- 动态路由:低层胶囊通过迭代协议决定向哪些高层胶囊传递信息
胶囊的定义
从标量神经元到向量胶囊
传统神经元:
胶囊神经元:
其中 是胶囊的输入向量, 是胶囊的输出向量。
Squash函数
Squash函数将短向量压缩到接近0,长向量压缩到接近1的长度:
- 当 时,
- 当 时,
动态路由算法
动态路由是CapsNet的核心创新。低层胶囊首先预测高层胶囊的输入,然后通过迭代过程调整路由系数。
路由算法步骤
输入:低层胶囊输出 ,迭代次数
输出:高层胶囊输出
算法:
1. 初始化:b_ij = 0,对于所有 i, j
2. for r iterations:
3. c_i = softmax(b_i) # 路由系数归一化
4. s_j = Σ_i c_ij * uhat_j|i # 加权求和
5. v_j = squash(s_j) # 非线性压缩
6. b_ij += uhat_j|i · v_j # 更新 logits
7. return v_j
数学形式化
预测向量:
其中 是第 层胶囊 到第 层胶囊 的权重矩阵。
加权和:
路由系数:
一致性更新:
路由系数的意义
路由系数 表示第 层胶囊对第 层胶囊的”投票”权重。通过迭代,低层胶囊逐渐学会将信息发送到最需要它们的高层胶囊。
胶囊网络的结构
经典CapsNet架构
输入图像
↓
卷积层 (Conv1): 256个 9×9 卷积核, stride=1
↓
PrimaryCaps: 32个胶囊通道,每个通道8个胶囊
每个胶囊:8个 9×9 卷积核, stride=2
↓
动态路由 (迭代r次)
↓
DigitCaps: 10个胶囊(MNIST 10个类别)
每个胶囊:16维向量
↓
重构网络(可选)
DigitCaps层
DigitCaps层对每个数字类别输出一个16维胶囊:
- 向量长度表示该类别的概率(通过squash函数归一化到0-1)
- 向量方向编码了该类别的姿态信息
重构网络
重构网络用于正则化,通过重构输入图像来确保胶囊编码了足够的信息:
DigitCaps (16维 × 10)
↓ 全连接层 (512)
↓ 全连接层 (1024)
↓ 全连接层 (784)
↓ sigmoid
重构图像 (28×28)
重构损失:
其中 当输入属于类别 ,否则 。
与CNN的对比
特征表示对比
| 特性 | CNN | Capsule Network |
|---|---|---|
| 神经元输出 | 标量 | 向量 |
| 特征表示 | 激活值 | 激活值+方向 |
| 位置编码 | 最大池化丢失 | 胶囊方向保留 |
| 层次关系 | 隐式学习 | 显式建模 |
路由机制对比
| 特性 | 最大池化 | 动态路由 |
|---|---|---|
| 决策方式 | 硬选择 | 软选择 |
| 信息保留 | 丢弃非最大值 | 加权聚合 |
| 可学习性 | 不可学习 | 可学习 |
| 计算开销 | 低 | 较高 |
鲁棒性对比
Capsule Network对以下变换具有更好的鲁棒性:
- 旋转:膂囊编码了方向信息
- 缩放:空间关系保持
- 部分遮挡:路由机制确保信息传递
- 姿态变化:向量表示自然编码姿态
PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class Squash(nn.Module):
"""Squash激活函数"""
def forward(self, x):
"""
Args:
x: 输入向量 (batch, num_caps, dim_caps)
Returns:
压缩后的向量
"""
squared_norm = torch.sum(x ** 2, dim=-1, keepdim=True)
scale = squared_norm / (1 + squared_norm)
return scale * x / torch.sqrt(squared_norm + 1e-8)
class CapsuleLayer(nn.Module):
"""胶囊层"""
def __init__(self, num_caps_in, dim_caps_in,
num_caps_out, dim_caps_out,
num_routing=3):
super().__init__()
self.num_routing = num_routing
self.num_caps_in = num_caps_in
self.num_caps_out = num_caps_out
# 权重矩阵
self.W = nn.Parameter(
torch.randn(num_caps_in, num_caps_out, dim_caps_in, dim_caps_out)
)
def forward(self, u):
"""
Args:
u: 低层胶囊输出 (batch, num_caps_in, dim_caps_in)
Returns:
高层胶囊输出
"""
batch_size = u.size(0)
# u_hat: (batch, num_caps_in, 1, 1, dim_caps_in)
# → (batch, num_caps_in, num_caps_out, dim_caps_out)
u_expanded = u.unsqueeze(2).unsqueeze(4)
W_expanded = self.W.unsqueeze(0)
u_hat = torch.matmul(u_expanded, W_expanded)
u_hat = u_hat.squeeze(4) # (batch, num_caps_in, num_caps_out, dim_caps_out)
# 路由系数
b = torch.zeros(batch_size, self.num_caps_in, self.num_caps_out).to(u.device)
for _ in range(self.num_routing):
c = F.softmax(b, dim=2) # (batch, num_caps_in, num_caps_out)
c_expanded = c.unsqueeze(3) # (batch, num_caps_in, num_caps_out, 1)
s = torch.sum(c_expanded * u_hat, dim=1) # (batch, num_caps_out, dim_caps_out)
v = Squash()(s)
# 更新 logits
if self.num_routing > 1:
u_dot_v = torch.matmul(u_hat.unsqueeze(4), v.unsqueeze(4).transpose(3, 4))
b = b + u_dot_v.squeeze(4).squeeze(4)
return v
class PrimaryCaps(nn.Module):
"""主胶囊层"""
def __init__(self, in_channels, out_dim, num_caps, kernel_size=9, stride=2):
super().__init__()
self.num_caps = num_caps
self.out_dim = out_dim
self.conv = nn.Conv2d(
in_channels,
num_caps * out_dim,
kernel_size,
stride=stride,
padding=0
)
def forward(self, x):
batch_size = x.size(0)
x = self.conv(x)
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(batch_size, -1, self.num_caps, self.out_dim)
# Squash each capsule
x = Squash()(x)
return x
class CapsNet(nn.Module):
"""完整CapsNet"""
def __init__(self, num_classes=10, routing_iterations=3):
super().__init__()
self.num_classes = num_classes
# Conv1
self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
# PrimaryCaps
self.primary_caps = PrimaryCaps(256, 8, 32)
# DigitCaps
self.digit_caps = CapsuleLayer(
num_caps_in=32 * 6 * 6, # PrimaryCaps的胶囊数
dim_caps_in=8,
num_caps_out=num_classes,
dim_caps_out=16,
num_routing=routing_iterations
)
# 重构网络
self.decoder = nn.Sequential(
nn.Linear(16 * num_classes, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Sigmoid()
)
def forward(self, x, reconstruction=True):
x = F.relu(self.conv1(x))
x = self.primary_caps(x)
x = self.digit_caps(x)
# 取最长向量的长度作为类别概率
classes = torch.sqrt(torch.sum(x ** 2, dim=-1))
class_pred = classes.argmax(dim=1)
if reconstruction and self.training:
# 重构
batch_size = x.size(0)
masked_x = x * classes.argmax(dim=1, keepdim=True).unsqueeze(2).unsqueeze(3)
masked_x = masked_x.view(batch_size, -1)
reconstruction = self.decoder(masked_x)
return classes, reconstruction
else:
return classes, class_pred
def margin_loss(x, target, margin=0.9, down_weight=0.5):
"""
边缘损失函数
Args:
x: 胶囊输出的长度 (batch, num_classes)
target: 目标类别 (batch,)
margin: 正样本的边缘
down_weight: 负样本的权重
"""
batch_size = x.size(0)
target_one_hot = torch.zeros_like(x).scatter_(1, target.unsqueeze(1), 1)
left = F.relu(margin - x, inplace=True) ** 2
right = F.relu(x - 0.1, inplace=True) ** 2
loss = target_one_hot * left + down_weight * (1 - target_one_hot) * right
return loss.sum(dim=1).mean()实验结果
MNIST分类
| 方法 | 错误率 |
|---|---|
| CNN (Max Pooling) | 0.95% |
| Capsule Network | 0.25% |
CIFAR-10分类
| 方法 | 错误率 |
|---|---|
| CNN | 7.42% |
| Capsule Network | 5.6% |
MultiMNIST
CapsNet在重叠数字识别任务上显著优于CNN,因为它能够学习特征之间的空间关系。
参考文献
Footnotes
-
Sabour, S., Frosst, N., & Hinton, G. E. (2017). Dynamic routing between capsules. NeurIPS. ↩