引言

胶囊神经网络(Capsule Network,CapsNet)由 Geoffrey Hinton 等人于2017年在论文 Dynamic Routing Between Capsules 中首次提出。1

CNN的局限性

传统的卷积神经网络(CNN)在图像分类任务中取得了巨大成功,但仍存在以下局限性:

  1. 最大池化的信息损失:最大池化只保留最显著的特征,丢弃了其他位置信息
  2. 空间关系建模不足:CNN难以学习特征之间的空间层次关系
  3. 平移不变性过强:CNN的对平移的不变性可能导致对旋转、缩放等变换过于敏感

胶囊网络的核心思想

胶囊网络通过以下创新来解决这些问题:

  • 向量表示:胶囊使用向量而非标量来表示神经元
  • 空间编码:向量的方向编码了特征的属性(如位置、姿态)
  • 动态路由:低层胶囊通过迭代协议决定向哪些高层胶囊传递信息

胶囊的定义

从标量神经元到向量胶囊

传统神经元:

胶囊神经元:

其中 是胶囊的输入向量, 是胶囊的输出向量。

Squash函数

Squash函数将短向量压缩到接近0,长向量压缩到接近1的长度:

  • 时,
  • 时,

动态路由算法

动态路由是CapsNet的核心创新。低层胶囊首先预测高层胶囊的输入,然后通过迭代过程调整路由系数。

路由算法步骤

输入:低层胶囊输出 ,迭代次数

输出:高层胶囊输出

算法

1. 初始化:b_ij = 0,对于所有 i, j
2. for r iterations:
3.     c_i = softmax(b_i)  # 路由系数归一化
4.     s_j = Σ_i c_ij * uhat_j|i  # 加权求和
5.     v_j = squash(s_j)  # 非线性压缩
6.     b_ij += uhat_j|i · v_j  # 更新 logits
7. return v_j

数学形式化

预测向量

其中 是第 层胶囊 到第 层胶囊 的权重矩阵。

加权和

路由系数

一致性更新

路由系数的意义

路由系数 表示第 层胶囊对第 层胶囊的”投票”权重。通过迭代,低层胶囊逐渐学会将信息发送到最需要它们的高层胶囊。

胶囊网络的结构

经典CapsNet架构

输入图像
    ↓
卷积层 (Conv1): 256个 9×9 卷积核, stride=1
    ↓
PrimaryCaps: 32个胶囊通道,每个通道8个胶囊
           每个胶囊:8个 9×9 卷积核, stride=2
    ↓
动态路由 (迭代r次)
    ↓
DigitCaps: 10个胶囊(MNIST 10个类别)
           每个胶囊:16维向量
    ↓
重构网络(可选)

DigitCaps层

DigitCaps层对每个数字类别输出一个16维胶囊:

  • 向量长度表示该类别的概率(通过squash函数归一化到0-1)
  • 向量方向编码了该类别的姿态信息

重构网络

重构网络用于正则化,通过重构输入图像来确保胶囊编码了足够的信息:

DigitCaps (16维 × 10)
    ↓ 全连接层 (512)
    ↓ 全连接层 (1024)
    ↓ 全连接层 (784)
    ↓ sigmoid
重构图像 (28×28)

重构损失:

其中 当输入属于类别 ,否则

与CNN的对比

特征表示对比

特性CNNCapsule Network
神经元输出标量向量
特征表示激活值激活值+方向
位置编码最大池化丢失胶囊方向保留
层次关系隐式学习显式建模

路由机制对比

特性最大池化动态路由
决策方式硬选择软选择
信息保留丢弃非最大值加权聚合
可学习性不可学习可学习
计算开销较高

鲁棒性对比

Capsule Network对以下变换具有更好的鲁棒性:

  1. 旋转:膂囊编码了方向信息
  2. 缩放:空间关系保持
  3. 部分遮挡:路由机制确保信息传递
  4. 姿态变化:向量表示自然编码姿态

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Squash(nn.Module):
    """Squash激活函数"""
    def forward(self, x):
        """
        Args:
            x: 输入向量 (batch, num_caps, dim_caps)
        Returns:
            压缩后的向量
        """
        squared_norm = torch.sum(x ** 2, dim=-1, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * x / torch.sqrt(squared_norm + 1e-8)
 
 
class CapsuleLayer(nn.Module):
    """胶囊层"""
    def __init__(self, num_caps_in, dim_caps_in, 
                 num_caps_out, dim_caps_out, 
                 num_routing=3):
        super().__init__()
        self.num_routing = num_routing
        self.num_caps_in = num_caps_in
        self.num_caps_out = num_caps_out
        
        # 权重矩阵
        self.W = nn.Parameter(
            torch.randn(num_caps_in, num_caps_out, dim_caps_in, dim_caps_out)
        )
    
    def forward(self, u):
        """
        Args:
            u: 低层胶囊输出 (batch, num_caps_in, dim_caps_in)
        Returns:
            高层胶囊输出
        """
        batch_size = u.size(0)
        
        # u_hat: (batch, num_caps_in, 1, 1, dim_caps_in) 
        #       → (batch, num_caps_in, num_caps_out, dim_caps_out)
        u_expanded = u.unsqueeze(2).unsqueeze(4)
        W_expanded = self.W.unsqueeze(0)
        u_hat = torch.matmul(u_expanded, W_expanded)
        u_hat = u_hat.squeeze(4)  # (batch, num_caps_in, num_caps_out, dim_caps_out)
        
        # 路由系数
        b = torch.zeros(batch_size, self.num_caps_in, self.num_caps_out).to(u.device)
        
        for _ in range(self.num_routing):
            c = F.softmax(b, dim=2)  # (batch, num_caps_in, num_caps_out)
            c_expanded = c.unsqueeze(3)  # (batch, num_caps_in, num_caps_out, 1)
            
            s = torch.sum(c_expanded * u_hat, dim=1)  # (batch, num_caps_out, dim_caps_out)
            v = Squash()(s)
            
            # 更新 logits
            if self.num_routing > 1:
                u_dot_v = torch.matmul(u_hat.unsqueeze(4), v.unsqueeze(4).transpose(3, 4))
                b = b + u_dot_v.squeeze(4).squeeze(4)
        
        return v
 
 
class PrimaryCaps(nn.Module):
    """主胶囊层"""
    def __init__(self, in_channels, out_dim, num_caps, kernel_size=9, stride=2):
        super().__init__()
        self.num_caps = num_caps
        self.out_dim = out_dim
        
        self.conv = nn.Conv2d(
            in_channels, 
            num_caps * out_dim, 
            kernel_size, 
            stride=stride, 
            padding=0
        )
    
    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        x = x.view(batch_size, -1, self.num_caps, self.out_dim)
        # Squash each capsule
        x = Squash()(x)
        return x
 
 
class CapsNet(nn.Module):
    """完整CapsNet"""
    def __init__(self, num_classes=10, routing_iterations=3):
        super().__init__()
        self.num_classes = num_classes
        
        # Conv1
        self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
        
        # PrimaryCaps
        self.primary_caps = PrimaryCaps(256, 8, 32)
        
        # DigitCaps
        self.digit_caps = CapsuleLayer(
            num_caps_in=32 * 6 * 6,  # PrimaryCaps的胶囊数
            dim_caps_in=8,
            num_caps_out=num_classes,
            dim_caps_out=16,
            num_routing=routing_iterations
        )
        
        # 重构网络
        self.decoder = nn.Sequential(
            nn.Linear(16 * num_classes, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )
    
    def forward(self, x, reconstruction=True):
        x = F.relu(self.conv1(x))
        x = self.primary_caps(x)
        x = self.digit_caps(x)
        
        # 取最长向量的长度作为类别概率
        classes = torch.sqrt(torch.sum(x ** 2, dim=-1))
        class_pred = classes.argmax(dim=1)
        
        if reconstruction and self.training:
            # 重构
            batch_size = x.size(0)
            masked_x = x * classes.argmax(dim=1, keepdim=True).unsqueeze(2).unsqueeze(3)
            masked_x = masked_x.view(batch_size, -1)
            reconstruction = self.decoder(masked_x)
            return classes, reconstruction
        else:
            return classes, class_pred
 
 
def margin_loss(x, target, margin=0.9, down_weight=0.5):
    """
    边缘损失函数
    
    Args:
        x: 胶囊输出的长度 (batch, num_classes)
        target: 目标类别 (batch,)
        margin: 正样本的边缘
        down_weight: 负样本的权重
    """
    batch_size = x.size(0)
    target_one_hot = torch.zeros_like(x).scatter_(1, target.unsqueeze(1), 1)
    
    left = F.relu(margin - x, inplace=True) ** 2
    right = F.relu(x - 0.1, inplace=True) ** 2
    
    loss = target_one_hot * left + down_weight * (1 - target_one_hot) * right
    return loss.sum(dim=1).mean()

实验结果

MNIST分类

方法错误率
CNN (Max Pooling)0.95%
Capsule Network0.25%

CIFAR-10分类

方法错误率
CNN7.42%
Capsule Network5.6%

MultiMNIST

CapsNet在重叠数字识别任务上显著优于CNN,因为它能够学习特征之间的空间关系。

参考文献


相关链接:动态路由算法详解 | 现代胶囊架构

Footnotes

  1. Sabour, S., Frosst, N., & Hinton, G. E. (2017). Dynamic routing between capsules. NeurIPS.