原型网络

原型网络(Prototypical Networks)由Jake Snell等人于2017年提出,是一种基于度量学习的少样本学习方法。1 其核心思想是:为每个类别学习一个原型(Prototype)表示,然后根据Query样本与各类别原型的距离进行分类

核心思想

与Siamese网络的区别

Siamese Network(孪生网络):
输入两个样本 → 共享编码器 → 输出相似度分数

Prototypical Networks(原型网络):
Support集 → 计算每个类的原型 → Query与原型比较

关键洞察

同类样本在嵌入空间中聚集,不同类样本远离。

           嵌入空间
    ┌─────────────────────────┐
    │                         │
    │    ● ●                  │
    │   ●   ●    ■ ■          │  ●: 类别1样本
    │    ● ●     ■ ■ ■        │  ■: 类别2样本
    │              ★          │  ★: 类别1原型
    │              ☆          │  ☆: 类别2原型
    │                         │
    └─────────────────────────┘
    
    Query样本 q 距离★更近 → 分类为类别1

数学推导

问题设置

给定 N-way K-shot 任务:

  • 支持集
  • 查询集
  • 编码器

原型计算

对于每个类别 ,计算其原型

其中 是类别 的支持样本。

分类决策

使用softmax over distances

其中 是距离函数(通常为欧氏距离)。

损失函数

训练时使用负对数似然损失:

展开为:

为什么使用欧氏距离?

论文指出,使用平方欧氏距离在指数项中,等价于在归一化空间中计算余弦相似度:

如果对嵌入向量做L2归一化,则等价于余弦相似度。


Episode训练机制

什么是Episode?

一个Episode是一个完整的少样本学习任务:

Episode = 1个N-way K-shot问题
       = 支持集 + 查询集

训练流程

# Episode训练伪代码
for epoch in range(num_epochs):
    for _ in range(num_episodes):
        # 1. 采样N个类
        classes = sample_classes(num_classes=N)
        
        # 2. 从每个类采样K+N_query个样本
        support_x, support_y = [], []
        query_x, query_y = [], []
        for c in classes:
            samples = get_samples(c, num=K + N_query)
            support_x.append(samples[:K])
            query_x.append(samples[K:])
        
        # 3. 计算原型
        prototypes = compute_prototypes(support_x, support_y)
        
        # 4. 在查询集上计算损失
        query_preds = classify(query_x, prototypes)
        loss = cross_entropy(query_preds, query_y)
        
        # 5. 反向传播更新编码器
        loss.backward()

代码实现

基础原型网络

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List
import numpy as np
 
class PrototypicalNetworks(nn.Module):
    """
    原型网络
    
    Args:
        encoder: 特征编码器(CNN或Transformer等)
        distance: 距离函数 ('euclidean' or 'cosine')
    """
    def __init__(
        self,
        encoder: nn.Module,
        distance: str = 'euclidean'
    ):
        super().__init__()
        self.encoder = encoder
        self.distance = distance
    
    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        way: int,
        shot: int,
        query: int = 1
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        前向传播
        
        Args:
            support_x: 支持集特征 (N * K, C, H, W)
            support_y: 支持集标签 (N * K,)
            query_x: 查询集特征 (N * Q, C, H, W)
            way: 类别数 N
            shot: 每个类的支持样本数 K
            query: 每个类的查询样本数 Q
        
        Returns:
            query_preds: 查询集预测 (N * Q,)
            prototypes: 各类原型 (N, d)
            query_embeddings: 查询集嵌入 (N * Q, d)
        """
        # 1. 编码支持集和查询集
        support_emb = self.encoder(support_x)
        query_emb = self.encoder(query_x)
        
        # 2. 计算各类原型
        prototypes = self.compute_prototypes(support_emb, support_y, way)
        
        # 3. 计算查询集到原型的距离
        if self.distance == 'euclidean':
            # 欧氏距离
            dists = torch.cdist(query_emb, prototypes, p=2)
        else:
            # 余弦距离(需要归一化)
            query_emb = F.normalize(query_emb, p=2, dim=-1)
            prototypes_norm = F.normalize(prototypes, p=2, dim=-1)
            dists = 1 - torch.mm(query_emb, prototypes_norm.t())
        
        # 4. Softmax分类
        log_probs = F.log_softmax(-dists, dim=-1)
        
        # 生成真实标签
        targets = torch.arange(way, device=query_x.device).unsqueeze(1)
        targets = targets.expand(way, query).contiguous().view(-1)
        
        return log_probs, prototypes, query_emb
    
    def compute_prototypes(
        self,
        support_emb: torch.Tensor,
        support_y: torch.Tensor,
        way: int
    ) -> torch.Tensor:
        """
        计算每个类的原型
        
        Args:
            support_emb: 支持集嵌入 (N*K, d)
            support_y: 支持集标签 (N*K,)
            way: 类别数
        
        Returns:
            prototypes: 原型 (N, d)
        """
        classes = torch.arange(way, device=support_emb.device)
        prototypes = torch.stack([
            support_emb[support_y == c].mean(0) for c in classes
        ])
        return prototypes
    
    def loss(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        query_y: torch.Tensor,
        way: int,
        shot: int
    ) -> torch.Tensor:
        """
        计算Episode损失
        """
        log_probs, _, _ = self.forward(
            support_x, support_y, query_x, way, shot
        )
        return F.nll_loss(log_probs, query_y)
 
 
class ConvEncoder(nn.Module):
    """
    4层卷积编码器(适用于Omniglot和Mini-ImageNet)
    """
    def __init__(self, in_channels=1, hid_dim=64, out_dim=64):
        super().__init__()
        
        self.net = nn.Sequential(
            # Block 1: 28x28 -> 14x14 (Omniglot)
            nn.Conv2d(in_channels, hid_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hid_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Block 2: 14x14 -> 7x7
            nn.Conv2d(hid_dim, hid_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hid_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Block 3: 7x7 -> 3x3
            nn.Conv2d(hid_dim, hid_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hid_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Block 4: 3x3 -> 1x1
            nn.Conv2d(hid_dim, out_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # 可选:L2归一化
        self.use_norm = True
    
    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), -1)
        if self.use_norm:
            x = F.normalize(x, p=2, dim=-1)
        return x

Episode采样器

class EpisodeSampler:
    """
    Episode采样器
    
    用于从数据集中采样训练/测试用的Episode
    """
    def __init__(
        self,
        labels: np.ndarray,
        way: int = 5,
        shot: int = 1,
        query: int = 15,
        num_episodes: int = 100
    ):
        self.labels = labels
        self.way = way
        self.shot = shot
        self.query = query
        self.num_episodes = num_episodes
        
        # 按标签分组样本索引
        self.class_to_indices = {}
        for idx, label in enumerate(labels):
            if label not in self.class_to_indices:
                self.class_to_indices[label] = []
            self.class_to_indices[label].append(idx)
    
    def __iter__(self):
        return self
    
    def __next__(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        返回一个Episode的数据
        """
        if self.num_episodes <= 0:
            raise StopIteration
        self.num_episodes -= 1
        
        # 1. 随机选择way个类
        selected_classes = np.random.choice(
            list(self.class_to_indices.keys()),
            size=self.way,
            replace=False
        )
        
        support_x, support_y = [], []
        query_x, query_y = [], []
        
        for label_idx, class_label in enumerate(selected_classes):
            indices = self.class_to_indices[class_label]
            
            # 2. 从每个类中采样support和query
            sampled = np.random.choice(
                indices,
                size=self.shot + self.query,
                replace=False
            )
            
            support_x.extend(sampled[:self.shot])
            support_y.extend([label_idx] * self.shot)
            
            query_x.extend(sampled[self.shot:])
            query_y.extend([label_idx] * self.query)
        
        # 打乱顺序
        perm_support = np.random.permutation(len(support_x))
        perm_query = np.random.permutation(len(query_x))
        
        return (
            np.array(support_x)[perm_support],
            np.array(support_y)[perm_support],
            np.array(query_x)[perm_query],
            np.array(query_y)[perm_query]
        )
    
    def __len__(self):
        return self.num_episodes

训练循环

def train_proto_net():
    """原型网络训练示例"""
    import torch.utils.data as data
    
    # 超参数
    WAY = 5
    SHOT = 5
    QUERY = 15
    NUM_TASKS = 1000
    EPOCHS = 50
    
    # 模型
    encoder = ConvEncoder(in_channels=1, hid_dim=64, out_dim=64)
    model = PrototypicalNetworks(encoder, distance='euclidean')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    
    # 训练
    for epoch in range(EPOCHS):
        total_loss = 0.0
        total_acc = 0.0
        
        sampler = EpisodeSampler(
            labels=np.arange(100),  # 100个类
            way=WAY,
            shot=SHOT,
            query=QUERY,
            num_episodes=NUM_TASKS
        )
        
        for support_idx, support_y, query_idx, query_y in sampler:
            # 加载数据(示例中用随机数据)
            support_x = torch.randn(len(support_idx), 1, 28, 28)
            query_x = torch.randn(len(query_idx), 1, 28, 28)
            support_y = torch.LongTensor(support_y)
            query_y = torch.LongTensor(query_y)
            
            # 前向
            log_probs, _, _ = model(support_x, support_y, query_x, WAY, SHOT, QUERY)
            
            # 损失
            loss = F.nll_loss(log_probs, query_y)
            
            # 反向
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # 统计
            preds = log_probs.argmax(dim=-1)
            acc = (preds == query_y).float().mean()
            
            total_loss += loss.item()
            total_acc += acc.item()
        
        scheduler.step()
        
        print(f"Epoch {epoch}: Loss={total_loss/NUM_TASKS:.4f}, "
              f"Acc={total_acc/NUM_TASKS:.4f}")

与其他方法的对比

方法对比表

方法分类机制特点
Siamese Network成对相似度简单,但需比较所有对
Matching Networks注意力加权匹配支持集加权,可解释
Prototypical Networks类原型距离高效,易训练
Relation Network学习关系模块最灵活,但需更多数据

原型网络的优势

  1. 计算高效:只需计算类原型 ,而非所有对
  2. 正则化强:类内样本平均减少噪声
  3. 易于训练:简单的 Episodic 训练即可
  4. 泛化性好:原型表示对异常值鲁棒

扩展:半监督原型网络

Transductive Setting

利用查询集样本帮助分类:

def semi_supervised_proto_net(support_x, support_y, query_x, way, shot):
    """
    半监督原型网络(转导设置)
    
    1. 用支持集计算初始原型
    2. 用查询集更新原型
    3. 重新分类
    """
    # Step 1: 初始原型
    initial_prototypes = compute_prototypes(support_x, support_y, way)
    
    # Step 2: 软分配查询样本
    dists = cdist(query_x, initial_prototypes)
    probs = softmax(-dists, dim=-1)
    
    # Step 3: 更新原型(包含查询集贡献)
    updated_prototypes = []
    for c in range(way):
        # 支持集样本 + 软加权的查询集样本
        support_mask = (support_y == c)
        prototype = (
            support_x[support_mask].sum(0) + 
            (probs[:, c:c+1] * query_x).sum(0)
        ) / (shot + probs[:, c].sum())
        updated_prototypes.append(prototype)
    
    return torch.stack(updated_prototypes)

参考文献

相关文章

Footnotes

  1. Snell, J., Swersky, K., & Zemel, R. (2017). “Prototypical Networks for Few-shot Learning”. Advances in Neural Information Processing Systems (NeurIPS).