匹配网络

匹配网络(Matching Networks)由Oriol Vinyals等人于2016年提出,是最早将注意力机制引入少样本学习的元学习方法之一。1 与原型网络不同,匹配网络为每个支持样本分配注意力权重,而非仅依赖类原型。

核心思想

与原型网络的关键区别

原型网络:
同类样本 → 平均 → 类原型 → 与Query比较

匹配网络:
每个Support样本 → 计算注意力权重 → 加权组合 → 与Query比较

注意力机制

其中:

  • 是Query样本
  • 是支持集的特征和标签
  • 是注意力函数

数学推导

任务定义

给定支持集 和查询样本

  1. 编码支持集,称为支持集编码器
  2. 编码查询集,称为查询集编码器
  3. 注意力计算

注意力函数

Cosine相似度注意力

MLP注意力(Relation Network)

分类输出

注意:这里 是 one-hot 编码的标签向量。


完全注意力记忆(FOCAL Attention)

支持集作为外部记忆

匹配网络可以看作一个记忆增强神经网络(Memory-Augmented Neural Network):

┌─────────────────────────────────────────────────────────┐
│                    查询样本 x̂                             │
│                         ↓                                │
│                   查询编码 f(x̂)                           │
│                         ↓                                │
│              ┌──────────────────────┐                    │
│              │    注意力机制         │                    │
│              │  a(f(x̂), g(x₁))      │                    │
│              │  a(f(x̂), g(x₂))  ... │                    │
│              └──────────────────────┘                    │
│                         ↓                                │
│              加权组合 → 预测标签 ŷ                          │
│                         ↓                                │
│              支持集 S = {(xᵢ, yᵢ)} 作为外部记忆            │
└─────────────────────────────────────────────────────────┘

与LSTM记忆的区别

组件Neural Turing MachineMatching Networks
记忆可学习的内部记忆固定的支持集
读取内容寻址 + 位置寻址仅内容寻址
写入可修改不可修改

代码实现

基础匹配网络

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class MatchingNetwork(nn.Module):
    """
    匹配网络
    
    支持集编码器 g 和查询集编码器 f 可以是独立的或共享的
    """
    def __init__(
        self,
        encoder: nn.Module,
        attention: str = 'cosine',
        metric: str = 'cosine'
    ):
        super().__init__()
        self.encoder = encoder
        self.attention = attention
        self.metric = metric
        
        if attention == 'cosine':
            self.cos = nn.CosineSimilarity(dim=-1, eps=1e-7)
    
    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        way: int,
        shot: int
    ) -> torch.Tensor:
        """
        Args:
            support_x: 支持集特征 (N*K, feature_dim)
            support_y: 支持集标签 (N*K,) - 整数标签
            query_x: 查询集特征 (N*Q, feature_dim)
            way: 类别数
            shot: 每类样本数
        
        Returns:
            query_logits: 查询集预测 logits
        """
        # 编码
        support_emb = self.encoder(support_x)
        query_emb = self.encoder(query_x)
        
        # 计算注意力(支持集对查询集)
        if self.attention == 'cosine':
            # 计算余弦相似度矩阵
            # query_emb: (Q, d), support_emb: (S, d)
            attn = self.cosine_attention(query_emb, support_emb)
        elif self.attention == 'dot':
            attn = torch.mm(query_emb, support_emb.t())
            attn = F.softmax(attn, dim=-1)
        else:
            raise ValueError(f"Unknown attention: {self.attention}")
        
        # 将标签转为one-hot
        support_y_onehot = F.one_hot(support_y, num_classes=way).float()
        
        # 加权求和
        predictions = torch.mm(attn, support_y_onehot)
        
        # 归一化(确保概率和为1)
        predictions = predictions / (predictions.sum(dim=-1, keepdim=True) + 1e-8)
        
        return torch.log(predictions + 1e-8)
    
    def cosine_attention(self, query_emb: torch.Tensor, support_emb: torch.Tensor) -> torch.Tensor:
        """
        计算余弦注意力矩阵
        
        Args:
            query_emb: (Q, d)
            support_emb: (S, d)
        
        Returns:
            attn: (Q, S)
        """
        # 余弦相似度
        cos_sim = self.cos(query_emb.unsqueeze(1), support_emb.unsqueeze(0))
        
        # Softmax归一化
        attn = F.softmax(cos_sim, dim=-1)
        
        return attn
 
 
class BidirectionalLSTMEncoder(nn.Module):
    """
    双向LSTM编码器
    
    论文中提出的支持集编码器 g
    
    特点:
    - 将整个支持集作为一个序列
    - 使用BiLSTM编码每个样本的上下文
    """
    def __init__(self, in_dim: int, hidden_dim: int):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=in_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        # 映射到输出维度
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len, in_dim) - 视为序列
        
        Returns:
            out: (batch, seq_len, hidden_dim)
        """
        lstm_out, _ = self.lstm(x)
        return self.fc(lstm_out)
 
 
class FullContextEmbedding(nn.Module):
    """
    全上下文嵌入
    
    使用BiLSTM对支持集进行编码,加入样本间的上下文信息
    """
    def __init__(self, in_dim: int, embed_dim: int):
        super().__init__()
        self.encoder = BidirectionalLSTMEncoder(in_dim, embed_dim // 2)
        self.out_proj = nn.Linear(embed_dim // 2, embed_dim)
    
    def forward(self, support_x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            support_x: (N*K, feature_dim) 或 (batch, N*K, feature_dim)
        
        Returns:
            embeddings: 全上下文嵌入
        """
        # 重新reshape为序列
        if support_x.dim() == 2:
            support_x = support_x.unsqueeze(0)
            squeeze = True
        else:
            squeeze = False
        
        batch_size, seq_len, feature_dim = support_x.shape
        
        # BiLSTM编码
        embeddings = self.encoder(support_x)  # (batch, seq, embed_dim/2)
        
        # 最终嵌入 = 平均 + LSTMOUT的最后状态(简化)
        final_embed = embeddings.mean(1)  # (batch, embed_dim/2)
        
        # 广播到每个位置
        embeddings = embeddings + final_embed.unsqueeze(1)
        embeddings = self.out_proj(embeddings)
        
        if squeeze:
            embeddings = embeddings.squeeze(0)
        
        return embeddings
 
 
class SimpleEncoder(nn.Module):
    """
    简单CNN编码器(用于图像)
    """
    def __init__(self, in_channels=1, out_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, out_dim)
        )
    
    def forward(self, x):
        return self.net(x)
 
 
class RelationNetworkAttention(nn.Module):
    """
    Relation Network注意力
    
    使用MLP计算样本间的关系分数
    """
    def __init__(self, embed_dim: int, hidden_dim: int = 8):
        super().__init__()
        self.relation_module = nn.Sequential(
            nn.Linear(embed_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, query_emb: torch.Tensor, support_emb: torch.Tensor) -> torch.Tensor:
        """
        Args:
            query_emb: (Q, d)
            support_emb: (S, d)
        
        Returns:
            attn: (Q, S)
        """
        # 构造所有(query, support)对
        query_expand = query_emb.unsqueeze(1).expand(-1, support_emb.size(0), -1)
        support_expand = support_emb.unsqueeze(0).expand(query_emb.size(0), -1, -1)
        
        # 拼接
        pairs = torch.cat([query_expand, support_expand], dim=-1)
        
        # 关系分数
        relations = self.relation_module(pairs).squeeze(-1)
        
        # Softmax
        return F.softmax(relations, dim=-1)

Episode训练

def train_matching_net():
    """匹配网络训练"""
    import numpy as np
    
    WAY = 5
    SHOT = 1
    QUERY = 15
    EMBED_DIM = 64
    EPOCHS = 100
    
    # 模型
    encoder = SimpleEncoder(out_dim=EMBED_DIM)
    model = MatchingNetwork(encoder, attention='cosine')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # 训练循环
    for epoch in range(EPOCHS):
        # 采样Episode(这里用随机数据模拟)
        support_x = torch.randn(WAY * SHOT, 1, 28, 28)
        support_y = torch.LongTensor([i // SHOT for i in range(WAY * SHOT)])
        query_x = torch.randn(WAY * QUERY, 1, 28, 28)
        query_y = torch.LongTensor([i // QUERY for i in range(WAY * QUERY)])
        
        # 前向
        log_probs = model(support_x, support_y, query_x, WAY, SHOT)
        
        # 损失
        loss = F.nll_loss(log_probs, query_y)
        
        # 反向
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计
        preds = log_probs.argmax(dim=-1)
        acc = (preds == query_y).float().mean()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={acc:.4f}")
 
 
if __name__ == '__main__':
    train_matching_net()

匹配网络 vs 原型网络

核心区别

方面匹配网络原型网络
表示方式每个样本独立类内平均
注意力所有支持样本加权仅类原型
计算复杂度
表达能力更强更简洁

数学对比

匹配网络:
ŷ = Σᵢ a(x̂, xᵢ) · yᵢ

原型网络(匹配网络的特例):
cₖ = (1/K) Σᵢ₌₁ᴷ xᵢ  for yᵢ=k
ŷ = Σₖ a(x̂, cₖ) · yₖ

原型网络可以看作匹配网络的一种硬注意力变体。


扩展:Transductive Matching Networks

转导设置

利用查询集样本帮助分类:

class TransductiveMatchingNet(nn.Module):
    """
    转导匹配网络
    
    使用查询集样本更新注意力权重
    """
    def __init__(self, encoder, way, shot, query):
        super().__init__()
        self.encoder = encoder
        self.way = way
        self.shot = shot
        self.query = query
    
    def forward(self, support_x, support_y, query_x, num_iterations=5):
        """
        转导推理
        """
        # 初始编码
        support_emb = self.encoder(support_x)
        query_emb = self.encoder(query_x)
        
        # 标签one-hot
        support_y_onehot = F.one_hot(support_y, num_classes=self.way).float()
        
        # 迭代更新
        for _ in range(num_iterations):
            # 计算相似度
            sim = torch.mm(query_emb, torch.cat([support_emb, query_emb]).t())
            
            # 软标签(包含查询集)
            all_labels = torch.cat([support_y_onehot, torch.zeros(self.way, self.way)])
            all_labels[torch.arange(self.way) + self.way, torch.arange(self.way)] = 1
            
            # 更新预测
            attn = F.softmax(sim, dim=-1)
            preds = torch.mm(attn, all_labels)
            
            # 用预测更新查询集嵌入(可选)
            # query_emb = query_emb + ... 
        
        return torch.log(preds + 1e-8)

参考文献

相关文章

Footnotes

  1. Vinyals, O., Blundell, C., Lillicrap, T., & Wierstra, D. (2016). “Matching Networks for One Shot Learning”. Advances in Neural Information Processing Systems (NeurIPS).