NAS搜索空间设计

搜索空间是NAS的基础,决定了探索的可能性边界和最终架构的质量上限。本章详细介绍搜索空间的设计原则、常见范式和前沿进展。

一、搜索空间的组成

┌─────────────────────────────────────────────────────────────┐
│                    搜索空间组成                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────┐ │
│  │  操作集合   │  │  拓扑结构   │  │   宏观配置          │ │
│  │  (Ops)     │  │  (Edges)   │  │   (Macro)           │ │
│  ├─────────────┤  ├─────────────┤  ├─────────────────────┤ │
│  │ • Conv3x3   │  │ • DAG节点   │  │ • 深度 (层数)       │ │
│  │ • Conv5x5   │  │ • 连接关系  │  │ • 宽度 (通道数)     │ │
│  │ • DilConv   │  │ • 跳过连接  │  │ • 分辨率            │ │
│  │ • Pool      │  │ • 残差连接  │  │ • 模块数量          │ │
│  │ • Attention │  │ • 多路径   │  │ • 阶段划分          │ │
│  └─────────────┘  └─────────────┘  └─────────────────────┘ │
│                                                             │
└─────────────────────────────────────────────────────────────┘

二、单元格结构(Cell-based)

2.1 核心思想

最流行的搜索空间设计范式:将网络划分为重复的单元格结构。

┌─────────────────────────────────────────────────────────────┐
│                   Cell-based网络结构                         │
│                                                             │
│  Input                                                      │
│    │                                                        │
│    ▼                                                        │
│  ┌─────────────────┐                                        │
│  │  Normal Cell ×N │  ← 重复N次,搜索最佳结构               │
│  └─────────────────┘                                        │
│    │                                                        │
│    ▼                                                        │
│  ┌─────────────────┐                                        │
│  │ Reduction Cell  │  ← 特征图尺寸减半                       │
│  └─────────────────┘                                        │
│    │                                                        │
│    ▼                                                        │
│  ┌─────────────────┐                                        │
│  │  Normal Cell ×N │                                        │
│  └─────────────────┘                                        │
│    │                                                        │
│    ▼                                                        │
│  Output                                                     │
└─────────────────────────────────────────────────────────────┘

2.2 单元格内部结构

典型的DAG(有向无环图)结构:

Normal Cell 内部结构 (4节点)
┌─────────────────────────────────────────────────┐
│                                                 │
│  node_0 ───────┬──────────────────────┐         │
│     │          │                      │         │
│     ├──────────┼──→ op[0,1] ──────────┤         │
│     │          │                      │         │
│     ├──────────┼──→ op[0,2] ──────────┤         │
│     │          │                      │         │
│     │          ├──────────────────────┼──→ node_3
│     │          │                      │         │
│     └──────────┼──→ op[0,3] ──────────┤         │
│                │                      │         │
│  node_1 ───────┼──→ op[1,2] ──────────┤         │
│     │          │                      │         │
│     ├──────────┼──→ op[1,3] ──────────┤         │
│     │          │                      │         │
│  node_2 ───────┼──→ op[2,3] ──────────┘         │
│     │          │                                    │
└─────────────────────────────────────────────────┘

Reduction Cell 内部结构 (H=W/2)
┌─────────────────────────────────────────────────┐
│                                                 │
│  node_0 ──┬─→ op[0,1] ──┬─→ op[1,2] ──┐         │
│           └─→ op[0,2] ──┴─→ op[2,3] ──┼──→ out  │
│                                        │         │
│  node_1 ──┬─→ op[1,2] ────────────────┤         │
│           └─→ op[1,3] ────────────────┘         │
│                                                  │
│  node_2 ──┴─→ op[2,3] ────────────────────       │
│                                                  │
└─────────────────────────────────────────────────┘

2.3 标准操作集

NAS-Bench-201操作集

CANDIDATE_OPERATIONS = {
    0: None,              # 'none': 无操作,残差断开
    1: 'skip_connect',   # 恒等映射
    2: 'sep_conv_3x3',   # 深度可分离卷积 3×3
    3: 'sep_conv_5x5',   # 深度可分离卷积 5×5
    4: 'dil_conv_3x3',   # 空洞卷积 3×3 (dilation=2)
    5: 'dil_conv_5x5',   # 空洞卷积 5×5 (dilation=2)
}

NASNet风格操作集

NASNET_OPERATIONS = [
    # 卷积操作
    'sep_conv_3x3',      # 深度可分离卷积
    'sep_conv_5x5',
    'sep_conv_7x7',
    'max_pool_3x3',
    'avg_pool_3x3',
    # 空洞卷积
    'dil_conv_3x3',
    'dil_conv_5x5',
    # 其他
    'skip_connect',       # 残差连接
    'none',              # 无连接
]

2.4 搜索空间规模

搜索空间架构数量特点
NAS-Bench-101~433K可训练的DAG
NAS-Bench-20115,625固定拓扑,4节点
NAS-Bench-301~10^18DARTS空间
NAS-Bench-NLP10^18+NLP专用

计算方法

  • NAS-Bench-201: 6条边 × 6种操作 = 6^6 = 15,625
  • NAS-Bench-301: 搜索空间更大,组合爆炸

三、宏观架构搜索

3.1 什么是宏观架构

宏观架构定义了网络级别的配置,不仅仅是单元格内部结构:

@dataclass
class MacroConfig:
    """宏观架构配置"""
    
    # 网络深度
    num_cells: int              # 总单元格数量
    
    # 各阶段配置
    stages: List[StageConfig]   # 每个阶段的配置
    
    # 输入处理
    input_channels: int         # 输入通道数
    stem_channels: int          # Stem层通道数
    
    # 分类头
    num_classes: int            # 分类类别数
    
    # 辅助任务
    use_aux_head: bool          # 是否使用辅助分类头

3.2 深度与宽度的联合搜索

class DepthWidthSearchSpace:
    """深度和宽度的联合搜索空间"""
    
    def __init__(self):
        # 深度搜索范围
        self.depth_options = [8, 12, 16, 20, 24]
        
        # 宽度(通道数)选项
        self.width_options = [0.5, 0.75, 1.0, 1.25, 1.5]
        
        # 分辨率选项
        self.resolution_options = [224, 240, 256, 288]
    
    def sample(self):
        """采样一个宏观配置"""
        return {
            'depth': random.choice(self.depth_options),
            'width': random.choice(self.width_options),
            'resolution': random.choice(self.resolution_options),
        }

3.3 EfficientNet的复合缩放

EfficientNet使用复合系数同时缩放深度、宽度和分辨率:

约束条件:

四、分层搜索空间

4.1 上下文无关文法(CFG)

使用CFG定义多层次的架构模板:

# CFG定义
S       → Cell+
Cell    → Stage+
Stage   → Block+
Block   → Op(Op, Op) | ResBlock(Op, Op) | AttentionBlock

# 操作
Op      → Conv3x3 | Conv5x5 | DilConv3x3 | DilConv5x5 | Skip | Pool

4.2 层次化表示

class HierarchicalSearchSpace:
    """分层搜索空间"""
    
    def __init__(self):
        # 第一层:操作原语
        self.primitives = ['conv3x3', 'conv5x5', 'skip', 'pool']
        
        # 第二层:操作组合
        self.motif_grammar = [
            'conv3x3-conv5x5',
            'skip-conv3x3', 
            'pool-conv3x3-conv5x5',
        ]
        
        # 第三层:宏模式
        self.macro_patterns = [
            'repeat(motif_A, 3)',
            'repeat(motif_B, 2) + repeat(motif_A, 1)',
        ]
    
    def generate_architecture(self):
        """生成完整架构"""
        return HierarchicalArchitecture(
            macro=self.sample_macro(),
            motifs=[self.sample_motif() for _ in range(n)],
            ops=[self.sample_op() for _ in range(m)]
        )

五、拓扑约束的演进

5.1 传统约束 vs 无约束

约束类型DARTSFX-DARTS (2025)
同类型Cell共享拓扑
固定节点数
固定连接模式

5.2 FX-DARTS:无拓扑约束

class FXDARTSCell:
    """
    FX-DARTS: 每个Cell独立学习拓扑
    """
    
    def __init__(self, num_nodes, num_ops):
        # 每个Cell有独立的架构参数
        self.edge_weights = nn.ParameterDict()
        
        # 为每个Cell单独初始化
        for cell_id in range(num_cells):
            for i in range(num_nodes):
                for j in range(i+1, num_nodes):
                    # 不同Cell可以有不同的边权重
                    key = f'cell_{cell_id}_edge_{i}_{j}'
                    self.edge_weights[key] = nn.Parameter(torch.zeros(num_ops))
    
    def forward(self, inputs, cell_id):
        """每个Cell独立前向"""
        states = list(inputs)
        
        for i in range(self.num_nodes):
            for j in range(i+1, self.num_nodes):
                # 使用该Cell特有的边权重
                key = f'cell_{cell_id}_edge_{i}_{j}'
                weights = F.softmax(self.edge_weights[key], dim=-1)
                # 混合操作...

六、特殊领域搜索空间

6.1 Vision Transformer搜索空间

class ViTSearchSpace:
    """ViT专用搜索空间"""
    
    def __init__(self):
        # Patch Embedding
        self.patch_size = [4, 8, 16]
        
        # 注意力参数
        self.num_heads = [4, 8, 12, 16]
        self.head_dim = [16, 32, 64]
        
        # MLP配置
        self.mlp_ratio = [2.0, 4.0, 8.0]
        
        # 层配置
        self.depth = [6, 8, 12, 16]
        
        # 注意力类型
        self.attention_types = [
            'standard',      # 标准注意力
            'window',        # 窗口注意力
            'shifted_window', # 移位窗口
            'linear',        # 线性注意力
        ]

6.2 目标检测搜索空间

class DetectionSearchSpace:
    """目标检测专用搜索空间"""
    
    def __init__(self):
        # Backbone搜索
        self.backbone_ops = ['conv3x3', 'conv5x5', 'skip', 'dil_conv']
        
        # Neck搜索 (FPN等)
        self.neck_configs = ['FPN', 'BiFPN', 'PAN', 'NAS-FPN']
        
        # Head搜索
        self.head_depth = [1, 2, 3, 4]
        self.head_channels = [64, 128, 256, 512]
        
        # 特征融合
        self.fusion_operations = ['add', 'concat', 'attention']

6.3 NAS-Bench-201结构详解

# NAS-Bench-201完整定义
NASBENCH201_CONFIG = {
    'num_vertices': 4,        # 节点数 (包含输入输出)
    'num_operations': 6,      # 操作数
    'max_edges': 6,           # 最大边数
    'search_space_size': 6**6,  # 15625
    
    'operations': [
        'none',           # 0
        'skip_connect',   # 1
        'sep_conv_3x3',  # 2
        'sep_conv_5x5',  # 3
        'dil_conv_3x3',  # 4
        'dil_conv_5x5',  # 5
    ],
    
    # 固定输入输出
    'input_node': 0,
    'output_node': 3,
}

七、搜索空间设计最佳实践

7.1 设计原则

  1. 表达力与效率平衡

    • 太简单:限制探索能力
    • 太复杂:搜索困难
  2. 领域先验融合

    • 将人类知识编码到搜索空间
    • 减少无效探索
  3. 可验证性

    • 使用NAS-Bench基准验证设计
    • 确保可复现

7.2 常见陷阱

陷阱描述解决方案
跳过连接过多DARTS倾向于选择SkipP-DARTS, FairDARTS
节点退化只用少数节点添加稀疏正则
空间碎片化拓扑过于多样层级化设计

7.3 搜索效率提升

class SearchSpaceOptimizer:
    """搜索空间优化"""
    
    def prune_space(self, search_space, dataset):
        """
        基于先验知识剪枝搜索空间
        """
        pruned = copy.deepcopy(search_space)
        
        # 1. 移除已知无效操作
        pruned.operations = [op for op in pruned.operations 
                           if self.is_effective(op, dataset)]
        
        # 2. 限制拓扑复杂度
        pruned.max_edges = min(pruned.max_edges, 6)
        
        # 3. 添加约束
        pruned.constraints = ['max_skip_ratio<0.3']
        
        return pruned

八、参考论文