Zero-Shot NAS方法综述

1. 概述

Zero-Shot NAS(零代价NAS)是一类无需训练即可预测架构性能的NAS方法1。传统NAS方法需要完整训练每个候选架构,计算成本高昂,而Zero-Shot方法通过设计代理指标(Proxy Metrics)来大幅降低评估成本。

核心思想: 架构的”可训练性”与其最终性能相关,因此可以通过训练前的某些统计量来预测性能。


2. 梯度度量方法

2.1 SNIP (Single-shot Network Pruning)

论文: Singh et al., “SNIP: Single-shot Network Pruning Based on Connection Sensitivity” (ICLR 2019, adapted for NAS)

核心指标: 基于梯度-权重乘积的连接重要性

原理: 梯度的幅度反映了权重对损失的影响程度。高敏感度的连接在训练中更重要。

算法流程:

def snip_score(model):
    """计算SNIP重要性分数"""
    # 获取数据和损失
    x, y = next(iter(dataloader))
    output = model(x)
    loss = criterion(output, y)
    
    # 计算梯度
    gradients = torch.autograd.grad(loss, model.parameters())
    
    # 计算敏感性分数
    scores = []
    for w, g in zip(model.parameters(), gradients):
        scores.append(torch.abs(w * g))
    
    return scores

2.2 GraSP (Gradient Signal Preservation)

论文: Wang et al., “Greedy Optimization Provably Wins the Lottery” (ICML 2020)

核心指标: 基于梯度流重要性的度量

原理: 保留那些对梯度流贡献大的权重,移除贡献小的权重。

与SNIP的区别:

特性SNIPGraSP
优化目标保持损失不变保持梯度流
公式乘积形式平方形式
理论基础敏感性分析Hessian近似

2.3 Synflow (Synthetic Gradient)

论文: Tanaka et al., “A Theory and Neural Architecture Search Based Approach” (NeurIPS 2020)

核心指标: 基于网络流敏感性的度量

原理: 忽略数据分布,通过计算合成输入下的梯度流来评估架构。

合成梯度方法:

def synflow_score(model):
    """计算Synflow重要性分数"""
    @torch.no_grad()
    def linearize(model):
        """线性化模型"""
        signs = {}
        for name, param in model.named_parameters():
            signs[name] = torch.sign(param)
            param.abs_()
        return signs
    
    @torch.no_grad()
    def nonlinearize(model, signs):
        """恢复原始符号"""
        for name, param in model.named_parameters():
            param.mul_(signs[name])
    
    signs = linearize(model)
    output = model(torch.ones_like(next(iter(dataloader))[0]))
    loss = output.sum()
    gradients = torch.autograd.grad(loss, model.parameters())
    nonlinearize(model, signs)
    
    return [g.abs() for g in gradients]

3. NTK分析方法

3.1 理论基础

神经正切核(Neural Tangent Kernel, NTK)理论为Zero-Shot NAS提供了理论基础。

NTK定义: 在无限宽度极限下,神经网络的学习动态可以解析地由NTK描述:

3.2 NTK基代理

论文: Chen et al., “Neural Architecture Search with Karush-Kuhn-Tucker Guidance” (ICLR 2022)

核心思想: 使用NTK的特征值分布作为架构性能指标。

代理指标:

其中是最大特征值,是条件数。

3.3 NTK范数方法

论文: Park et al., “A Unified Approach to Neural Architecture Search” (ICLR 2022)

核心指标: NTK的F范数

直觉: 更大的NTK范数意味着更强的函数空间覆盖能力。


4. 频谱分析方法

4.1 特征值分布分析

核心思想: 权重矩阵的谱(特征值分布)反映了网络的表达能力。

代理指标:

其中是权重矩阵的归一化特征值分布,是熵函数。

4.2 谱熵与性能相关性

实验发现:适当大小的谱熵与架构性能正相关。

  • 低谱熵:权重高度集中在少数方向,表达能力受限
  • 高谱熵:权重分布过于均匀,可能缺乏特异性
  • 适中谱熵:平衡表达能力和特异性

4.3 Dextr方法

论文: Asthana et al., “Dextr: Zero-Shot NAS with SVD and Extrinsic Curvature” (arXiv 2508.12977)

核心创新: 结合SVD和外部曲率分析

其中是损失景观的外部曲率。


5. 因果推断方法

5.1 因果视角的NAS问题

问题定义: 传统方法混淆了架构与权重的因果效应。

因果图:

其中是架构,是权重,是性能。

5.2 从One-Shot到Zero-Shot的因果分析

论文: “From One to Zero: Causal Zero-Shot NAS” (NeurIPS 2023)

核心贡献: 使用do-calculus解耦架构效应

内在代理(Intrinsic Proxy):

5.3 因果零代价框架

框架组成:

  1. 因果结构学习: 识别架构-权重-性能因果链
  2. 介入评估: 使用do-calculus估计架构的因果效应
  3. 代理学习: 学习代理指标到因果效应的映射

6. 其他零代价代理方法

6.1 NASWOT (NAS Without Training)

论文: Abdelfattah et al., “NAS-WOT: Neural Architecture Search Without Training” (ICML 2020)

核心指标: 基于激活相关性的度量

6.2 Buter

改进: 使用互信息而非相关性

6.3 TE-NAS (Training-free)

核心思想: 结合染色数和网络流分析


7. 方法比较与理论分析

7.1 方法对比表

方法理论基础数据依赖计算复杂度精度
SNIP敏感性分析中等
GraSPHessian分析中等
Synflow网络流较高
NTK基核理论较高
频谱谱理论中等
因果推断因果理论部分
Dextr曲率分析待验证

7.2 理论局限性

  1. 代理-性能相关性假设: 假设可训练性=性能,可能不总是成立
  2. 数据依赖性: 多数方法仍依赖数据分布
  3. 尺度问题: 小规模验证的结论可能不适用于大规模
  4. 分布偏移: 搜索空间和评估任务的差异

8. 实践指南

8.1 方法选择建议

场景推荐方法原因
快速筛选Synflow无数据依赖
高精度NTK基/因果推断理论更完善
资源受限SNIP计算简单
大规模搜索Dextr可并行化

8.2 实现注意事项

class ZeroShotNAS:
    def __init__(self, method='synflow'):
        self.method = method
        
    def score_architecture(self, model, dataloader):
        if self.method == 'snip':
            return self._snip_score(model, dataloader)
        elif self.method == 'synflow':
            return self._synflow_score(model)
        elif self.method == 'ntk':
            return self._ntk_score(model, dataloader)
    
    def _synflow_score(self, model):
        """Synflow方法:推荐使用"""
        # 线性化 + 合成梯度
        # 无需真实数据
        pass

8.3 评估协议

  1. 在NAS-Bench-201等基准上验证代理质量
  2. 计算代理分数与真实准确率的Rank相关性(Spearman/ Kendall)
  3. 评估Top-K选择准确率

9. 相关主题


参考文献

Footnotes

  1. Abdelfattah, M. S., et al. (2020). NAS-WOT: Neural Architecture Search Without Training. ICML 2020.