NAS评估方法

评估方法是NAS系统的关键瓶颈之一。由于候选架构数量庞大(通常以上),完整训练每个候选架构代价极高。本章详细介绍各类评估方法:One-Shot评估、Zero-Cost代理和预测器方法。

一、评估方法的分类

┌─────────────────────────────────────────────────────────────┐
│                    NAS评估方法                              │
├──────────────────┬──────────────────┬───────────────────────┤
│   需要训练       │   需要训练        │    不需要训练         │
│   (低/中成本)    │   (高成本)        │    (极低成本)         │
├──────────────────┼──────────────────┼───────────────────────┤
│ One-Shot (超网络)│ Full Training    │ Zero-Cost Proxy       │
│ Predictor (预测器)│ Progressive NAS │ SynFlow, Zen-NAS      │
│ K-shot NAS       │                  │ GradNorm, NASWOT      │
└──────────────────┴──────────────────┴───────────────────────┘

二、One-Shot评估与超网络

2.1 核心思想

构建一个包含所有候选架构的”超网络”(Supernet),子架构直接继承共享权重:

┌──────────────────────────────────────────────────────────┐
│                      超网络 (Supernet)                    │
│                                                          │
│  Layer 1  ─┬─ [Conv3x3] ──┬─ [Conv5x5] ──┬─ [Dil3x3] ──┐ │
│            └─ [Skip]     └─ [Pool]      └─ [None]    │ │
│                           │                            │ │
│  Layer 2  ─┬─ [Conv3x3] ──┴─ [Conv5x5] ──┬─ [Dil3x3] ──┼─┼─→ Out
│            └─ [Skip]     ┌─ [Pool]      └─ [None]    │ │
│                           │                            │ │
│  子架构1: Conv3x3→Conv3x3  │  子架构2: Skip→Conv5x5    │ │
│  (直接继承权重)              │  (直接继承权重)           │ │
└──────────────────────────────────────────────────────────┘

2.2 权重复用策略

标准One-Shot

所有子架构共享完全相同的权重:

class OneShotSupernet(nn.Module):
    def __init__(self, num_ops=6, hidden_dim=64):
        super().__init__()
        # 所有操作共享同一组权重
        self.conv3x3 = Conv2d(hidden_dim, hidden_dim, 3, padding=1)
        self.conv5x5 = Conv2d(hidden_dim, hidden_dim, 5, padding=2)
        self.skip = SkipConnection()
        
    def forward(self, x, arch):
        """
        arch: 架构选择 [op1, op2, op3, ...]
        """
        for op_id in arch:
            if op_id == 0:
                x = self.conv3x3(x)
            elif op_id == 1:
                x = self.conv5x5(x)
            # ...
        return x

优点:实现简单,内存效率高
缺点:不同架构间存在权重干扰

K-shot NAS

维护K个不同的权重空间,减少干扰:

class KShotSupernet(nn.Module):
    def __init__(self, k=3, num_ops=6, hidden_dim=64):
        super().__init__()
        self.k = k
        # 每个操作有K组独立权重
        self.op_weights = nn.ParameterList([
            nn.Parameter(torch.randn(hidden_dim, hidden_dim))
            for _ in range(k * num_ops)
        ])
        
    def forward(self, x, arch, path_id):
        """
        path_id: 选择使用哪组权重空间 (0 to k-1)
        """
        offset = path_id * self.num_ops
        for i, op_id in enumerate(arch):
            # 使用对应路径的权重
            w = self.op_weights[offset + op_id]
            x = F.conv2d(x, w)
        return x

2.3 超网络训练策略

class SupernetTrainer:
    def __init__(self, supernet, search_space):
        self.supernet = supernet
        self.search_space = search_space
        
    def train_uniform(self, loader, epochs):
        """均匀采样训练"""
        for epoch in range(epochs):
            for batch in loader:
                arch = self.search_space.uniform_sample()
                loss = self.compute_loss(self.supernet, batch, arch)
                loss.backward()
                self.optimizer.step()
                
    def train_arch_gradient(self, loader, epochs):
        """架构梯度训练 (DARTS风格)"""
        for epoch in range(epochs):
            # 1. 更新权重 (下层)
            self.weight_optimizer.zero_grad()
            arch = self.search_space.sample()
            train_loss = self.compute_loss(self.supernet, train_data, arch)
            train_loss.backward()
            self.weight_optimizer.step()
            
            # 2. 更新架构参数 (上层)
            self.arch_optimizer.zero_grad()
            val_loss = self.compute_loss(self.supernet, val_data, arch)
            val_loss.backward()
            self.arch_optimizer.step()

2.4 超网络的问题与挑战

问题描述解决方案
权重耦合不同架构共享权重导致冲突K-shot, 路径分离
过度评估某些操作被高估/低估FairDARTS, 早停
性能崩溃训练后期所有架构都变差早停策略
置信度偏差与真实性能相关性不稳定蒸馏, 微调

三、Zero-Cost评估指标

3.1 核心思想

完全不需要训练网络,直接从随机初始化的网络快照中评估架构质量。

┌─────────────────────────────────────────────────────────┐
│              Zero-Cost NAS 流程                         │
│                                                         │
│   随机初始化网络 ──→ 前向传播 ──→ 计算指标 ──→ 排序架构   │
│         │                                              │
│         │  (无需训练! 耗时: 秒级)                       │
│         ▼                                              │
│   指标包括: 梯度、激活、参数范数、批归一化统计量等        │
└─────────────────────────────────────────────────────────┘

3.2 经典Zero-Cost指标

SynFlow (2021)

基于”参数重要性”的度量:

def synflow_score(model, data):
    """
    SynFlow核心思想:好的架构应该对参数具有均匀的敏感性
    """
    # 用全1输入计算梯度
    x = torch.ones_like(list(model.parameters())[0])
    
    def hook(module, input, output):
        return output * x
    
    # 注册hook并前向传播
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            m.register_forward_hook(hook)
    
    with torch.no_grad():
        output = model(data)
    
    # 分数 = 所有参数乘以梯度的和
    score = sum((p.abs() * p.grad).sum() for p in model.parameters())
    return score

Zen-NAS (2021)

利用批归一化统计量评估架构:

def zen_score(model, data_loader):
    """
    Zen-NAS核心思想:好的架构应该有"良好校准"的BN统计量
    """
    model.eval()
    moments = []
    
    with torch.no_grad():
        for x, _ in data_loader:
            x = x.cuda()
            for module in model.modules():
                if isinstance(module, nn.BatchNorm2d):
                    # 收集统计量
                    moments.append(module.running_mean.var())
                    moments.append(module.running_var.mean())
    
    # Zen分数 = 统计量的负熵
    # 好的架构:统计量分布均匀(高熵)
    score = -sum(m * torch.log(m + 1e-10) for m in moments)
    return score

GradNorm (2021)

基于梯度范数的度量:

def gradnorm_score(model, data):
    """
    GradNorm:好的架构应该有适中的梯度规模
    """
    grads = []
    
    for x, y in data:
        loss = F.cross_entropy(model(x), y)
        loss.backward()
        
        # 收集参数梯度范数
        grad_norms = [p.grad.norm().item() 
                      for p in model.parameters() if p.grad is not None]
        grads.append(grad_norms)
    
    # 平均梯度范数作为分数
    return np.mean(grads)

NASWOT (2020)

基于激活轨迹的图核方法:

def naswot_score(model, data):
    """
    NASWOT:利用激活轨迹构建图核
    """
    model.eval()
    signatures = []
    
    def hook(module, input, output):
        signatures.append(input[0].detach().clone())
    
    # 注册hook
    handles = []
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            handles.append(m.register_forward_hook(hook))
    
    with torch.no_grad():
        for x, _ in data:
            model(x)
    
    # 清理hook
    for h in handles:
        h.remove()
    
    # 计算图核相似度
    return compute_graph_kernel(signatures)

3.3 2025年新指标:L-SWAG

Layer-Sample Wise Activation with Gradients

专门为Vision Transformer设计的Zero-Cost指标:

指标CIFAR-10ImageNet
SynFlow0.610.52
Zen-NAS0.940.71
L-SWAG0.950.89

四、预测器方法

4.1 核心思想

训练一个代理模型,直接预测给定架构的性能。

┌─────────────────────────────────────────────────────────┐
│                    预测器NAS流程                        │
│                                                         │
│  1. 采样架构集                                           │
│       │                                                 │
│       ▼                                                 │
│  2. 训练架构 → 收集(架构, 性能)对                         │
│       │                                                 │
│       ▼                                                 │
│  3. 训练预测器 f: 架构 → 性能                            │
│       │                                                 │
│       ▼                                                 │
│  4. 使用预测器评估新架构                                 │
│       │                                                 │
│       ▼                                                 │
│  5. 选择预测性能最佳的架构                                │
└─────────────────────────────────────────────────────────┘

4.2 架构编码方法

基于操作序列的编码

class SequentialEncoder:
    def encode(self, arch):
        """
        arch: [op1, op2, op3, ...] 操作序列
        """
        return torch.tensor(arch)  # 直接编码为向量

基于图神经网络的编码

class GNNEncoder(nn.Module):
    """利用GNN编码架构拓扑"""
    
    def __init__(self, node_dim=64, hidden_dim=128):
        super().__init__()
        self.node_embedding = nn.Embedding(num_ops, node_dim)
        self.gnn_layers = nn.ModuleList([
            GATConv(node_dim, hidden_dim),
            GATConv(hidden_dim, hidden_dim),
        ])
        self.readout = nn.Linear(hidden_dim, 1)
        
    def forward(self, arch_graph):
        """
        arch_graph: 架构的图表示
        """
        x = self.node_embedding(arch_graph.node_ids)
        
        for gnn in self.gnn_layers:
            x = gnn(x, arch_graph.edge_index)
            x = F.relu(x)
        
        # 图级别聚合
        graph_rep = torch.mean(x, dim=0)
        return self.readout(graph_rep)

4.3 代表性预测器方法

BANANAS

Bayesian Optimization with Neural Architectures

class BANANAS:
    def __init__(self):
        self.predictor = NeuralPredictor()
        self.acquisition = EI()  # Expected Improvement
        
    def search(self, initial_archs, n_trials):
        # 训练预测器
        self.predictor.fit(initial_archs)
        
        for _ in range(n_trials):
            # 选择下一个架构
            arch = self.select_next_arch()
            
            # 评估架构
            performance = evaluate(arch)
            
            # 更新预测器
            self.predictor.update(arch, performance)
            
        return self.predictor.best_arch()

ONNX-Net (2025)

跨基准迁移的预测器:

源基准目标基准迁移准确率
NAS-Bench-101NAS-Bench-201~85%
NAS-Bench-201ImageNet~78%

五、方法对比

5.1 计算成本对比

方法评估时间内存需求额外数据
Full Training数百GPU小时-
One-Shot (DARTS)~1 GPU天
K-shot NAS~1-2 GPU天很高
Predictor分钟级10-50K样本
SynFlow秒级
Zen-NAS秒级
L-SWAG秒级

5.2 准确性对比(SRCC on NAS-Bench-201)

方法SRCC备注
Random0.00基线
SynFlow0.77梯度敏感
NASWOT0.61激活轨迹
GradNorm0.72梯度范数
Zen-NAS0.94批归一化
L-SWAG0.95ViT专用

5.3 适用场景

场景推荐方法
快速架构筛选Zero-Cost (Zen-NAS)
高精度搜索Predictor + 验证
大规模搜索空间DARTS
资源受限环境Zero-Cost
Transformer架构L-SWAG

六、实际应用建议

6.1 选择评估方法的决策树

开始
  │
  ├─ 资源充足?
  │    ├─ 是 → Full Training 或 Predictor
  │    │
  │    └─ 否 → 资源受限?
  │         ├─ 是 → Zero-Cost Proxy
  │         │
  │         └─ 否 → One-Shot (DARTS)
  │
  └─ 架构类型?
       ├─ CNN → Zen-NAS, SynFlow
       ├─ ViT → L-SWAG
       └─ 其他 → Predictor

6.2 混合策略

实际应用中,常组合多种评估方法:

class HybridEvaluator:
    def __init__(self):
        # 1. 快速初筛
        self.zero_cost = ZenNAS()
        # 2. 精确评估
        self.predictor = NeuralPredictor()
        # 3. 最终验证
        self.full_train = FullTrainer()
        
    def evaluate(self, arch):
        # 第一阶段:Zero-Cost筛选
        zc_score = self.zero_cost(arch)
        if zc_score < threshold:
            return float('-inf')
        
        # 第二阶段:预测器排序
        pred_score = self.predictor(arch)
        
        # 第三阶段:选择top-k完整训练
        if pred_score > pred_threshold:
            return self.full_train(arch)
        
        return pred_score

七、参考论文