免训练NAS:零代价代理指标

免训练神经网络架构搜索(Zero-Cost NAS)1是近年来最重要的研究突破之一,它通过设计零代价代理指标(Zero-Cost Proxy Metrics),无需完整训练即可预测神经网络架构的性能,大幅降低了NAS的计算成本。


1. 核心思想

1.1 问题背景

传统NAS方法面临的核心问题是评估代价高昂

方法评估代价搜索空间大小总代价
完整训练~1 GPU day不现实
权重共享~0.01 GPU day~
零代价代理~0 GPU day可行

1.2 基本假设

零代价代理基于以下假设:

核心假设:神经网络的初始状态(随机权重)包含了其表达能力的关键信息。

这与神经网络的Lazy Training现象有关:在训练初期,神经网络近似线性,权重的主要结构决定网络的表达能力。

1.3 代理指标分类

┌─────────────────────────────────────────────────────────────┐
│                    零代价代理指标分类                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐     │
│  │ 梯度相关    │  │ 拓扑相关     │  │ 信息论相关   │     │
│  ├──────────────┤  ├──────────────┤  ├──────────────┤     │
│  │ • SNIP      │  │ • zen-Score  │  │ • Entropy    │     │
│  │ • GraSP     │  │ • FLOPs      │  │ • Mutual Inf │     │
│  │ • SynFlow   │  │ • Params     │  │ • O-inf      │     │
│  │ • NASWOT    │  │ • grad_norm  │  │ • I-inf      │     │
│  └──────────────┘  └──────────────┘  └──────────────┘     │
│                                                             │
│  ┌──────────────────────────────────────────────────────┐  │
│  │                    随机预言                            │  │
│  ├──────────────────────────────────────────────────────┤  │
│  │ • 随机权重精度  • 边缘分布差异  • 梯度相关性           │  │
│  └──────────────────────────────────────────────────────┘  │
└─────────────────────────────────────────────────────────────┘

2. 梯度相关代理指标

2.1 SNIP(Single-shot Network Pruning)

论文:Single-shot Network Pruning based on Connection Sensitivity2

核心思想:计算每个权重对损失的敏感度,反映其重要性。

敏感度定义

使用一阶泰勒展开近似:

SNIP指标计算

def snip_score(model, train_loader, criterion):
    """计算SNIP敏感度分数"""
    model.zero_grad()
    
    # 收集梯度
    grads = {}
    for input, target in train_loader:
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        
        for name, param in model.named_parameters():
            if param.grad is not None:
                if name not in grads:
                    grads[name] = []
                grads[name].append(
                    (param.grad * param).abs().detach()
                )
    
    # 汇总敏感度
    scores = {}
    for name, grad_list in grads.items():
        # 敏感度 = 梯度 × 权重 的绝对值
        scores[name] = torch.stack(grad_list).mean(dim=0)
    
    return scores

架构级别聚合

其中 是架构 的边集合。

2.2 GraSP(Gradient Signal Preservation)

论文:PRuning by Gradient Signal Preservation3

核心思想:选择保留最多梯度信息的权重/操作。

优化目标

其中 是Hessian矩阵, 是掩码向量。

近似计算

由于Hessian计算代价高,使用Fisher信息矩阵近似:

2.3 SynFlow(Synaptic Flow)

论文:Synaptic Flow: Pruning at Initialization4

核心思想:网络的”流动能力”决定了其表达能力。

递归信息流

设网络为层间连接矩阵的乘积:

定理:无论网络结构和权重如何初始化,只要网络是全初始化的(无稀疏),递归信息流保持恒定:

其中 为全1矩阵。

SynFlow分数

def synflow_score(model):
    """计算SynFlow分数"""
    # 重置所有参数为1
    for param in model.parameters():
        param.data.fill_(1.0)
    
    # 递归反向传播计算分数
    def backward_synflow(module):
        if hasattr(module, 'weight') and module.weight is not None:
            grad = module.weight.grad
            if grad is None:
                module.weight.grad = (module.weight * torch.ones_like(module.weight)).data
            else:
                module.weight.grad = (module.weight * grad).data
        
        for child in module.children():
            backward_synflow(child)
    
    # 前向传播创建计算图
    output = model(torch.ones_like(input))
    output.sum().backward()
    
    # 计算分数
    scores = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            scores[name] = (param.abs() * param.grad.abs()).detach()
    
    return scores

2.4 NASWOT(NAS without Training)

论文:Neural Architecture Search without Training5

核心思想:利用标签不可知的激活相关性预测性能。

NAK(Neural Architecture Knowledge)

定义架构 的NKN矩阵:

其中 是第 个样本的隐藏激活。

目标函数

是由参数 控制的预测NKN。

最终指标

高行列式意味着节点间信息传递更独立。


3. 拓扑相关代理指标

3.1 zen-Score

论文:zen-NAS6

核心思想:衡量网络的信息瓶颈能力。

zen定义

其中 是第 次随机前向传播的中间激活。

直觉:信息瓶颈越强,网络压缩信息的能力越好,表达能力越强。

def zen_score(model, input, n_samples=100):
    """计算zen-NAS分数"""
    model.eval()
    
    # 注册hook收集中间激活
    activations = []
    handles = []
    
    def hook_fn(module, input, output):
        activations.append(output.detach())
    
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            handles.append(module.register_forward_hook(hook_fn))
    
    # 多次随机前向传播
    with torch.no_grad():
        for _ in range(n_samples):
            # 随机dropout模拟随机性
            noise_input = input + torch.randn_like(input) * 0.1
            model(noise_input)
    
    # 移除hook
    for handle in handles:
        handle.remove()
    
    # 计算方差
    stacked = torch.stack(activations)
    variances = torch.var(stacked, dim=0).mean()
    
    return variances.item()

3.2 参数量与FLOPs

最简单的代理指标,但相关性有限:

def compute_flops(model, input_shape):
    """计算FLOPs"""
    from thop import profile
    flops, params = profile(model, inputs=(torch.randn(input_shape),))
    return flops, params

3.3 Grad-Norm

核心思想:梯度范数反映参数的重要性。


4. 信息论相关代理指标

4.1 O-Information(O-inf)

核心思想:衡量特征的冗余/协同信息。

部分信息分解(PID)

对于多个变量 和目标 ,信息可分解为:

O-Information定义

其中 是在去掉 后的互信息。

4.2 熵基代理

def entropy_proxy(model, train_loader):
    """基于激活熵的代理指标"""
    entropies = []
    
    for input, _ in train_loader:
        output = model(input)
        
        # softmax后的熵
        probs = F.softmax(output, dim=-1)
        entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
        entropies.append(entropy.item())
    
    return np.mean(entropies)

5. 随机预言代理

5.1 随机权重精度

核心思想:随机初始化权重下网络的预测精度。

虽然直觉上不可信,但研究发现7随机权重精度与训练后精度存在正相关

┌────────────────────────────────────────┐
│   随机权重精度 vs 真实精度 (CIFAR-10)    │
│                                        │
│   真实精度                               │
│     ↑      ●  ●                        │
│     │    ●      ●                       │
│     │  ●          ●                    │
│     │●               ●                │
│     └────────────────────────→          │
│           随机权重精度                    │
└────────────────────────────────────────┘

5.2 边缘分布差异

衡量训练样本与测试样本在网络激活空间的分布差异:

其中 MMD 是最大均值差异。


6. AZ-NAS:多代理组合

6.1 核心思想

单一代理指标往往不够准确,AZ-NAS8提出组合多个代理指标的框架。

组合框架

其中 是代理指标集合, 是组合权重。

6.2 代理相关性分析

代理SNIPGraSPSynFlowzen
SNIP1.000.720.450.38
GraSP0.721.000.520.41
SynFlow0.450.521.000.33
zen0.380.410.331.00

观察:不同代理指标具有互补性,组合后效果更好。

6.3 AZ-NAS PyTorch实现

class AZNASScorer:
    """AZ-NAS: 多代理组合打分器"""
    def __init__(self, model, train_loader, device='cuda'):
        self.model = model
        self.train_loader = train_loader
        self.device = device
        self.proxy_metrics = {
            'snip': self._snip_score,
            'synflow': self._synflow_score,
            'zen': self._zen_score,
            'grad_norm': self._grad_norm_score,
        }
        self.weights = self._learn_weights()
    
    def _snip_score(self):
        """SNIP分数"""
        # 收集梯度
        ...
        return scores
    
    def _synflow_score(self):
        """SynFlow分数"""
        ...
        return scores
    
    def _zen_score(self):
        """zen分数"""
        ...
        return scores
    
    def _grad_norm_score(self):
        """梯度范数分数"""
        ...
        return scores
    
    def _learn_weights(self):
        """学习代理权重"""
        # 使用NAS-Bench-201等基准数据集
        # 学习最小二乘权重
        from sklearn.linear_model import Ridge
        
        # 收集所有架构的代理分数和真实准确率
        X, y = [], []
        for arch in benchmark_architectures:
            scores = self.compute_all_scores(arch)
            X.append(list(scores.values()))
            y.append(arch.true_accuracy)
        
        # 拟合线性组合
        ridge = Ridge(alpha=0.1)
        ridge.fit(X, y)
        
        # 归一化权重
        weights = {
            name: w for name, w in zip(self.proxy_metrics.keys(), ridge.coef_)
        }
        weights = {k: v / sum(abs(w) for w in weights.values()) 
                   for k, v in weights.items()}
        
        return weights
    
    def compute_all_scores(self, arch):
        """计算所有代理分数"""
        scores = {}
        for name, func in self.proxy_metrics.items():
            scores[name] = func(arch)
        return scores
    
    def final_score(self, arch):
        """最终组合分数"""
        scores = self.compute_all_scores(arch)
        return sum(self.weights[m] * scores[m] for m in scores)

7. 零代价代理对比分析

7.1 在NAS-Bench-201上的性能

代理Kendall’s τSpearman’s ρ与训练时间的相关性
Params0.120.18
FLOPs0.150.21
SNIP0.310.42
GraSP0.330.45
SynFlow0.280.39
zen0.350.48
NASWOT0.380.51
AZ-NAS0.520.68

7.2 计算效率对比

方法计算时间需要的操作
完整训练~1 GPU day训练整个网络
DARTS~0.5 GPU day联合优化
SNIP~10 seconds单次前向+反向
SynFlow~30 secondsL次迭代
zen~60 secondsN次随机前向
AZ-NAS~2 minutes所有代理求和

7.3 选择指南

┌─────────────────────────────────────────────────────────────┐
│                   零代价代理选择指南                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ◉ 追求最高准确性 → AZ-NAS(多代理组合)                     │
│                                                             │
│  ◉ 追求实现简便 → zen-Score(单一代理,效果好)              │
│                                                             │
│  ◉ 追求计算效率 → SynFlow(仅需初始化修改)                  │
│                                                             │
│  ◉ 搜索空间较小 → SNIP + NASWOT 组合                        │
│                                                             │
│  ◉ 移动端部署 → FLOPs + 延迟联合指标                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

8. 实践代码

8.1 完整零代价NAS框架

class ZeroCostNAS:
    """完整的零代价NAS框架"""
    def __init__(self, supernet, train_loader, device='cuda'):
        self.supernet = supernet
        self.train_loader = train_loader
        self.device = device
        self.scorer = AZNASScorer(supernet, train_loader, device)
    
    def search(self, candidates, top_k=10):
        """搜索最优架构"""
        scores = []
        
        for candidate in candidates:
            score = self.scorer.final_score(candidate)
            scores.append((candidate, score))
        
        # 排序并返回top-k
        scores.sort(key=lambda x: x[1], reverse=True)
        return scores[:top_k]
    
    def get_top_architecture(self, n_iterations=100, batch_size=20):
        """迭代搜索"""
        best_arch = None
        best_score = float('-inf')
        
        for _ in range(n_iterations):
            # 采样候选架构
            candidates = [self._sample_architecture() 
                         for _ in range(batch_size)]
            
            # 评估
            results = self.search(candidates, top_k=1)
            top_arch, top_score = results[0]
            
            # 更新最优
            if top_score > best_score:
                best_arch = top_arch
                best_score = top_score
        
        return best_arch
    
    def _sample_architecture(self):
        """采样一个随机架构"""
        # 采样细胞结构
        ...

参考文献

Footnotes

  1. Li J, Liu Y, Chen J, et al. Zero-Cost Neural Architecture Search: A Comprehensive Survey. IEEE TPAMI 2024.

  2. Lee N, Ajanthan T, Torr P. SNIP: Single-shot Network Pruning based on Connection Sensitivity. ICLR 2019.

  3. Wang Z, Wohlby P, O’Connor G, et al. GRASP: Pruning by Gradient Signal Preservation. arXiv 2019.

  4. Tanaka H, Kunin D, Yamins D, Ganguli S. Pruning at Initialization via Synaptic Flow. ICLR 2020.

  5. Mellor J, Turner J, Storkey A, Crowley E. Neural Architecture Search without Training. ICML 2021.

  6. Lin MX, Wang PF, Han H, et al. zen-NAS: Zero-shot Neural Architecture Search. ICML 2021.

  7. Shah H, Nguyen K, Tagliasacchi M, et al. NAS evaluation is unfortunately hard. ICML 2020.

  8. Lee S, Ham B. AZ-NAS: Automated Zero-Cost NAS with Stochastic Combination of Proxy Metrics. NeurIPS 2024.