NAS基准测试

NAS基准测试库是解决神经架构搜索可复现性问题的关键基础设施。本文档详细介绍主流NAS基准的设计、特点和使用方法。


1. 问题背景

1.1 可复现性危机

NAS领域面临严重的可复现性危机1

问题类型具体表现
搜索空间不一致不同论文使用不同搜索空间
训练配置差异学习率、数据增强等超参数不同
评估指标不统一Top-1/Top-5、原始/微调
硬件差异GPU型号、批量大小差异
随机种子不同随机初始化影响结果

1.2 基准测试的必要性

┌─────────────────────────────────────────────────────────────┐
│              基准测试的价值                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ✓ 公平比较: 所有方法在同一环境中评估                        │
│  ✓ 快速评估: 无需完整训练,快速验证想法                      │
│  ✓ 可复现性: 固定数据集,消除随机性                          │
│  ✓ 科学进步: 促进方法论改进                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. NAS-Bench-101

2.1 概述

NAS-Bench-1012是首个大规模NAS基准,包含超过423,000个预训练网络。

核心特点

  • 固定搜索空间:基于有向无环图(DAG)
  • 多任务评估:CIFAR-10, CIFAR-100, ImageNet
  • 预训练模型:所有架构已完整训练

2.2 搜索空间定义

# NAS-Bench-101搜索空间
SEARCH_SPACE = {
    'num_nodes': 7,           # 节点数
    'num_edges': 9,           # 边数
    'ops': [
        'none',              # 无操作
        'skip_connect',      # 跳跃连接
        'conv_3x3',         # 3x3卷积
        'conv_3x3_relur',   # ReLU后卷积
        'conv_1x1',         # 1x1卷积
        'max_pool_3x3',     # 3x3最大池化
        'avg_pool_3x3',     # 3x3平均池化
    ],
    'max_edges': 9,           # 最大边数
}

2.3 架构编码

每个架构用矩阵表示:

 adjacency matrix (7x7):
 ┌─────────────────┐
 │ 0 1 0 0 0 0 0 │
 │ 0 0 1 1 0 0 0 │
 │ 0 0 0 1 1 0 0 │
 │ 0 0 0 0 1 1 0 │
 │ 0 0 0 0 0 1 1 │
 │ 0 0 0 0 0 0 1 │
 │ 0 0 0 0 0 0 0 │
 └─────────────────┘
 
 operation list: [conv_3x3, max_pool, skip, conv_1x1, avg_pool, conv_3x3, none]

2.4 API使用

from nasbench import api as nb_api
 
# 加载基准
nasbench = nb_api.NASBench101(
    '/path/to/nasbench.tfrecord',
   师父_cache_size=1000
)
 
# 查询架构性能
model_spec = nb_api.ModelSpec(
    matrix=adjacency_matrix,
    ops=operation_list
)
 
# 获取完整训练历史
metrics = nasbench.query(model_spec)
 
# 返回结果
{
    'train_accuracy': 95.32,
    'valid_accuracy': 91.56,
    'test_accuracy': 90.23,
    'training_time': 1234.5,
    'parameters': 2.34e6,
    'flops': 45.2e6,
}

2.5 使用示例

def compare_nas_methods(nasbench):
    """比较不同NAS方法的性能"""
    results = {}
    
    # 1. 随机搜索
    random_search_results = random_search(nasbench, n_samples=500)
    results['random'] = {
        'mean_acc': np.mean([r['test_accuracy'] for r in random_search_results]),
        'std_acc': np.std([r['test_accuracy'] for r in random_search_results]),
        'best_acc': max([r['test_accuracy'] for r in random_search_results])
    }
    
    # 2. 进化算法
    evolution_results = evolution_search(nasbench, n_generations=100)
    results['evolution'] = {...}
    
    # 3. DARTS
    darts_results = darts_search(nasbench, epochs=50)
    results['darts'] = {...}
    
    return results
 
def random_search(nasbench, n_samples=500):
    """随机搜索"""
    results = []
    
    for _ in range(n_samples):
        # 生成随机架构
        model_spec = nasbench.random_resample()
        
        # 查询
        metrics = nasbench.query(model_spec)
        results.append(metrics)
        
        # 进度
        if (len(results)) % 100 == 0:
            print(f"Completed {len(results)} samples")
    
    return results

3. NAS-Bench-201

3.1 概述

NAS-Bench-2013扩展了搜索空间,支持更复杂的操作和连接模式。

核心特点

  • 4个预训练数据集:CIFAR-10, CIFAR-100, ImageNet-16-120
  • 更丰富的操作集:包括注意力机制
  • 更小的搜索空间:便于方法比较

3.2 搜索空间定义

# NAS-Bench-201搜索空间
SEARCH_SPACE_201 = {
    'num_nodes': 4,           # 节点数
    'num_ops': 5,            # 每个节点的操作数
    'ops': [
        'none',              # skip (via identity)
        'skip_connect',      # skip (via structure)
        'conv_3x3',         # 3x3卷积
        'conv_1x1',         # 1x1卷积
        'avg_pool_3x3',     # 3x3平均池化
    ],
}

3.3 邻接矩阵表示

节点0 → 节点1 → 节点2 → 节点3
   ↓        ↓        ↓        ↓
  输入    op[1]    op[2]    op[3]   → 输出

邻接矩阵 (4x4):
┌─────────────────┐
│ 0 1 0 0 │ ← 节点0指向节点1
│ 0 0 1 0 │ ← 节点1指向节点2
│ 0 0 0 1 │ ← 节点2指向节点3
│ 0 0 0 0 │ ← 节点3无输出
└─────────────────┘

3.4 API使用

from nas_201_api import NASBench201API
 
# 加载基准
api = NASBench201API('/path/to/NAS-Bench-201.h5')
 
# 获取所有架构信息
total_archs = api.get_algo_count()
print(f"Total architectures: {total_archs}")  # 15,625
 
# 查询架构
arch_index = 1234
metrics = api.query_by_index(arch_index, 'cifar10')
 
# 或通过邻接矩阵查询
adj = [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]]
ops = ['none', 'conv_3x3', 'conv_1x1', 'avg_pool_3x3']
metrics = api.query_meta_info([adj, ops], None, 'cifar10')
 
# 获取训练曲线
info = api.get_more_info(arch_index, 'cifar10', use_12epochs=False)
# 返回训练时间、验证准确率历史等

3.5 性能对比

架构CIFAR-10CIFAR-100ImageNet-16
最优架构94.37%73.51%47.31%
次优架构94.27%73.33%47.13%
平均架构87.23%60.42%30.11%
随机架构84.23%56.72%25.82%

4. NATS-Bench

4.1 概述

NATS-Bench4是NAS-Bench-201的扩展,包含**拓扑(Topology)大小(Size)**两个搜索空间。

核心特点

  • 拓扑空间:与NAS-Bench-201相同(15,625个架构)
  • 大小空间:更大的通道数和深度变化(32,768个配置)
  • 更全面的评估:不同训练配置的结果

4.2 大小搜索空间

# NATS-Bench大小搜索空间
SIZE_SPACE = {
    'depth': [1, 2, 3, 4],           # 网络深度
    'width': [8, 16, 24, 32],        # 通道数
    'num_nodes': 4,                   # 节点数
    'resolution': [32, 64, 96, 128], # 输入分辨率
}

4.3 API使用

from nats_bench import NATS
 
# 加载NATS-Bench
api = NATS('/path/to/NATS-Bench-4-size')
 
# 查询大小空间
config = {
    'depth': 3,
    'width': 16,
    'num_nodes': 4,
    'resolution': 32
}
metrics = api.query(config, 'cifar10')
 
# 获取拓扑空间性能
api_topology = NATS('/path/to/NATS-Bench-4-topology')
topo_metrics = api_topology.query_by_index(1234, 'cifar10')

5. 搜索空间比较

5.1 搜索空间规模

基准架构数量搜索空间大小
NAS-Bench-101423,624
NAS-Bench-20115,625
NATS-Bench-Topo15,625
NATS-Bench-Size32,768

5.2 操作集比较

操作NB-101NB-201NB-301
none
skip_connect
conv_3x3
conv_1x1
max_pool
avg_pool
sep_conv
dil_conv

5.3 评估数据集

基准CIFAR-10CIFAR-100ImageNet-16-120其他
NAS-Bench-101
NAS-Bench-201
NATS-Bench

6. 公平比较实践

6.1 评估协议

class NASBenchmarkEvaluator:
    """标准NAS基准评估器"""
    def __init__(self, benchmark_name):
        self.benchmark = self._load_benchmark(benchmark_name)
    
    def evaluate(self, method, n_trials=10):
        """评估方法"""
        results = []
        
        for trial in range(n_trials):
            # 设置随机种子
            set_seed(trial)
            
            # 运行方法
            arch = method.search()
            
            # 获取性能
            if isinstance(self.benchmark, NASBench201API):
                metrics = self.benchmark.query_by_index(arch, 'cifar10')
            else:
                metrics = self.benchmark.query(arch)
            
            results.append({
                'trial': trial,
                'arch': arch,
                'test_acc': metrics['test_accuracy'],
                'val_acc': metrics['valid_accuracy'],
                'train_time': metrics['training_time'],
            })
        
        return self._summarize(results)
    
    def _summarize(self, results):
        """汇总结果"""
        return {
            'mean_test_acc': np.mean([r['test_acc'] for r in results]),
            'std_test_acc': np.std([r['test_acc'] for r in results]),
            'mean_train_time': np.mean([r['train_time'] for r in results]),
            'best_test_acc': max([r['test_acc'] for r in results]),
        }

6.2 常用指标

指标描述计算方法
Best Accuracy最优架构准确率
Mean Accuracy平均准确率
Regret与最优的差距
Correlation与真实性能相关性Spearman ρ / Kendall τ
Efficiency搜索效率

6.3 方法对比框架

def compare_methods(benchmark, methods, n_trials=10):
    """比较多种NAS方法"""
    evaluator = NASBenchmarkEvaluator(benchmark)
    
    comparison = {}
    
    for name, method in methods.items():
        print(f"Evaluating {name}...")
        results = evaluator.evaluate(method, n_trials=n_trials)
        comparison[name] = results
        
        # 打印摘要
        print(f"  Mean Acc: {results['mean_test_acc']:.2f}% ± {results['std_test_acc']:.2f}%")
        print(f"  Best Acc: {results['best_test_acc']:.2f}%")
        print(f"  Mean Time: {results['mean_train_time']:.2f}s")
        print()
    
    return comparison
 
# 使用示例
methods = {
    'random': RandomSearch(),
    'evolution': EvolutionSearch(population_size=50),
    'darts': DARTSearch(epochs=50),
    'bayes': BayesianOptimization(),
}
 
results = compare_methods('nasbench201', methods)

7. 基准使用注意事项

7.1 常见陷阱

陷阱说明避免方法
查询偏差重复查询同一架构导致过拟合记录查询次数
超参数不一致不同方法使用不同训练超参数使用统一评估协议
早停滥用允许早停导致不公平比较固定训练时长
随机性忽视未报告多次运行的方差多次运行取平均

7.2 正确使用方式

# ✓ 正确做法
def proper_usage(api):
    # 1. 使用统一的评估配置
    eval_config = {
        'dataset': 'cifar10',
        'epochs': 12,  # 或108
        'use_cutout': True,
        'use_aug': True,
    }
    
    # 2. 多次运行取平均
    results = []
    for seed in range(5):
        arch = search_method(seed=seed)
        metrics = api.query_by_index(arch, dataset='cifar10')
        results.append(metrics)
    
    mean_acc = np.mean([r['test_accuracy'] for r in results])
    
    # 3. 报告完整信息
    return {
        'mean': mean_acc,
        'std': np.std([r['test_accuracy'] for r in results]),
        'n_trials': len(results),
    }
 
# ✗ 错误做法
def improper_usage(api):
    # 只查询一次
    arch = search_method()
    metrics = api.query_by_index(arch)
    return metrics['test_accuracy']

8. 其他基准资源

8.1 特定任务基准

基准任务说明
NAS-Bench-301超分辨率超分辨率网络搜索
Transformer BenchNLPTransformer架构搜索
NAS-Bench-Macro检测/分割目标检测/分割
Once-for-All多平台跨平台部署

8.2 开源工具

# NASBench相关库
pip install nasbench         # NAS-Bench-101
pip install nas-201-api      # NAS-Bench-201
pip install nats-bench       # NATS-Bench
 
# 使用示例
from nasbench import api as nb
from nas_201_api import NASBench201API
from nats_bench import NATS

参考文献

Footnotes

  1. Ying CX, Klein A, Christiansen E, et al. NAS-Bench-101: Towards Reproducible Neural Architecture Search. ICML 2019.

  2. Dong X, Yang Y. NAS-Bench-201: Extending the Benchmark for Neural Architecture Search. NeurIPS 2020.

  3. Dong X, Yang Y. NAS-Bench-201: Extending the Benchmark for Neural Architecture Search. CVPR 2020.

  4. Li C, Peng Z, Yuan Z, et al. NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size. CVPR 2022.