NAS评估方法
评估方法是NAS系统的关键瓶颈之一。由于候选架构数量庞大(通常以上),完整训练每个候选架构代价极高。本章详细介绍各类评估方法:One-Shot评估、Zero-Cost代理和预测器方法。
一、评估方法的分类
┌─────────────────────────────────────────────────────────────┐
│ NAS评估方法 │
├──────────────────┬──────────────────┬───────────────────────┤
│ 需要训练 │ 需要训练 │ 不需要训练 │
│ (低/中成本) │ (高成本) │ (极低成本) │
├──────────────────┼──────────────────┼───────────────────────┤
│ One-Shot (超网络)│ Full Training │ Zero-Cost Proxy │
│ Predictor (预测器)│ Progressive NAS │ SynFlow, Zen-NAS │
│ K-shot NAS │ │ GradNorm, NASWOT │
└──────────────────┴──────────────────┴───────────────────────┘
二、One-Shot评估与超网络
2.1 核心思想
构建一个包含所有候选架构的”超网络”(Supernet),子架构直接继承共享权重:
┌──────────────────────────────────────────────────────────┐
│ 超网络 (Supernet) │
│ │
│ Layer 1 ─┬─ [Conv3x3] ──┬─ [Conv5x5] ──┬─ [Dil3x3] ──┐ │
│ └─ [Skip] └─ [Pool] └─ [None] │ │
│ │ │ │
│ Layer 2 ─┬─ [Conv3x3] ──┴─ [Conv5x5] ──┬─ [Dil3x3] ──┼─┼─→ Out
│ └─ [Skip] ┌─ [Pool] └─ [None] │ │
│ │ │ │
│ 子架构1: Conv3x3→Conv3x3 │ 子架构2: Skip→Conv5x5 │ │
│ (直接继承权重) │ (直接继承权重) │ │
└──────────────────────────────────────────────────────────┘
2.2 权重复用策略
标准One-Shot
所有子架构共享完全相同的权重:
class OneShotSupernet(nn.Module):
def __init__(self, num_ops=6, hidden_dim=64):
super().__init__()
# 所有操作共享同一组权重
self.conv3x3 = Conv2d(hidden_dim, hidden_dim, 3, padding=1)
self.conv5x5 = Conv2d(hidden_dim, hidden_dim, 5, padding=2)
self.skip = SkipConnection()
def forward(self, x, arch):
"""
arch: 架构选择 [op1, op2, op3, ...]
"""
for op_id in arch:
if op_id == 0:
x = self.conv3x3(x)
elif op_id == 1:
x = self.conv5x5(x)
# ...
return x优点:实现简单,内存效率高
缺点:不同架构间存在权重干扰
K-shot NAS
维护K个不同的权重空间,减少干扰:
class KShotSupernet(nn.Module):
def __init__(self, k=3, num_ops=6, hidden_dim=64):
super().__init__()
self.k = k
# 每个操作有K组独立权重
self.op_weights = nn.ParameterList([
nn.Parameter(torch.randn(hidden_dim, hidden_dim))
for _ in range(k * num_ops)
])
def forward(self, x, arch, path_id):
"""
path_id: 选择使用哪组权重空间 (0 to k-1)
"""
offset = path_id * self.num_ops
for i, op_id in enumerate(arch):
# 使用对应路径的权重
w = self.op_weights[offset + op_id]
x = F.conv2d(x, w)
return x2.3 超网络训练策略
class SupernetTrainer:
def __init__(self, supernet, search_space):
self.supernet = supernet
self.search_space = search_space
def train_uniform(self, loader, epochs):
"""均匀采样训练"""
for epoch in range(epochs):
for batch in loader:
arch = self.search_space.uniform_sample()
loss = self.compute_loss(self.supernet, batch, arch)
loss.backward()
self.optimizer.step()
def train_arch_gradient(self, loader, epochs):
"""架构梯度训练 (DARTS风格)"""
for epoch in range(epochs):
# 1. 更新权重 (下层)
self.weight_optimizer.zero_grad()
arch = self.search_space.sample()
train_loss = self.compute_loss(self.supernet, train_data, arch)
train_loss.backward()
self.weight_optimizer.step()
# 2. 更新架构参数 (上层)
self.arch_optimizer.zero_grad()
val_loss = self.compute_loss(self.supernet, val_data, arch)
val_loss.backward()
self.arch_optimizer.step()2.4 超网络的问题与挑战
| 问题 | 描述 | 解决方案 |
|---|---|---|
| 权重耦合 | 不同架构共享权重导致冲突 | K-shot, 路径分离 |
| 过度评估 | 某些操作被高估/低估 | FairDARTS, 早停 |
| 性能崩溃 | 训练后期所有架构都变差 | 早停策略 |
| 置信度偏差 | 与真实性能相关性不稳定 | 蒸馏, 微调 |
三、Zero-Cost评估指标
3.1 核心思想
完全不需要训练网络,直接从随机初始化的网络快照中评估架构质量。
┌─────────────────────────────────────────────────────────┐
│ Zero-Cost NAS 流程 │
│ │
│ 随机初始化网络 ──→ 前向传播 ──→ 计算指标 ──→ 排序架构 │
│ │ │
│ │ (无需训练! 耗时: 秒级) │
│ ▼ │
│ 指标包括: 梯度、激活、参数范数、批归一化统计量等 │
└─────────────────────────────────────────────────────────┘
3.2 经典Zero-Cost指标
SynFlow (2021)
基于”参数重要性”的度量:
def synflow_score(model, data):
"""
SynFlow核心思想:好的架构应该对参数具有均匀的敏感性
"""
# 用全1输入计算梯度
x = torch.ones_like(list(model.parameters())[0])
def hook(module, input, output):
return output * x
# 注册hook并前向传播
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
m.register_forward_hook(hook)
with torch.no_grad():
output = model(data)
# 分数 = 所有参数乘以梯度的和
score = sum((p.abs() * p.grad).sum() for p in model.parameters())
return scoreZen-NAS (2021)
利用批归一化统计量评估架构:
def zen_score(model, data_loader):
"""
Zen-NAS核心思想:好的架构应该有"良好校准"的BN统计量
"""
model.eval()
moments = []
with torch.no_grad():
for x, _ in data_loader:
x = x.cuda()
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
# 收集统计量
moments.append(module.running_mean.var())
moments.append(module.running_var.mean())
# Zen分数 = 统计量的负熵
# 好的架构:统计量分布均匀(高熵)
score = -sum(m * torch.log(m + 1e-10) for m in moments)
return scoreGradNorm (2021)
基于梯度范数的度量:
def gradnorm_score(model, data):
"""
GradNorm:好的架构应该有适中的梯度规模
"""
grads = []
for x, y in data:
loss = F.cross_entropy(model(x), y)
loss.backward()
# 收集参数梯度范数
grad_norms = [p.grad.norm().item()
for p in model.parameters() if p.grad is not None]
grads.append(grad_norms)
# 平均梯度范数作为分数
return np.mean(grads)NASWOT (2020)
基于激活轨迹的图核方法:
def naswot_score(model, data):
"""
NASWOT:利用激活轨迹构建图核
"""
model.eval()
signatures = []
def hook(module, input, output):
signatures.append(input[0].detach().clone())
# 注册hook
handles = []
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
handles.append(m.register_forward_hook(hook))
with torch.no_grad():
for x, _ in data:
model(x)
# 清理hook
for h in handles:
h.remove()
# 计算图核相似度
return compute_graph_kernel(signatures)3.3 2025年新指标:L-SWAG
Layer-Sample Wise Activation with Gradients
专门为Vision Transformer设计的Zero-Cost指标:
| 指标 | CIFAR-10 | ImageNet |
|---|---|---|
| SynFlow | 0.61 | 0.52 |
| Zen-NAS | 0.94 | 0.71 |
| L-SWAG | 0.95 | 0.89 |
四、预测器方法
4.1 核心思想
训练一个代理模型,直接预测给定架构的性能。
┌─────────────────────────────────────────────────────────┐
│ 预测器NAS流程 │
│ │
│ 1. 采样架构集 │
│ │ │
│ ▼ │
│ 2. 训练架构 → 收集(架构, 性能)对 │
│ │ │
│ ▼ │
│ 3. 训练预测器 f: 架构 → 性能 │
│ │ │
│ ▼ │
│ 4. 使用预测器评估新架构 │
│ │ │
│ ▼ │
│ 5. 选择预测性能最佳的架构 │
└─────────────────────────────────────────────────────────┘
4.2 架构编码方法
基于操作序列的编码
class SequentialEncoder:
def encode(self, arch):
"""
arch: [op1, op2, op3, ...] 操作序列
"""
return torch.tensor(arch) # 直接编码为向量基于图神经网络的编码
class GNNEncoder(nn.Module):
"""利用GNN编码架构拓扑"""
def __init__(self, node_dim=64, hidden_dim=128):
super().__init__()
self.node_embedding = nn.Embedding(num_ops, node_dim)
self.gnn_layers = nn.ModuleList([
GATConv(node_dim, hidden_dim),
GATConv(hidden_dim, hidden_dim),
])
self.readout = nn.Linear(hidden_dim, 1)
def forward(self, arch_graph):
"""
arch_graph: 架构的图表示
"""
x = self.node_embedding(arch_graph.node_ids)
for gnn in self.gnn_layers:
x = gnn(x, arch_graph.edge_index)
x = F.relu(x)
# 图级别聚合
graph_rep = torch.mean(x, dim=0)
return self.readout(graph_rep)4.3 代表性预测器方法
BANANAS
Bayesian Optimization with Neural Architectures
class BANANAS:
def __init__(self):
self.predictor = NeuralPredictor()
self.acquisition = EI() # Expected Improvement
def search(self, initial_archs, n_trials):
# 训练预测器
self.predictor.fit(initial_archs)
for _ in range(n_trials):
# 选择下一个架构
arch = self.select_next_arch()
# 评估架构
performance = evaluate(arch)
# 更新预测器
self.predictor.update(arch, performance)
return self.predictor.best_arch()ONNX-Net (2025)
跨基准迁移的预测器:
| 源基准 | 目标基准 | 迁移准确率 |
|---|---|---|
| NAS-Bench-101 | NAS-Bench-201 | ~85% |
| NAS-Bench-201 | ImageNet | ~78% |
五、方法对比
5.1 计算成本对比
| 方法 | 评估时间 | 内存需求 | 额外数据 |
|---|---|---|---|
| Full Training | 数百GPU小时 | - | 无 |
| One-Shot (DARTS) | ~1 GPU天 | 高 | 无 |
| K-shot NAS | ~1-2 GPU天 | 很高 | 无 |
| Predictor | 分钟级 | 中 | 10-50K样本 |
| SynFlow | 秒级 | 低 | 无 |
| Zen-NAS | 秒级 | 低 | 无 |
| L-SWAG | 秒级 | 低 | 无 |
5.2 准确性对比(SRCC on NAS-Bench-201)
| 方法 | SRCC | 备注 |
|---|---|---|
| Random | 0.00 | 基线 |
| SynFlow | 0.77 | 梯度敏感 |
| NASWOT | 0.61 | 激活轨迹 |
| GradNorm | 0.72 | 梯度范数 |
| Zen-NAS | 0.94 | 批归一化 |
| L-SWAG | 0.95 | ViT专用 |
5.3 适用场景
| 场景 | 推荐方法 |
|---|---|
| 快速架构筛选 | Zero-Cost (Zen-NAS) |
| 高精度搜索 | Predictor + 验证 |
| 大规模搜索空间 | DARTS |
| 资源受限环境 | Zero-Cost |
| Transformer架构 | L-SWAG |
六、实际应用建议
6.1 选择评估方法的决策树
开始
│
├─ 资源充足?
│ ├─ 是 → Full Training 或 Predictor
│ │
│ └─ 否 → 资源受限?
│ ├─ 是 → Zero-Cost Proxy
│ │
│ └─ 否 → One-Shot (DARTS)
│
└─ 架构类型?
├─ CNN → Zen-NAS, SynFlow
├─ ViT → L-SWAG
└─ 其他 → Predictor
6.2 混合策略
实际应用中,常组合多种评估方法:
class HybridEvaluator:
def __init__(self):
# 1. 快速初筛
self.zero_cost = ZenNAS()
# 2. 精确评估
self.predictor = NeuralPredictor()
# 3. 最终验证
self.full_train = FullTrainer()
def evaluate(self, arch):
# 第一阶段:Zero-Cost筛选
zc_score = self.zero_cost(arch)
if zc_score < threshold:
return float('-inf')
# 第二阶段:预测器排序
pred_score = self.predictor(arch)
# 第三阶段:选择top-k完整训练
if pred_score > pred_threshold:
return self.full_train(arch)
return pred_score