原型网络
原型网络(Prototypical Networks)由Jake Snell等人于2017年提出,是一种基于度量学习的少样本学习方法。1 其核心思想是:为每个类别学习一个原型(Prototype)表示,然后根据Query样本与各类别原型的距离进行分类。
核心思想
与Siamese网络的区别
Siamese Network(孪生网络):
输入两个样本 → 共享编码器 → 输出相似度分数
Prototypical Networks(原型网络):
Support集 → 计算每个类的原型 → Query与原型比较
关键洞察
同类样本在嵌入空间中聚集,不同类样本远离。
嵌入空间
┌─────────────────────────┐
│ │
│ ● ● │
│ ● ● ■ ■ │ ●: 类别1样本
│ ● ● ■ ■ ■ │ ■: 类别2样本
│ ★ │ ★: 类别1原型
│ ☆ │ ☆: 类别2原型
│ │
└─────────────────────────┘
Query样本 q 距离★更近 → 分类为类别1
数学推导
问题设置
给定 N-way K-shot 任务:
- 支持集
- 查询集
- 编码器
原型计算
对于每个类别 ,计算其原型:
其中 是类别 的支持样本。
分类决策
使用softmax over distances:
其中 是距离函数(通常为欧氏距离)。
损失函数
训练时使用负对数似然损失:
展开为:
为什么使用欧氏距离?
论文指出,使用平方欧氏距离在指数项中,等价于在归一化空间中计算余弦相似度:
如果对嵌入向量做L2归一化,则等价于余弦相似度。
Episode训练机制
什么是Episode?
一个Episode是一个完整的少样本学习任务:
Episode = 1个N-way K-shot问题
= 支持集 + 查询集
训练流程
# Episode训练伪代码
for epoch in range(num_epochs):
for _ in range(num_episodes):
# 1. 采样N个类
classes = sample_classes(num_classes=N)
# 2. 从每个类采样K+N_query个样本
support_x, support_y = [], []
query_x, query_y = [], []
for c in classes:
samples = get_samples(c, num=K + N_query)
support_x.append(samples[:K])
query_x.append(samples[K:])
# 3. 计算原型
prototypes = compute_prototypes(support_x, support_y)
# 4. 在查询集上计算损失
query_preds = classify(query_x, prototypes)
loss = cross_entropy(query_preds, query_y)
# 5. 反向传播更新编码器
loss.backward()代码实现
基础原型网络
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List
import numpy as np
class PrototypicalNetworks(nn.Module):
"""
原型网络
Args:
encoder: 特征编码器(CNN或Transformer等)
distance: 距离函数 ('euclidean' or 'cosine')
"""
def __init__(
self,
encoder: nn.Module,
distance: str = 'euclidean'
):
super().__init__()
self.encoder = encoder
self.distance = distance
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
way: int,
shot: int,
query: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播
Args:
support_x: 支持集特征 (N * K, C, H, W)
support_y: 支持集标签 (N * K,)
query_x: 查询集特征 (N * Q, C, H, W)
way: 类别数 N
shot: 每个类的支持样本数 K
query: 每个类的查询样本数 Q
Returns:
query_preds: 查询集预测 (N * Q,)
prototypes: 各类原型 (N, d)
query_embeddings: 查询集嵌入 (N * Q, d)
"""
# 1. 编码支持集和查询集
support_emb = self.encoder(support_x)
query_emb = self.encoder(query_x)
# 2. 计算各类原型
prototypes = self.compute_prototypes(support_emb, support_y, way)
# 3. 计算查询集到原型的距离
if self.distance == 'euclidean':
# 欧氏距离
dists = torch.cdist(query_emb, prototypes, p=2)
else:
# 余弦距离(需要归一化)
query_emb = F.normalize(query_emb, p=2, dim=-1)
prototypes_norm = F.normalize(prototypes, p=2, dim=-1)
dists = 1 - torch.mm(query_emb, prototypes_norm.t())
# 4. Softmax分类
log_probs = F.log_softmax(-dists, dim=-1)
# 生成真实标签
targets = torch.arange(way, device=query_x.device).unsqueeze(1)
targets = targets.expand(way, query).contiguous().view(-1)
return log_probs, prototypes, query_emb
def compute_prototypes(
self,
support_emb: torch.Tensor,
support_y: torch.Tensor,
way: int
) -> torch.Tensor:
"""
计算每个类的原型
Args:
support_emb: 支持集嵌入 (N*K, d)
support_y: 支持集标签 (N*K,)
way: 类别数
Returns:
prototypes: 原型 (N, d)
"""
classes = torch.arange(way, device=support_emb.device)
prototypes = torch.stack([
support_emb[support_y == c].mean(0) for c in classes
])
return prototypes
def loss(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
query_y: torch.Tensor,
way: int,
shot: int
) -> torch.Tensor:
"""
计算Episode损失
"""
log_probs, _, _ = self.forward(
support_x, support_y, query_x, way, shot
)
return F.nll_loss(log_probs, query_y)
class ConvEncoder(nn.Module):
"""
4层卷积编码器(适用于Omniglot和Mini-ImageNet)
"""
def __init__(self, in_channels=1, hid_dim=64, out_dim=64):
super().__init__()
self.net = nn.Sequential(
# Block 1: 28x28 -> 14x14 (Omniglot)
nn.Conv2d(in_channels, hid_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(hid_dim),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 2: 14x14 -> 7x7
nn.Conv2d(hid_dim, hid_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(hid_dim),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 3: 7x7 -> 3x3
nn.Conv2d(hid_dim, hid_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(hid_dim),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 4: 3x3 -> 1x1
nn.Conv2d(hid_dim, out_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
# 可选:L2归一化
self.use_norm = True
def forward(self, x):
x = self.net(x)
x = x.view(x.size(0), -1)
if self.use_norm:
x = F.normalize(x, p=2, dim=-1)
return xEpisode采样器
class EpisodeSampler:
"""
Episode采样器
用于从数据集中采样训练/测试用的Episode
"""
def __init__(
self,
labels: np.ndarray,
way: int = 5,
shot: int = 1,
query: int = 15,
num_episodes: int = 100
):
self.labels = labels
self.way = way
self.shot = shot
self.query = query
self.num_episodes = num_episodes
# 按标签分组样本索引
self.class_to_indices = {}
for idx, label in enumerate(labels):
if label not in self.class_to_indices:
self.class_to_indices[label] = []
self.class_to_indices[label].append(idx)
def __iter__(self):
return self
def __next__(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
返回一个Episode的数据
"""
if self.num_episodes <= 0:
raise StopIteration
self.num_episodes -= 1
# 1. 随机选择way个类
selected_classes = np.random.choice(
list(self.class_to_indices.keys()),
size=self.way,
replace=False
)
support_x, support_y = [], []
query_x, query_y = [], []
for label_idx, class_label in enumerate(selected_classes):
indices = self.class_to_indices[class_label]
# 2. 从每个类中采样support和query
sampled = np.random.choice(
indices,
size=self.shot + self.query,
replace=False
)
support_x.extend(sampled[:self.shot])
support_y.extend([label_idx] * self.shot)
query_x.extend(sampled[self.shot:])
query_y.extend([label_idx] * self.query)
# 打乱顺序
perm_support = np.random.permutation(len(support_x))
perm_query = np.random.permutation(len(query_x))
return (
np.array(support_x)[perm_support],
np.array(support_y)[perm_support],
np.array(query_x)[perm_query],
np.array(query_y)[perm_query]
)
def __len__(self):
return self.num_episodes训练循环
def train_proto_net():
"""原型网络训练示例"""
import torch.utils.data as data
# 超参数
WAY = 5
SHOT = 5
QUERY = 15
NUM_TASKS = 1000
EPOCHS = 50
# 模型
encoder = ConvEncoder(in_channels=1, hid_dim=64, out_dim=64)
model = PrototypicalNetworks(encoder, distance='euclidean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
# 训练
for epoch in range(EPOCHS):
total_loss = 0.0
total_acc = 0.0
sampler = EpisodeSampler(
labels=np.arange(100), # 100个类
way=WAY,
shot=SHOT,
query=QUERY,
num_episodes=NUM_TASKS
)
for support_idx, support_y, query_idx, query_y in sampler:
# 加载数据(示例中用随机数据)
support_x = torch.randn(len(support_idx), 1, 28, 28)
query_x = torch.randn(len(query_idx), 1, 28, 28)
support_y = torch.LongTensor(support_y)
query_y = torch.LongTensor(query_y)
# 前向
log_probs, _, _ = model(support_x, support_y, query_x, WAY, SHOT, QUERY)
# 损失
loss = F.nll_loss(log_probs, query_y)
# 反向
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计
preds = log_probs.argmax(dim=-1)
acc = (preds == query_y).float().mean()
total_loss += loss.item()
total_acc += acc.item()
scheduler.step()
print(f"Epoch {epoch}: Loss={total_loss/NUM_TASKS:.4f}, "
f"Acc={total_acc/NUM_TASKS:.4f}")与其他方法的对比
方法对比表
| 方法 | 分类机制 | 特点 |
|---|---|---|
| Siamese Network | 成对相似度 | 简单,但需比较所有对 |
| Matching Networks | 注意力加权匹配 | 支持集加权,可解释 |
| Prototypical Networks | 类原型距离 | 高效,易训练 |
| Relation Network | 学习关系模块 | 最灵活,但需更多数据 |
原型网络的优势
- 计算高效:只需计算类原型 ,而非所有对
- 正则化强:类内样本平均减少噪声
- 易于训练:简单的 Episodic 训练即可
- 泛化性好:原型表示对异常值鲁棒
扩展:半监督原型网络
Transductive Setting
利用查询集样本帮助分类:
def semi_supervised_proto_net(support_x, support_y, query_x, way, shot):
"""
半监督原型网络(转导设置)
1. 用支持集计算初始原型
2. 用查询集更新原型
3. 重新分类
"""
# Step 1: 初始原型
initial_prototypes = compute_prototypes(support_x, support_y, way)
# Step 2: 软分配查询样本
dists = cdist(query_x, initial_prototypes)
probs = softmax(-dists, dim=-1)
# Step 3: 更新原型(包含查询集贡献)
updated_prototypes = []
for c in range(way):
# 支持集样本 + 软加权的查询集样本
support_mask = (support_y == c)
prototype = (
support_x[support_mask].sum(0) +
(probs[:, c:c+1] * query_x).sum(0)
) / (shot + probs[:, c].sum())
updated_prototypes.append(prototype)
return torch.stack(updated_prototypes)参考文献
相关文章
Footnotes
-
Snell, J., Swersky, K., & Zemel, R. (2017). “Prototypical Networks for Few-shot Learning”. Advances in Neural Information Processing Systems (NeurIPS). ↩