NAS基准测试
NAS基准测试库是解决神经架构搜索可复现性问题的关键基础设施。本文档详细介绍主流NAS基准的设计、特点和使用方法。
1. 问题背景
1.1 可复现性危机
NAS领域面临严重的可复现性危机1:
| 问题类型 | 具体表现 |
|---|---|
| 搜索空间不一致 | 不同论文使用不同搜索空间 |
| 训练配置差异 | 学习率、数据增强等超参数不同 |
| 评估指标不统一 | Top-1/Top-5、原始/微调 |
| 硬件差异 | GPU型号、批量大小差异 |
| 随机种子 | 不同随机初始化影响结果 |
1.2 基准测试的必要性
┌─────────────────────────────────────────────────────────────┐
│ 基准测试的价值 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ✓ 公平比较: 所有方法在同一环境中评估 │
│ ✓ 快速评估: 无需完整训练,快速验证想法 │
│ ✓ 可复现性: 固定数据集,消除随机性 │
│ ✓ 科学进步: 促进方法论改进 │
│ │
└─────────────────────────────────────────────────────────────┘
2. NAS-Bench-101
2.1 概述
NAS-Bench-1012是首个大规模NAS基准,包含超过423,000个预训练网络。
核心特点:
- 固定搜索空间:基于有向无环图(DAG)
- 多任务评估:CIFAR-10, CIFAR-100, ImageNet
- 预训练模型:所有架构已完整训练
2.2 搜索空间定义
# NAS-Bench-101搜索空间
SEARCH_SPACE = {
'num_nodes': 7, # 节点数
'num_edges': 9, # 边数
'ops': [
'none', # 无操作
'skip_connect', # 跳跃连接
'conv_3x3', # 3x3卷积
'conv_3x3_relur', # ReLU后卷积
'conv_1x1', # 1x1卷积
'max_pool_3x3', # 3x3最大池化
'avg_pool_3x3', # 3x3平均池化
],
'max_edges': 9, # 最大边数
}2.3 架构编码
每个架构用矩阵表示:
adjacency matrix (7x7):
┌─────────────────┐
│ 0 1 0 0 0 0 0 │
│ 0 0 1 1 0 0 0 │
│ 0 0 0 1 1 0 0 │
│ 0 0 0 0 1 1 0 │
│ 0 0 0 0 0 1 1 │
│ 0 0 0 0 0 0 1 │
│ 0 0 0 0 0 0 0 │
└─────────────────┘
operation list: [conv_3x3, max_pool, skip, conv_1x1, avg_pool, conv_3x3, none]
2.4 API使用
from nasbench import api as nb_api
# 加载基准
nasbench = nb_api.NASBench101(
'/path/to/nasbench.tfrecord',
师父_cache_size=1000
)
# 查询架构性能
model_spec = nb_api.ModelSpec(
matrix=adjacency_matrix,
ops=operation_list
)
# 获取完整训练历史
metrics = nasbench.query(model_spec)
# 返回结果
{
'train_accuracy': 95.32,
'valid_accuracy': 91.56,
'test_accuracy': 90.23,
'training_time': 1234.5,
'parameters': 2.34e6,
'flops': 45.2e6,
}2.5 使用示例
def compare_nas_methods(nasbench):
"""比较不同NAS方法的性能"""
results = {}
# 1. 随机搜索
random_search_results = random_search(nasbench, n_samples=500)
results['random'] = {
'mean_acc': np.mean([r['test_accuracy'] for r in random_search_results]),
'std_acc': np.std([r['test_accuracy'] for r in random_search_results]),
'best_acc': max([r['test_accuracy'] for r in random_search_results])
}
# 2. 进化算法
evolution_results = evolution_search(nasbench, n_generations=100)
results['evolution'] = {...}
# 3. DARTS
darts_results = darts_search(nasbench, epochs=50)
results['darts'] = {...}
return results
def random_search(nasbench, n_samples=500):
"""随机搜索"""
results = []
for _ in range(n_samples):
# 生成随机架构
model_spec = nasbench.random_resample()
# 查询
metrics = nasbench.query(model_spec)
results.append(metrics)
# 进度
if (len(results)) % 100 == 0:
print(f"Completed {len(results)} samples")
return results3. NAS-Bench-201
3.1 概述
NAS-Bench-2013扩展了搜索空间,支持更复杂的操作和连接模式。
核心特点:
- 4个预训练数据集:CIFAR-10, CIFAR-100, ImageNet-16-120
- 更丰富的操作集:包括注意力机制
- 更小的搜索空间:便于方法比较
3.2 搜索空间定义
# NAS-Bench-201搜索空间
SEARCH_SPACE_201 = {
'num_nodes': 4, # 节点数
'num_ops': 5, # 每个节点的操作数
'ops': [
'none', # skip (via identity)
'skip_connect', # skip (via structure)
'conv_3x3', # 3x3卷积
'conv_1x1', # 1x1卷积
'avg_pool_3x3', # 3x3平均池化
],
}3.3 邻接矩阵表示
节点0 → 节点1 → 节点2 → 节点3
↓ ↓ ↓ ↓
输入 op[1] op[2] op[3] → 输出
邻接矩阵 (4x4):
┌─────────────────┐
│ 0 1 0 0 │ ← 节点0指向节点1
│ 0 0 1 0 │ ← 节点1指向节点2
│ 0 0 0 1 │ ← 节点2指向节点3
│ 0 0 0 0 │ ← 节点3无输出
└─────────────────┘
3.4 API使用
from nas_201_api import NASBench201API
# 加载基准
api = NASBench201API('/path/to/NAS-Bench-201.h5')
# 获取所有架构信息
total_archs = api.get_algo_count()
print(f"Total architectures: {total_archs}") # 15,625
# 查询架构
arch_index = 1234
metrics = api.query_by_index(arch_index, 'cifar10')
# 或通过邻接矩阵查询
adj = [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]]
ops = ['none', 'conv_3x3', 'conv_1x1', 'avg_pool_3x3']
metrics = api.query_meta_info([adj, ops], None, 'cifar10')
# 获取训练曲线
info = api.get_more_info(arch_index, 'cifar10', use_12epochs=False)
# 返回训练时间、验证准确率历史等3.5 性能对比
| 架构 | CIFAR-10 | CIFAR-100 | ImageNet-16 |
|---|---|---|---|
| 最优架构 | 94.37% | 73.51% | 47.31% |
| 次优架构 | 94.27% | 73.33% | 47.13% |
| 平均架构 | 87.23% | 60.42% | 30.11% |
| 随机架构 | 84.23% | 56.72% | 25.82% |
4. NATS-Bench
4.1 概述
NATS-Bench4是NAS-Bench-201的扩展,包含**拓扑(Topology)和大小(Size)**两个搜索空间。
核心特点:
- 拓扑空间:与NAS-Bench-201相同(15,625个架构)
- 大小空间:更大的通道数和深度变化(32,768个配置)
- 更全面的评估:不同训练配置的结果
4.2 大小搜索空间
# NATS-Bench大小搜索空间
SIZE_SPACE = {
'depth': [1, 2, 3, 4], # 网络深度
'width': [8, 16, 24, 32], # 通道数
'num_nodes': 4, # 节点数
'resolution': [32, 64, 96, 128], # 输入分辨率
}4.3 API使用
from nats_bench import NATS
# 加载NATS-Bench
api = NATS('/path/to/NATS-Bench-4-size')
# 查询大小空间
config = {
'depth': 3,
'width': 16,
'num_nodes': 4,
'resolution': 32
}
metrics = api.query(config, 'cifar10')
# 获取拓扑空间性能
api_topology = NATS('/path/to/NATS-Bench-4-topology')
topo_metrics = api_topology.query_by_index(1234, 'cifar10')5. 搜索空间比较
5.1 搜索空间规模
| 基准 | 架构数量 | 搜索空间大小 |
|---|---|---|
| NAS-Bench-101 | 423,624 | |
| NAS-Bench-201 | 15,625 | |
| NATS-Bench-Topo | 15,625 | |
| NATS-Bench-Size | 32,768 |
5.2 操作集比较
| 操作 | NB-101 | NB-201 | NB-301 |
|---|---|---|---|
| none | ✓ | ✓ | ✓ |
| skip_connect | ✓ | ✓ | ✓ |
| conv_3x3 | ✓ | ✓ | ✓ |
| conv_1x1 | ✓ | ✓ | ✓ |
| max_pool | ✓ | ✗ | ✗ |
| avg_pool | ✓ | ✓ | ✓ |
| sep_conv | ✓ | ✗ | ✓ |
| dil_conv | ✓ | ✗ | ✗ |
5.3 评估数据集
| 基准 | CIFAR-10 | CIFAR-100 | ImageNet-16-120 | 其他 |
|---|---|---|---|---|
| NAS-Bench-101 | ✓ | ✓ | ✗ | ✓ |
| NAS-Bench-201 | ✓ | ✓ | ✓ | ✗ |
| NATS-Bench | ✓ | ✓ | ✓ | ✗ |
6. 公平比较实践
6.1 评估协议
class NASBenchmarkEvaluator:
"""标准NAS基准评估器"""
def __init__(self, benchmark_name):
self.benchmark = self._load_benchmark(benchmark_name)
def evaluate(self, method, n_trials=10):
"""评估方法"""
results = []
for trial in range(n_trials):
# 设置随机种子
set_seed(trial)
# 运行方法
arch = method.search()
# 获取性能
if isinstance(self.benchmark, NASBench201API):
metrics = self.benchmark.query_by_index(arch, 'cifar10')
else:
metrics = self.benchmark.query(arch)
results.append({
'trial': trial,
'arch': arch,
'test_acc': metrics['test_accuracy'],
'val_acc': metrics['valid_accuracy'],
'train_time': metrics['training_time'],
})
return self._summarize(results)
def _summarize(self, results):
"""汇总结果"""
return {
'mean_test_acc': np.mean([r['test_acc'] for r in results]),
'std_test_acc': np.std([r['test_acc'] for r in results]),
'mean_train_time': np.mean([r['train_time'] for r in results]),
'best_test_acc': max([r['test_acc'] for r in results]),
}6.2 常用指标
| 指标 | 描述 | 计算方法 |
|---|---|---|
| Best Accuracy | 最优架构准确率 | |
| Mean Accuracy | 平均准确率 | |
| Regret | 与最优的差距 | |
| Correlation | 与真实性能相关性 | Spearman ρ / Kendall τ |
| Efficiency | 搜索效率 |
6.3 方法对比框架
def compare_methods(benchmark, methods, n_trials=10):
"""比较多种NAS方法"""
evaluator = NASBenchmarkEvaluator(benchmark)
comparison = {}
for name, method in methods.items():
print(f"Evaluating {name}...")
results = evaluator.evaluate(method, n_trials=n_trials)
comparison[name] = results
# 打印摘要
print(f" Mean Acc: {results['mean_test_acc']:.2f}% ± {results['std_test_acc']:.2f}%")
print(f" Best Acc: {results['best_test_acc']:.2f}%")
print(f" Mean Time: {results['mean_train_time']:.2f}s")
print()
return comparison
# 使用示例
methods = {
'random': RandomSearch(),
'evolution': EvolutionSearch(population_size=50),
'darts': DARTSearch(epochs=50),
'bayes': BayesianOptimization(),
}
results = compare_methods('nasbench201', methods)7. 基准使用注意事项
7.1 常见陷阱
| 陷阱 | 说明 | 避免方法 |
|---|---|---|
| 查询偏差 | 重复查询同一架构导致过拟合 | 记录查询次数 |
| 超参数不一致 | 不同方法使用不同训练超参数 | 使用统一评估协议 |
| 早停滥用 | 允许早停导致不公平比较 | 固定训练时长 |
| 随机性忽视 | 未报告多次运行的方差 | 多次运行取平均 |
7.2 正确使用方式
# ✓ 正确做法
def proper_usage(api):
# 1. 使用统一的评估配置
eval_config = {
'dataset': 'cifar10',
'epochs': 12, # 或108
'use_cutout': True,
'use_aug': True,
}
# 2. 多次运行取平均
results = []
for seed in range(5):
arch = search_method(seed=seed)
metrics = api.query_by_index(arch, dataset='cifar10')
results.append(metrics)
mean_acc = np.mean([r['test_accuracy'] for r in results])
# 3. 报告完整信息
return {
'mean': mean_acc,
'std': np.std([r['test_accuracy'] for r in results]),
'n_trials': len(results),
}
# ✗ 错误做法
def improper_usage(api):
# 只查询一次
arch = search_method()
metrics = api.query_by_index(arch)
return metrics['test_accuracy']8. 其他基准资源
8.1 特定任务基准
| 基准 | 任务 | 说明 |
|---|---|---|
| NAS-Bench-301 | 超分辨率 | 超分辨率网络搜索 |
| Transformer Bench | NLP | Transformer架构搜索 |
| NAS-Bench-Macro | 检测/分割 | 目标检测/分割 |
| Once-for-All | 多平台 | 跨平台部署 |
8.2 开源工具
# NASBench相关库
pip install nasbench # NAS-Bench-101
pip install nas-201-api # NAS-Bench-201
pip install nats-bench # NATS-Bench
# 使用示例
from nasbench import api as nb
from nas_201_api import NASBench201API
from nats_bench import NATS参考文献
Footnotes
-
Ying CX, Klein A, Christiansen E, et al. NAS-Bench-101: Towards Reproducible Neural Architecture Search. ICML 2019. ↩
-
Dong X, Yang Y. NAS-Bench-201: Extending the Benchmark for Neural Architecture Search. NeurIPS 2020. ↩
-
Dong X, Yang Y. NAS-Bench-201: Extending the Benchmark for Neural Architecture Search. CVPR 2020. ↩
-
Li C, Peng Z, Yuan Z, et al. NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size. CVPR 2022. ↩