引言
胶囊网络因其独特的空间关系建模能力,在多种任务中展现出优势。本文档将系统介绍胶囊网络的主要应用场景及代表性工作。
图像分类与识别
基础图像分类
Capsule Network在基础图像分类任务上展现出优越的性能:
| 数据集 | CNN错误率 | CapsNet错误率 | 提升 |
|---|---|---|---|
| MNIST | 0.95% | 0.25% | 73.7% |
| SmallNORB | 2.7% | 1.8% | 33.3% |
| CIFAR-10 | 7.42% | 5.6% | 24.5% |
MultiMNIST分析
CapsNet在重叠数字识别任务上具有显著优势,因为它能够学习特征之间的空间关系。
实验设置:
- 将两个数字叠加在一起
- 训练数据:60,000张
- 测试数据:10,000张
结果:
- CNN:难以准确分离重叠数字
- CapsNet:通过动态路由准确识别两个数字
代码实现
class MultiMNISTCapsNet(nn.Module):
"""MultiMNIST分类网络"""
def __init__(self, num_classes=10):
super().__init__()
# 特征提取
self.conv1 = nn.Conv2d(1, 256, 9)
self.primary_caps = PrimaryCapsule(256, 32, 8)
self.digit_caps = DigitCapsule(32*6*6, 8, num_classes, 16)
# 多数字解码器
self.decoder = nn.Sequential(
nn.Linear(16 * num_classes * 2, 512), # 2个数字
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 2 * 784), # 2个数字的重建
nn.Sigmoid()
)
def forward(self, x, num_digits=2):
x = F.relu(self.conv1(x))
x = self.primary_caps(x)
x, _ = self.digit_caps(x)
# 取top-k最长的胶囊
classes = torch.sqrt(torch.sum(x ** 2, dim=-1))
top_k = torch.topk(classes, num_digits, dim=1)
# 重构
batch_size = x.size(0)
masked_x = x * (classes.argmax(dim=1, keepdim=True).unsqueeze(2) == torch.arange(num_classes).to(x.device)).float().unsqueeze(2)
masked_x = masked_x.view(batch_size, -1)
reconstruction = self.decoder(masked_x)
return classes, reconstruction, top_k医学图像分析
CardioCaps (2024)
CardioCaps是一种基于注意力机制的胶囊网络,用于心超图像的类别不平衡分类。
核心创新
- 类别不平衡处理:加权损失函数
- 注意力机制:聚焦重要区域
- 动态路由:建模心脏结构关系
架构
class CardioCaps(nn.Module):
"""CardioCaps for Echocardiogram Classification"""
def __init__(self, num_classes, imbalanced_ratio=1.0):
super().__init__()
# 特征提取器
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 胶囊层
self.caps = nn.ModuleList([
CapsuleLayer(128, 8, 16, 16, routing_iterations=3),
CapsuleLayer(16 * 8 * 8, 16, num_classes, 16, routing_iterations=3)
])
# 类别权重
self.class_weights = self._compute_weights(num_classes, imbalanced_ratio)
def _compute_weights(self, num_classes, ratio):
"""计算类别权重"""
weights = torch.ones(num_classes)
for i in range(num_classes):
weights[i] = (1 / (i + 1)) ** ratio
return weights
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1, 8)
x = self.caps[0](x)
x = self.caps[1](x)
return x视网膜疾病检测
Vision Transformer和Capsule Network的混合架构用于糖尿病视网膜病变检测。
方法
class HybridViTCapsNet(nn.Module):
"""混合ViT-CapsNet"""
def __init__(self, num_classes=5):
super().__init__()
# ViT特征提取
self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
self.vit.head = nn.Identity()
# 胶囊层
self.caps = CapsuleLayer(
in_features=768,
out_caps=num_classes,
dim_caps=16
)
def forward(self, x):
# ViT特征提取
features = self.vit(x) # (B, 768, H/16, W/16)
# 转换为胶囊格式
caps = features.flatten(2).transpose(1, 2) # (B, N, 768)
# 胶囊处理
output = self.caps(caps)
return output脑肿瘤MRI分类
Hybrid ViT-CapsNet框架用于MRI图像的脑肿瘤诊断。
数据增强策略
- 随机翻转:水平/垂直翻转
- 随机旋转:±15度
- 色彩抖动:亮度、对比度、饱和度
- 弹性变形:模拟组织变形
文本分类
句子胶囊网络
将胶囊网络应用于文本分类任务:
class TextCapsNet(nn.Module):
"""文本胶囊网络"""
def __init__(self, vocab_size, embed_dim, num_classes, dim_caps=16):
super().__init__()
# 嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim)
# 卷积层
self.conv = nn.Sequential(
nn.Conv1d(embed_dim, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(256, 256, kernel_size=3, padding=1),
nn.ReLU(),
)
# 胶囊层
self.caps = nn.ModuleList([
CapsuleLayer(256, 8, out_caps=8, dim_caps=8),
CapsuleLayer(8, num_classes, dim_caps=dim_caps)
])
def forward(self, x):
# 嵌入
x = self.embedding(x) # (B, seq_len, embed_dim)
# 卷积
x = x.transpose(1, 2) # (B, embed_dim, seq_len)
x = self.conv(x)
# 胶囊
x = x.transpose(1, 2) # (B, seq_len, channels)
for cap_layer in self.caps:
x = cap_layer(x)
return x情感分析应用
数据集:IMDB电影评论
结果:
- CNN准确率:88.5%
- CapsNet准确率:89.7%
视频动作识别
时序胶囊网络
class TemporalCapsNet(nn.Module):
"""时序胶囊网络"""
def __init__(self, num_classes):
super().__init__()
# 空间特征提取
self.spatial_cnn = nn.Sequential(
nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3)),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 1, 1)),
)
# 时序胶囊
self.temporal_caps = TemporalCapsuleLayer(
in_channels=64,
num_capsules=16,
dim_capsule=8,
temporal_window=4
)
# 分类胶囊
self.class_caps = DigitCapsule(
num_capsules=16 * 4, # temporal_window
dim_capsules=8,
out_caps=num_classes,
out_dim=16
)
def forward(self, x):
# x: (B, C, T, H, W)
x = self.spatial_cnn(x)
# 时序胶囊
x = self.temporal_caps(x)
# 分类
x, _ = self.class_caps(x)
return x零样本学习
胶囊网络的零样本优势
胶囊网络的向量表示自然适合零样本学习,因为:
- 姿态编码:可以泛化到未见过的姿态
- 属性表示:胶囊方向编码了视觉属性
- 组合性:新类别可以由已知属性组合
实现
class ZeroShotCapsNet(nn.Module):
"""零样本胶囊网络"""
def __init__(self, num_seen_classes, num_attributes, embed_dim=64):
super().__init__()
# 视觉胶囊编码器
self.visual_encoder = VisualCapsEncoder(embed_dim=embed_dim)
# 属性编码器
self.attribute_encoder = AttributeEncoder(
num_attributes=num_attributes,
embed_dim=embed_dim
)
# 零样本分类器
self.classifier = ZeroShotClassifier(embed_dim=embed_dim)
def forward(self, x, seen_mask=None):
# 视觉胶囊表示
visual_caps = self.visual_encoder(x) # (B, dim_caps)
# 属性预测
predicted_attrs = self.predict_attributes(visual_caps) # (B, num_attrs)
# 零样本分类
if seen_mask is not None:
# 只计算seen类别的分数
logits = self.classify_zero_shot(visual_caps, self.class_attributes)
logits = logits.masked_fill(~seen_mask, float('-inf'))
return logits, predicted_attrs
else:
# 全部类别
return self.classify_zero_shot(visual_caps, self.class_attributes), predicted_attrs鲁棒性分析
对抗攻击鲁棒性
Capsule Network对某些类型的对抗攻击具有更好的鲁棒性:
| 攻击方法 | CNN准确率 | CapsNet准确率 |
|---|---|---|
| FGSM | 81.2% | 87.5% |
| PGD | 76.3% | 83.1% |
| DeepFool | 79.5% | 85.2% |
姿态变化的鲁棒性
Capsule Network对输入姿态变化具有天然的鲁棒性:
class PoseRobustnessTest:
"""姿态鲁棒性测试"""
def __init__(self, model):
self.model = model
def test_rotation(self, images, angles):
"""测试旋转不变性"""
results = []
for angle in angles:
rotated = self.rotate_image(images, angle)
output = self.model(rotated)
pred = output.argmax(dim=1)
results.append(pred)
# 计算一致性
consistency = (torch.stack(results) == results[0]).float().mean()
return consistency
def test_scale(self, images, scales):
"""测试尺度不变性"""
results = []
for scale in scales:
scaled = self.scale_image(images, scale)
output = self.model(scaled)
pred = output.argmax(dim=1)
results.append(pred)
consistency = (torch.stack(results) == results[0]).float().mean()
return consistency
def test_occlusion(self, images, occlusion_ratios):
"""测试遮挡鲁棒性"""
results = []
for ratio in occlusion_ratios:
occluded = self.occlude_image(images, ratio)
output = self.model(occluded)
pred = output.argmax(dim=1)
results.append(pred)
# 期望:遮挡增加时准确率下降更慢
return results实践建议
1. 任务适配
| 任务类型 | 推荐架构 | 理由 |
|---|---|---|
| 图像分类 | 原始CapsNet | 简单有效 |
| 医学图像 | IBCapsNet/CardioCaps | 噪声鲁棒 |
| 细粒度分类 | MSPCaps | 多尺度特征 |
| 文本分类 | TextCapsNet | 序列建模 |
| 图数据 | PR-CapsNet | 图结构建模 |
2. 超参数选择
# 推荐配置
config = {
'num_routing_iterations': 3, # MNIST: 3, CIFAR: 5
'dim_capsule': 16, # 分类: 16, 检测: 8
'num_primary_capsules': 32, # 32-64
'primary_capsule_dim': 8, # 8-16
'reconstruction_weight': 0.0005, # 根据任务调整
'margin_loss_margin': 0.9,
'margin_loss_down_weight': 0.5,
}3. 训练技巧
# 1. 学习率调度
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=1e-6
)
# 2. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 3. 标签平滑
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# 4. 数据增强
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
])参考文献
相关链接:胶囊网络基础 | 现代胶囊架构 | 胶囊 vs ViT对比