引言

胶囊网络因其独特的空间关系建模能力,在多种任务中展现出优势。本文档将系统介绍胶囊网络的主要应用场景及代表性工作。

图像分类与识别

基础图像分类

Capsule Network在基础图像分类任务上展现出优越的性能:

数据集CNN错误率CapsNet错误率提升
MNIST0.95%0.25%73.7%
SmallNORB2.7%1.8%33.3%
CIFAR-107.42%5.6%24.5%

MultiMNIST分析

CapsNet在重叠数字识别任务上具有显著优势,因为它能够学习特征之间的空间关系。

实验设置

  • 将两个数字叠加在一起
  • 训练数据:60,000张
  • 测试数据:10,000张

结果

  • CNN:难以准确分离重叠数字
  • CapsNet:通过动态路由准确识别两个数字

代码实现

class MultiMNISTCapsNet(nn.Module):
    """MultiMNIST分类网络"""
    def __init__(self, num_classes=10):
        super().__init__()
        
        # 特征提取
        self.conv1 = nn.Conv2d(1, 256, 9)
        self.primary_caps = PrimaryCapsule(256, 32, 8)
        self.digit_caps = DigitCapsule(32*6*6, 8, num_classes, 16)
        
        # 多数字解码器
        self.decoder = nn.Sequential(
            nn.Linear(16 * num_classes * 2, 512),  # 2个数字
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2 * 784),  # 2个数字的重建
            nn.Sigmoid()
        )
    
    def forward(self, x, num_digits=2):
        x = F.relu(self.conv1(x))
        x = self.primary_caps(x)
        x, _ = self.digit_caps(x)
        
        # 取top-k最长的胶囊
        classes = torch.sqrt(torch.sum(x ** 2, dim=-1))
        top_k = torch.topk(classes, num_digits, dim=1)
        
        # 重构
        batch_size = x.size(0)
        masked_x = x * (classes.argmax(dim=1, keepdim=True).unsqueeze(2) == torch.arange(num_classes).to(x.device)).float().unsqueeze(2)
        masked_x = masked_x.view(batch_size, -1)
        reconstruction = self.decoder(masked_x)
        
        return classes, reconstruction, top_k

医学图像分析

CardioCaps (2024)

CardioCaps是一种基于注意力机制的胶囊网络,用于心超图像的类别不平衡分类。

核心创新

  1. 类别不平衡处理:加权损失函数
  2. 注意力机制:聚焦重要区域
  3. 动态路由:建模心脏结构关系

架构

class CardioCaps(nn.Module):
    """CardioCaps for Echocardiogram Classification"""
    def __init__(self, num_classes, imbalanced_ratio=1.0):
        super().__init__()
        
        # 特征提取器
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        # 胶囊层
        self.caps = nn.ModuleList([
            CapsuleLayer(128, 8, 16, 16, routing_iterations=3),
            CapsuleLayer(16 * 8 * 8, 16, num_classes, 16, routing_iterations=3)
        ])
        
        # 类别权重
        self.class_weights = self._compute_weights(num_classes, imbalanced_ratio)
    
    def _compute_weights(self, num_classes, ratio):
        """计算类别权重"""
        weights = torch.ones(num_classes)
        for i in range(num_classes):
            weights[i] = (1 / (i + 1)) ** ratio
        return weights
    
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1, 8)
        x = self.caps[0](x)
        x = self.caps[1](x)
        return x

视网膜疾病检测

Vision Transformer和Capsule Network的混合架构用于糖尿病视网膜病变检测。

方法

class HybridViTCapsNet(nn.Module):
    """混合ViT-CapsNet"""
    def __init__(self, num_classes=5):
        super().__init__()
        
        # ViT特征提取
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.vit.head = nn.Identity()
        
        # 胶囊层
        self.caps = CapsuleLayer(
            in_features=768,
            out_caps=num_classes,
            dim_caps=16
        )
    
    def forward(self, x):
        # ViT特征提取
        features = self.vit(x)  # (B, 768, H/16, W/16)
        
        # 转换为胶囊格式
        caps = features.flatten(2).transpose(1, 2)  # (B, N, 768)
        
        # 胶囊处理
        output = self.caps(caps)
        
        return output

脑肿瘤MRI分类

Hybrid ViT-CapsNet框架用于MRI图像的脑肿瘤诊断。

数据增强策略

  1. 随机翻转:水平/垂直翻转
  2. 随机旋转:±15度
  3. 色彩抖动:亮度、对比度、饱和度
  4. 弹性变形:模拟组织变形

文本分类

句子胶囊网络

将胶囊网络应用于文本分类任务:

class TextCapsNet(nn.Module):
    """文本胶囊网络"""
    def __init__(self, vocab_size, embed_dim, num_classes, dim_caps=16):
        super().__init__()
        
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # 卷积层
        self.conv = nn.Sequential(
            nn.Conv1d(embed_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        
        # 胶囊层
        self.caps = nn.ModuleList([
            CapsuleLayer(256, 8, out_caps=8, dim_caps=8),
            CapsuleLayer(8, num_classes, dim_caps=dim_caps)
        ])
    
    def forward(self, x):
        # 嵌入
        x = self.embedding(x)  # (B, seq_len, embed_dim)
        
        # 卷积
        x = x.transpose(1, 2)  # (B, embed_dim, seq_len)
        x = self.conv(x)
        
        # 胶囊
        x = x.transpose(1, 2)  # (B, seq_len, channels)
        for cap_layer in self.caps:
            x = cap_layer(x)
        
        return x

情感分析应用

数据集:IMDB电影评论

结果

  • CNN准确率:88.5%
  • CapsNet准确率:89.7%

视频动作识别

时序胶囊网络

class TemporalCapsNet(nn.Module):
    """时序胶囊网络"""
    def __init__(self, num_classes):
        super().__init__()
        
        # 空间特征提取
        self.spatial_cnn = nn.Sequential(
            nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3)),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 1, 1)),
        )
        
        # 时序胶囊
        self.temporal_caps = TemporalCapsuleLayer(
            in_channels=64,
            num_capsules=16,
            dim_capsule=8,
            temporal_window=4
        )
        
        # 分类胶囊
        self.class_caps = DigitCapsule(
            num_capsules=16 * 4,  # temporal_window
            dim_capsules=8,
            out_caps=num_classes,
            out_dim=16
        )
    
    def forward(self, x):
        # x: (B, C, T, H, W)
        x = self.spatial_cnn(x)
        
        # 时序胶囊
        x = self.temporal_caps(x)
        
        # 分类
        x, _ = self.class_caps(x)
        
        return x

零样本学习

胶囊网络的零样本优势

胶囊网络的向量表示自然适合零样本学习,因为:

  1. 姿态编码:可以泛化到未见过的姿态
  2. 属性表示:胶囊方向编码了视觉属性
  3. 组合性:新类别可以由已知属性组合

实现

class ZeroShotCapsNet(nn.Module):
    """零样本胶囊网络"""
    def __init__(self, num_seen_classes, num_attributes, embed_dim=64):
        super().__init__()
        
        # 视觉胶囊编码器
        self.visual_encoder = VisualCapsEncoder(embed_dim=embed_dim)
        
        # 属性编码器
        self.attribute_encoder = AttributeEncoder(
            num_attributes=num_attributes,
            embed_dim=embed_dim
        )
        
        # 零样本分类器
        self.classifier = ZeroShotClassifier(embed_dim=embed_dim)
    
    def forward(self, x, seen_mask=None):
        # 视觉胶囊表示
        visual_caps = self.visual_encoder(x)  # (B, dim_caps)
        
        # 属性预测
        predicted_attrs = self.predict_attributes(visual_caps)  # (B, num_attrs)
        
        # 零样本分类
        if seen_mask is not None:
            # 只计算seen类别的分数
            logits = self.classify_zero_shot(visual_caps, self.class_attributes)
            logits = logits.masked_fill(~seen_mask, float('-inf'))
            return logits, predicted_attrs
        else:
            # 全部类别
            return self.classify_zero_shot(visual_caps, self.class_attributes), predicted_attrs

鲁棒性分析

对抗攻击鲁棒性

Capsule Network对某些类型的对抗攻击具有更好的鲁棒性:

攻击方法CNN准确率CapsNet准确率
FGSM81.2%87.5%
PGD76.3%83.1%
DeepFool79.5%85.2%

姿态变化的鲁棒性

Capsule Network对输入姿态变化具有天然的鲁棒性:

class PoseRobustnessTest:
    """姿态鲁棒性测试"""
    def __init__(self, model):
        self.model = model
    
    def test_rotation(self, images, angles):
        """测试旋转不变性"""
        results = []
        for angle in angles:
            rotated = self.rotate_image(images, angle)
            output = self.model(rotated)
            pred = output.argmax(dim=1)
            results.append(pred)
        
        # 计算一致性
        consistency = (torch.stack(results) == results[0]).float().mean()
        return consistency
    
    def test_scale(self, images, scales):
        """测试尺度不变性"""
        results = []
        for scale in scales:
            scaled = self.scale_image(images, scale)
            output = self.model(scaled)
            pred = output.argmax(dim=1)
            results.append(pred)
        
        consistency = (torch.stack(results) == results[0]).float().mean()
        return consistency
    
    def test_occlusion(self, images, occlusion_ratios):
        """测试遮挡鲁棒性"""
        results = []
        for ratio in occlusion_ratios:
            occluded = self.occlude_image(images, ratio)
            output = self.model(occluded)
            pred = output.argmax(dim=1)
            results.append(pred)
        
        # 期望:遮挡增加时准确率下降更慢
        return results

实践建议

1. 任务适配

任务类型推荐架构理由
图像分类原始CapsNet简单有效
医学图像IBCapsNet/CardioCaps噪声鲁棒
细粒度分类MSPCaps多尺度特征
文本分类TextCapsNet序列建模
图数据PR-CapsNet图结构建模

2. 超参数选择

# 推荐配置
config = {
    'num_routing_iterations': 3,      # MNIST: 3, CIFAR: 5
    'dim_capsule': 16,                  # 分类: 16, 检测: 8
    'num_primary_capsules': 32,         # 32-64
    'primary_capsule_dim': 8,            # 8-16
    'reconstruction_weight': 0.0005,   # 根据任务调整
    'margin_loss_margin': 0.9,
    'margin_loss_down_weight': 0.5,
}

3. 训练技巧

# 1. 学习率调度
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs, eta_min=1e-6
)
 
# 2. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
# 3. 标签平滑
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
 
# 4. 数据增强
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
])

参考文献


相关链接:胶囊网络基础 | 现代胶囊架构 | 胶囊 vs ViT对比