匹配网络
匹配网络(Matching Networks)由Oriol Vinyals等人于2016年提出,是最早将注意力机制引入少样本学习的元学习方法之一。1 与原型网络不同,匹配网络为每个支持样本分配注意力权重,而非仅依赖类原型。
核心思想
与原型网络的关键区别
原型网络:
同类样本 → 平均 → 类原型 → 与Query比较
匹配网络:
每个Support样本 → 计算注意力权重 → 加权组合 → 与Query比较
注意力机制
其中:
- 是Query样本
- 是支持集的特征和标签
- 是注意力函数
数学推导
任务定义
给定支持集 和查询样本 :
- 编码支持集:,称为支持集编码器
- 编码查询集:,称为查询集编码器
- 注意力计算:
注意力函数
Cosine相似度注意力
MLP注意力(Relation Network)
分类输出
注意:这里 是 one-hot 编码的标签向量。
完全注意力记忆(FOCAL Attention)
支持集作为外部记忆
匹配网络可以看作一个记忆增强神经网络(Memory-Augmented Neural Network):
┌─────────────────────────────────────────────────────────┐
│ 查询样本 x̂ │
│ ↓ │
│ 查询编码 f(x̂) │
│ ↓ │
│ ┌──────────────────────┐ │
│ │ 注意力机制 │ │
│ │ a(f(x̂), g(x₁)) │ │
│ │ a(f(x̂), g(x₂)) ... │ │
│ └──────────────────────┘ │
│ ↓ │
│ 加权组合 → 预测标签 ŷ │
│ ↓ │
│ 支持集 S = {(xᵢ, yᵢ)} 作为外部记忆 │
└─────────────────────────────────────────────────────────┘
与LSTM记忆的区别
| 组件 | Neural Turing Machine | Matching Networks |
|---|---|---|
| 记忆 | 可学习的内部记忆 | 固定的支持集 |
| 读取 | 内容寻址 + 位置寻址 | 仅内容寻址 |
| 写入 | 可修改 | 不可修改 |
代码实现
基础匹配网络
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MatchingNetwork(nn.Module):
"""
匹配网络
支持集编码器 g 和查询集编码器 f 可以是独立的或共享的
"""
def __init__(
self,
encoder: nn.Module,
attention: str = 'cosine',
metric: str = 'cosine'
):
super().__init__()
self.encoder = encoder
self.attention = attention
self.metric = metric
if attention == 'cosine':
self.cos = nn.CosineSimilarity(dim=-1, eps=1e-7)
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
way: int,
shot: int
) -> torch.Tensor:
"""
Args:
support_x: 支持集特征 (N*K, feature_dim)
support_y: 支持集标签 (N*K,) - 整数标签
query_x: 查询集特征 (N*Q, feature_dim)
way: 类别数
shot: 每类样本数
Returns:
query_logits: 查询集预测 logits
"""
# 编码
support_emb = self.encoder(support_x)
query_emb = self.encoder(query_x)
# 计算注意力(支持集对查询集)
if self.attention == 'cosine':
# 计算余弦相似度矩阵
# query_emb: (Q, d), support_emb: (S, d)
attn = self.cosine_attention(query_emb, support_emb)
elif self.attention == 'dot':
attn = torch.mm(query_emb, support_emb.t())
attn = F.softmax(attn, dim=-1)
else:
raise ValueError(f"Unknown attention: {self.attention}")
# 将标签转为one-hot
support_y_onehot = F.one_hot(support_y, num_classes=way).float()
# 加权求和
predictions = torch.mm(attn, support_y_onehot)
# 归一化(确保概率和为1)
predictions = predictions / (predictions.sum(dim=-1, keepdim=True) + 1e-8)
return torch.log(predictions + 1e-8)
def cosine_attention(self, query_emb: torch.Tensor, support_emb: torch.Tensor) -> torch.Tensor:
"""
计算余弦注意力矩阵
Args:
query_emb: (Q, d)
support_emb: (S, d)
Returns:
attn: (Q, S)
"""
# 余弦相似度
cos_sim = self.cos(query_emb.unsqueeze(1), support_emb.unsqueeze(0))
# Softmax归一化
attn = F.softmax(cos_sim, dim=-1)
return attn
class BidirectionalLSTMEncoder(nn.Module):
"""
双向LSTM编码器
论文中提出的支持集编码器 g
特点:
- 将整个支持集作为一个序列
- 使用BiLSTM编码每个样本的上下文
"""
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.lstm = nn.LSTM(
input_size=in_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=True,
batch_first=True
)
# 映射到输出维度
self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, in_dim) - 视为序列
Returns:
out: (batch, seq_len, hidden_dim)
"""
lstm_out, _ = self.lstm(x)
return self.fc(lstm_out)
class FullContextEmbedding(nn.Module):
"""
全上下文嵌入
使用BiLSTM对支持集进行编码,加入样本间的上下文信息
"""
def __init__(self, in_dim: int, embed_dim: int):
super().__init__()
self.encoder = BidirectionalLSTMEncoder(in_dim, embed_dim // 2)
self.out_proj = nn.Linear(embed_dim // 2, embed_dim)
def forward(self, support_x: torch.Tensor) -> torch.Tensor:
"""
Args:
support_x: (N*K, feature_dim) 或 (batch, N*K, feature_dim)
Returns:
embeddings: 全上下文嵌入
"""
# 重新reshape为序列
if support_x.dim() == 2:
support_x = support_x.unsqueeze(0)
squeeze = True
else:
squeeze = False
batch_size, seq_len, feature_dim = support_x.shape
# BiLSTM编码
embeddings = self.encoder(support_x) # (batch, seq, embed_dim/2)
# 最终嵌入 = 平均 + LSTMOUT的最后状态(简化)
final_embed = embeddings.mean(1) # (batch, embed_dim/2)
# 广播到每个位置
embeddings = embeddings + final_embed.unsqueeze(1)
embeddings = self.out_proj(embeddings)
if squeeze:
embeddings = embeddings.squeeze(0)
return embeddings
class SimpleEncoder(nn.Module):
"""
简单CNN编码器(用于图像)
"""
def __init__(self, in_channels=1, out_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, out_dim)
)
def forward(self, x):
return self.net(x)
class RelationNetworkAttention(nn.Module):
"""
Relation Network注意力
使用MLP计算样本间的关系分数
"""
def __init__(self, embed_dim: int, hidden_dim: int = 8):
super().__init__()
self.relation_module = nn.Sequential(
nn.Linear(embed_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, query_emb: torch.Tensor, support_emb: torch.Tensor) -> torch.Tensor:
"""
Args:
query_emb: (Q, d)
support_emb: (S, d)
Returns:
attn: (Q, S)
"""
# 构造所有(query, support)对
query_expand = query_emb.unsqueeze(1).expand(-1, support_emb.size(0), -1)
support_expand = support_emb.unsqueeze(0).expand(query_emb.size(0), -1, -1)
# 拼接
pairs = torch.cat([query_expand, support_expand], dim=-1)
# 关系分数
relations = self.relation_module(pairs).squeeze(-1)
# Softmax
return F.softmax(relations, dim=-1)Episode训练
def train_matching_net():
"""匹配网络训练"""
import numpy as np
WAY = 5
SHOT = 1
QUERY = 15
EMBED_DIM = 64
EPOCHS = 100
# 模型
encoder = SimpleEncoder(out_dim=EMBED_DIM)
model = MatchingNetwork(encoder, attention='cosine')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(EPOCHS):
# 采样Episode(这里用随机数据模拟)
support_x = torch.randn(WAY * SHOT, 1, 28, 28)
support_y = torch.LongTensor([i // SHOT for i in range(WAY * SHOT)])
query_x = torch.randn(WAY * QUERY, 1, 28, 28)
query_y = torch.LongTensor([i // QUERY for i in range(WAY * QUERY)])
# 前向
log_probs = model(support_x, support_y, query_x, WAY, SHOT)
# 损失
loss = F.nll_loss(log_probs, query_y)
# 反向
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计
preds = log_probs.argmax(dim=-1)
acc = (preds == query_y).float().mean()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={acc:.4f}")
if __name__ == '__main__':
train_matching_net()匹配网络 vs 原型网络
核心区别
| 方面 | 匹配网络 | 原型网络 |
|---|---|---|
| 表示方式 | 每个样本独立 | 类内平均 |
| 注意力 | 所有支持样本加权 | 仅类原型 |
| 计算复杂度 | ||
| 表达能力 | 更强 | 更简洁 |
数学对比
匹配网络:
ŷ = Σᵢ a(x̂, xᵢ) · yᵢ
原型网络(匹配网络的特例):
cₖ = (1/K) Σᵢ₌₁ᴷ xᵢ for yᵢ=k
ŷ = Σₖ a(x̂, cₖ) · yₖ
原型网络可以看作匹配网络的一种硬注意力变体。
扩展:Transductive Matching Networks
转导设置
利用查询集样本帮助分类:
class TransductiveMatchingNet(nn.Module):
"""
转导匹配网络
使用查询集样本更新注意力权重
"""
def __init__(self, encoder, way, shot, query):
super().__init__()
self.encoder = encoder
self.way = way
self.shot = shot
self.query = query
def forward(self, support_x, support_y, query_x, num_iterations=5):
"""
转导推理
"""
# 初始编码
support_emb = self.encoder(support_x)
query_emb = self.encoder(query_x)
# 标签one-hot
support_y_onehot = F.one_hot(support_y, num_classes=self.way).float()
# 迭代更新
for _ in range(num_iterations):
# 计算相似度
sim = torch.mm(query_emb, torch.cat([support_emb, query_emb]).t())
# 软标签(包含查询集)
all_labels = torch.cat([support_y_onehot, torch.zeros(self.way, self.way)])
all_labels[torch.arange(self.way) + self.way, torch.arange(self.way)] = 1
# 更新预测
attn = F.softmax(sim, dim=-1)
preds = torch.mm(attn, all_labels)
# 用预测更新查询集嵌入(可选)
# query_emb = query_emb + ...
return torch.log(preds + 1e-8)参考文献
相关文章
Footnotes
-
Vinyals, O., Blundell, C., Lillicrap, T., & Wierstra, D. (2016). “Matching Networks for One Shot Learning”. Advances in Neural Information Processing Systems (NeurIPS). ↩