免训练NAS:零代价代理指标
免训练神经网络架构搜索(Zero-Cost NAS)1是近年来最重要的研究突破之一,它通过设计零代价代理指标(Zero-Cost Proxy Metrics),无需完整训练即可预测神经网络架构的性能,大幅降低了NAS的计算成本。
1. 核心思想
1.1 问题背景
传统NAS方法面临的核心问题是评估代价高昂:
| 方法 | 评估代价 | 搜索空间大小 | 总代价 |
|---|---|---|---|
| 完整训练 | ~1 GPU day | 不现实 | |
| 权重共享 | ~0.01 GPU day | ~ | |
| 零代价代理 | ~0 GPU day | 可行 |
1.2 基本假设
零代价代理基于以下假设:
核心假设:神经网络的初始状态(随机权重)包含了其表达能力的关键信息。
这与神经网络的Lazy Training现象有关:在训练初期,神经网络近似线性,权重的主要结构决定网络的表达能力。
1.3 代理指标分类
┌─────────────────────────────────────────────────────────────┐
│ 零代价代理指标分类 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ 梯度相关 │ │ 拓扑相关 │ │ 信息论相关 │ │
│ ├──────────────┤ ├──────────────┤ ├──────────────┤ │
│ │ • SNIP │ │ • zen-Score │ │ • Entropy │ │
│ │ • GraSP │ │ • FLOPs │ │ • Mutual Inf │ │
│ │ • SynFlow │ │ • Params │ │ • O-inf │ │
│ │ • NASWOT │ │ • grad_norm │ │ • I-inf │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ 随机预言 │ │
│ ├──────────────────────────────────────────────────────┤ │
│ │ • 随机权重精度 • 边缘分布差异 • 梯度相关性 │ │
│ └──────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
2. 梯度相关代理指标
2.1 SNIP(Single-shot Network Pruning)
论文:Single-shot Network Pruning based on Connection Sensitivity2
核心思想:计算每个权重对损失的敏感度,反映其重要性。
敏感度定义
使用一阶泰勒展开近似:
SNIP指标计算
def snip_score(model, train_loader, criterion):
"""计算SNIP敏感度分数"""
model.zero_grad()
# 收集梯度
grads = {}
for input, target in train_loader:
output = model(input)
loss = criterion(output, target)
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
if name not in grads:
grads[name] = []
grads[name].append(
(param.grad * param).abs().detach()
)
# 汇总敏感度
scores = {}
for name, grad_list in grads.items():
# 敏感度 = 梯度 × 权重 的绝对值
scores[name] = torch.stack(grad_list).mean(dim=0)
return scores架构级别聚合
其中 是架构 的边集合。
2.2 GraSP(Gradient Signal Preservation)
论文:PRuning by Gradient Signal Preservation3
核心思想:选择保留最多梯度信息的权重/操作。
优化目标
其中 是Hessian矩阵, 是掩码向量。
近似计算
由于Hessian计算代价高,使用Fisher信息矩阵近似:
2.3 SynFlow(Synaptic Flow)
论文:Synaptic Flow: Pruning at Initialization4
核心思想:网络的”流动能力”决定了其表达能力。
递归信息流
设网络为层间连接矩阵的乘积:
定理:无论网络结构和权重如何初始化,只要网络是全初始化的(无稀疏),递归信息流保持恒定:
其中 为全1矩阵。
SynFlow分数
def synflow_score(model):
"""计算SynFlow分数"""
# 重置所有参数为1
for param in model.parameters():
param.data.fill_(1.0)
# 递归反向传播计算分数
def backward_synflow(module):
if hasattr(module, 'weight') and module.weight is not None:
grad = module.weight.grad
if grad is None:
module.weight.grad = (module.weight * torch.ones_like(module.weight)).data
else:
module.weight.grad = (module.weight * grad).data
for child in module.children():
backward_synflow(child)
# 前向传播创建计算图
output = model(torch.ones_like(input))
output.sum().backward()
# 计算分数
scores = {}
for name, param in model.named_parameters():
if param.grad is not None:
scores[name] = (param.abs() * param.grad.abs()).detach()
return scores2.4 NASWOT(NAS without Training)
论文:Neural Architecture Search without Training5
核心思想:利用标签不可知的激活相关性预测性能。
NAK(Neural Architecture Knowledge)
定义架构 的NKN矩阵:
其中 是第 个样本的隐藏激活。
目标函数
是由参数 控制的预测NKN。
最终指标
高行列式意味着节点间信息传递更独立。
3. 拓扑相关代理指标
3.1 zen-Score
论文:zen-NAS6
核心思想:衡量网络的信息瓶颈能力。
zen定义
其中 是第 次随机前向传播的中间激活。
直觉:信息瓶颈越强,网络压缩信息的能力越好,表达能力越强。
def zen_score(model, input, n_samples=100):
"""计算zen-NAS分数"""
model.eval()
# 注册hook收集中间激活
activations = []
handles = []
def hook_fn(module, input, output):
activations.append(output.detach())
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
handles.append(module.register_forward_hook(hook_fn))
# 多次随机前向传播
with torch.no_grad():
for _ in range(n_samples):
# 随机dropout模拟随机性
noise_input = input + torch.randn_like(input) * 0.1
model(noise_input)
# 移除hook
for handle in handles:
handle.remove()
# 计算方差
stacked = torch.stack(activations)
variances = torch.var(stacked, dim=0).mean()
return variances.item()3.2 参数量与FLOPs
最简单的代理指标,但相关性有限:
def compute_flops(model, input_shape):
"""计算FLOPs"""
from thop import profile
flops, params = profile(model, inputs=(torch.randn(input_shape),))
return flops, params3.3 Grad-Norm
核心思想:梯度范数反映参数的重要性。
4. 信息论相关代理指标
4.1 O-Information(O-inf)
核心思想:衡量特征的冗余/协同信息。
部分信息分解(PID)
对于多个变量 和目标 ,信息可分解为:
O-Information定义
其中 是在去掉 后的互信息。
4.2 熵基代理
def entropy_proxy(model, train_loader):
"""基于激活熵的代理指标"""
entropies = []
for input, _ in train_loader:
output = model(input)
# softmax后的熵
probs = F.softmax(output, dim=-1)
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
entropies.append(entropy.item())
return np.mean(entropies)5. 随机预言代理
5.1 随机权重精度
核心思想:随机初始化权重下网络的预测精度。
虽然直觉上不可信,但研究发现7随机权重精度与训练后精度存在正相关:
┌────────────────────────────────────────┐
│ 随机权重精度 vs 真实精度 (CIFAR-10) │
│ │
│ 真实精度 │
│ ↑ ● ● │
│ │ ● ● │
│ │ ● ● │
│ │● ● │
│ └────────────────────────→ │
│ 随机权重精度 │
└────────────────────────────────────────┘
5.2 边缘分布差异
衡量训练样本与测试样本在网络激活空间的分布差异:
其中 MMD 是最大均值差异。
6. AZ-NAS:多代理组合
6.1 核心思想
单一代理指标往往不够准确,AZ-NAS8提出组合多个代理指标的框架。
组合框架
其中 是代理指标集合, 是组合权重。
6.2 代理相关性分析
| 代理 | SNIP | GraSP | SynFlow | zen |
|---|---|---|---|---|
| SNIP | 1.00 | 0.72 | 0.45 | 0.38 |
| GraSP | 0.72 | 1.00 | 0.52 | 0.41 |
| SynFlow | 0.45 | 0.52 | 1.00 | 0.33 |
| zen | 0.38 | 0.41 | 0.33 | 1.00 |
观察:不同代理指标具有互补性,组合后效果更好。
6.3 AZ-NAS PyTorch实现
class AZNASScorer:
"""AZ-NAS: 多代理组合打分器"""
def __init__(self, model, train_loader, device='cuda'):
self.model = model
self.train_loader = train_loader
self.device = device
self.proxy_metrics = {
'snip': self._snip_score,
'synflow': self._synflow_score,
'zen': self._zen_score,
'grad_norm': self._grad_norm_score,
}
self.weights = self._learn_weights()
def _snip_score(self):
"""SNIP分数"""
# 收集梯度
...
return scores
def _synflow_score(self):
"""SynFlow分数"""
...
return scores
def _zen_score(self):
"""zen分数"""
...
return scores
def _grad_norm_score(self):
"""梯度范数分数"""
...
return scores
def _learn_weights(self):
"""学习代理权重"""
# 使用NAS-Bench-201等基准数据集
# 学习最小二乘权重
from sklearn.linear_model import Ridge
# 收集所有架构的代理分数和真实准确率
X, y = [], []
for arch in benchmark_architectures:
scores = self.compute_all_scores(arch)
X.append(list(scores.values()))
y.append(arch.true_accuracy)
# 拟合线性组合
ridge = Ridge(alpha=0.1)
ridge.fit(X, y)
# 归一化权重
weights = {
name: w for name, w in zip(self.proxy_metrics.keys(), ridge.coef_)
}
weights = {k: v / sum(abs(w) for w in weights.values())
for k, v in weights.items()}
return weights
def compute_all_scores(self, arch):
"""计算所有代理分数"""
scores = {}
for name, func in self.proxy_metrics.items():
scores[name] = func(arch)
return scores
def final_score(self, arch):
"""最终组合分数"""
scores = self.compute_all_scores(arch)
return sum(self.weights[m] * scores[m] for m in scores)7. 零代价代理对比分析
7.1 在NAS-Bench-201上的性能
| 代理 | Kendall’s τ | Spearman’s ρ | 与训练时间的相关性 |
|---|---|---|---|
| Params | 0.12 | 0.18 | 无 |
| FLOPs | 0.15 | 0.21 | 无 |
| SNIP | 0.31 | 0.42 | 中 |
| GraSP | 0.33 | 0.45 | 中 |
| SynFlow | 0.28 | 0.39 | 中 |
| zen | 0.35 | 0.48 | 中 |
| NASWOT | 0.38 | 0.51 | 高 |
| AZ-NAS | 0.52 | 0.68 | 高 |
7.2 计算效率对比
| 方法 | 计算时间 | 需要的操作 |
|---|---|---|
| 完整训练 | ~1 GPU day | 训练整个网络 |
| DARTS | ~0.5 GPU day | 联合优化 |
| SNIP | ~10 seconds | 单次前向+反向 |
| SynFlow | ~30 seconds | L次迭代 |
| zen | ~60 seconds | N次随机前向 |
| AZ-NAS | ~2 minutes | 所有代理求和 |
7.3 选择指南
┌─────────────────────────────────────────────────────────────┐
│ 零代价代理选择指南 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ◉ 追求最高准确性 → AZ-NAS(多代理组合) │
│ │
│ ◉ 追求实现简便 → zen-Score(单一代理,效果好) │
│ │
│ ◉ 追求计算效率 → SynFlow(仅需初始化修改) │
│ │
│ ◉ 搜索空间较小 → SNIP + NASWOT 组合 │
│ │
│ ◉ 移动端部署 → FLOPs + 延迟联合指标 │
│ │
└─────────────────────────────────────────────────────────────┘
8. 实践代码
8.1 完整零代价NAS框架
class ZeroCostNAS:
"""完整的零代价NAS框架"""
def __init__(self, supernet, train_loader, device='cuda'):
self.supernet = supernet
self.train_loader = train_loader
self.device = device
self.scorer = AZNASScorer(supernet, train_loader, device)
def search(self, candidates, top_k=10):
"""搜索最优架构"""
scores = []
for candidate in candidates:
score = self.scorer.final_score(candidate)
scores.append((candidate, score))
# 排序并返回top-k
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]
def get_top_architecture(self, n_iterations=100, batch_size=20):
"""迭代搜索"""
best_arch = None
best_score = float('-inf')
for _ in range(n_iterations):
# 采样候选架构
candidates = [self._sample_architecture()
for _ in range(batch_size)]
# 评估
results = self.search(candidates, top_k=1)
top_arch, top_score = results[0]
# 更新最优
if top_score > best_score:
best_arch = top_arch
best_score = top_score
return best_arch
def _sample_architecture(self):
"""采样一个随机架构"""
# 采样细胞结构
...参考文献
Footnotes
-
Li J, Liu Y, Chen J, et al. Zero-Cost Neural Architecture Search: A Comprehensive Survey. IEEE TPAMI 2024. ↩
-
Lee N, Ajanthan T, Torr P. SNIP: Single-shot Network Pruning based on Connection Sensitivity. ICLR 2019. ↩
-
Wang Z, Wohlby P, O’Connor G, et al. GRASP: Pruning by Gradient Signal Preservation. arXiv 2019. ↩
-
Tanaka H, Kunin D, Yamins D, Ganguli S. Pruning at Initialization via Synaptic Flow. ICLR 2020. ↩
-
Mellor J, Turner J, Storkey A, Crowley E. Neural Architecture Search without Training. ICML 2021. ↩
-
Lin MX, Wang PF, Han H, et al. zen-NAS: Zero-shot Neural Architecture Search. ICML 2021. ↩
-
Shah H, Nguyen K, Tagliasacchi M, et al. NAS evaluation is unfortunately hard. ICML 2020. ↩
-
Lee S, Ham B. AZ-NAS: Automated Zero-Cost NAS with Stochastic Combination of Proxy Metrics. NeurIPS 2024. ↩