NAS搜索空间设计
搜索空间是NAS的基础,决定了探索的可能性边界和最终架构的质量上限。本章详细介绍搜索空间的设计原则、常见范式和前沿进展。
一、搜索空间的组成
┌─────────────────────────────────────────────────────────────┐
│ 搜索空间组成 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
│ │ 操作集合 │ │ 拓扑结构 │ │ 宏观配置 │ │
│ │ (Ops) │ │ (Edges) │ │ (Macro) │ │
│ ├─────────────┤ ├─────────────┤ ├─────────────────────┤ │
│ │ • Conv3x3 │ │ • DAG节点 │ │ • 深度 (层数) │ │
│ │ • Conv5x5 │ │ • 连接关系 │ │ • 宽度 (通道数) │ │
│ │ • DilConv │ │ • 跳过连接 │ │ • 分辨率 │ │
│ │ • Pool │ │ • 残差连接 │ │ • 模块数量 │ │
│ │ • Attention │ │ • 多路径 │ │ • 阶段划分 │ │
│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
二、单元格结构(Cell-based)
2.1 核心思想
最流行的搜索空间设计范式:将网络划分为重复的单元格结构。
┌─────────────────────────────────────────────────────────────┐
│ Cell-based网络结构 │
│ │
│ Input │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Normal Cell ×N │ ← 重复N次,搜索最佳结构 │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Reduction Cell │ ← 特征图尺寸减半 │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Normal Cell ×N │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ Output │
└─────────────────────────────────────────────────────────────┘
2.2 单元格内部结构
典型的DAG(有向无环图)结构:
Normal Cell 内部结构 (4节点)
┌─────────────────────────────────────────────────┐
│ │
│ node_0 ───────┬──────────────────────┐ │
│ │ │ │ │
│ ├──────────┼──→ op[0,1] ──────────┤ │
│ │ │ │ │
│ ├──────────┼──→ op[0,2] ──────────┤ │
│ │ │ │ │
│ │ ├──────────────────────┼──→ node_3
│ │ │ │ │
│ └──────────┼──→ op[0,3] ──────────┤ │
│ │ │ │
│ node_1 ───────┼──→ op[1,2] ──────────┤ │
│ │ │ │ │
│ ├──────────┼──→ op[1,3] ──────────┤ │
│ │ │ │ │
│ node_2 ───────┼──→ op[2,3] ──────────┘ │
│ │ │ │
└─────────────────────────────────────────────────┘
Reduction Cell 内部结构 (H=W/2)
┌─────────────────────────────────────────────────┐
│ │
│ node_0 ──┬─→ op[0,1] ──┬─→ op[1,2] ──┐ │
│ └─→ op[0,2] ──┴─→ op[2,3] ──┼──→ out │
│ │ │
│ node_1 ──┬─→ op[1,2] ────────────────┤ │
│ └─→ op[1,3] ────────────────┘ │
│ │
│ node_2 ──┴─→ op[2,3] ──────────────────── │
│ │
└─────────────────────────────────────────────────┘
2.3 标准操作集
NAS-Bench-201操作集
CANDIDATE_OPERATIONS = {
0: None, # 'none': 无操作,残差断开
1: 'skip_connect', # 恒等映射
2: 'sep_conv_3x3', # 深度可分离卷积 3×3
3: 'sep_conv_5x5', # 深度可分离卷积 5×5
4: 'dil_conv_3x3', # 空洞卷积 3×3 (dilation=2)
5: 'dil_conv_5x5', # 空洞卷积 5×5 (dilation=2)
}NASNet风格操作集
NASNET_OPERATIONS = [
# 卷积操作
'sep_conv_3x3', # 深度可分离卷积
'sep_conv_5x5',
'sep_conv_7x7',
'max_pool_3x3',
'avg_pool_3x3',
# 空洞卷积
'dil_conv_3x3',
'dil_conv_5x5',
# 其他
'skip_connect', # 残差连接
'none', # 无连接
]2.4 搜索空间规模
| 搜索空间 | 架构数量 | 特点 |
|---|---|---|
| NAS-Bench-101 | ~433K | 可训练的DAG |
| NAS-Bench-201 | 15,625 | 固定拓扑,4节点 |
| NAS-Bench-301 | ~10^18 | DARTS空间 |
| NAS-Bench-NLP | 10^18+ | NLP专用 |
计算方法:
- NAS-Bench-201: 6条边 × 6种操作 = 6^6 = 15,625
- NAS-Bench-301: 搜索空间更大,组合爆炸
三、宏观架构搜索
3.1 什么是宏观架构
宏观架构定义了网络级别的配置,不仅仅是单元格内部结构:
@dataclass
class MacroConfig:
"""宏观架构配置"""
# 网络深度
num_cells: int # 总单元格数量
# 各阶段配置
stages: List[StageConfig] # 每个阶段的配置
# 输入处理
input_channels: int # 输入通道数
stem_channels: int # Stem层通道数
# 分类头
num_classes: int # 分类类别数
# 辅助任务
use_aux_head: bool # 是否使用辅助分类头3.2 深度与宽度的联合搜索
class DepthWidthSearchSpace:
"""深度和宽度的联合搜索空间"""
def __init__(self):
# 深度搜索范围
self.depth_options = [8, 12, 16, 20, 24]
# 宽度(通道数)选项
self.width_options = [0.5, 0.75, 1.0, 1.25, 1.5]
# 分辨率选项
self.resolution_options = [224, 240, 256, 288]
def sample(self):
"""采样一个宏观配置"""
return {
'depth': random.choice(self.depth_options),
'width': random.choice(self.width_options),
'resolution': random.choice(self.resolution_options),
}3.3 EfficientNet的复合缩放
EfficientNet使用复合系数同时缩放深度、宽度和分辨率:
约束条件:
四、分层搜索空间
4.1 上下文无关文法(CFG)
使用CFG定义多层次的架构模板:
# CFG定义
S → Cell+
Cell → Stage+
Stage → Block+
Block → Op(Op, Op) | ResBlock(Op, Op) | AttentionBlock
# 操作
Op → Conv3x3 | Conv5x5 | DilConv3x3 | DilConv5x5 | Skip | Pool
4.2 层次化表示
class HierarchicalSearchSpace:
"""分层搜索空间"""
def __init__(self):
# 第一层:操作原语
self.primitives = ['conv3x3', 'conv5x5', 'skip', 'pool']
# 第二层:操作组合
self.motif_grammar = [
'conv3x3-conv5x5',
'skip-conv3x3',
'pool-conv3x3-conv5x5',
]
# 第三层:宏模式
self.macro_patterns = [
'repeat(motif_A, 3)',
'repeat(motif_B, 2) + repeat(motif_A, 1)',
]
def generate_architecture(self):
"""生成完整架构"""
return HierarchicalArchitecture(
macro=self.sample_macro(),
motifs=[self.sample_motif() for _ in range(n)],
ops=[self.sample_op() for _ in range(m)]
)五、拓扑约束的演进
5.1 传统约束 vs 无约束
| 约束类型 | DARTS | FX-DARTS (2025) |
|---|---|---|
| 同类型Cell共享拓扑 | ✓ | ✗ |
| 固定节点数 | ✓ | ✗ |
| 固定连接模式 | ✓ | ✗ |
5.2 FX-DARTS:无拓扑约束
class FXDARTSCell:
"""
FX-DARTS: 每个Cell独立学习拓扑
"""
def __init__(self, num_nodes, num_ops):
# 每个Cell有独立的架构参数
self.edge_weights = nn.ParameterDict()
# 为每个Cell单独初始化
for cell_id in range(num_cells):
for i in range(num_nodes):
for j in range(i+1, num_nodes):
# 不同Cell可以有不同的边权重
key = f'cell_{cell_id}_edge_{i}_{j}'
self.edge_weights[key] = nn.Parameter(torch.zeros(num_ops))
def forward(self, inputs, cell_id):
"""每个Cell独立前向"""
states = list(inputs)
for i in range(self.num_nodes):
for j in range(i+1, self.num_nodes):
# 使用该Cell特有的边权重
key = f'cell_{cell_id}_edge_{i}_{j}'
weights = F.softmax(self.edge_weights[key], dim=-1)
# 混合操作...六、特殊领域搜索空间
6.1 Vision Transformer搜索空间
class ViTSearchSpace:
"""ViT专用搜索空间"""
def __init__(self):
# Patch Embedding
self.patch_size = [4, 8, 16]
# 注意力参数
self.num_heads = [4, 8, 12, 16]
self.head_dim = [16, 32, 64]
# MLP配置
self.mlp_ratio = [2.0, 4.0, 8.0]
# 层配置
self.depth = [6, 8, 12, 16]
# 注意力类型
self.attention_types = [
'standard', # 标准注意力
'window', # 窗口注意力
'shifted_window', # 移位窗口
'linear', # 线性注意力
]6.2 目标检测搜索空间
class DetectionSearchSpace:
"""目标检测专用搜索空间"""
def __init__(self):
# Backbone搜索
self.backbone_ops = ['conv3x3', 'conv5x5', 'skip', 'dil_conv']
# Neck搜索 (FPN等)
self.neck_configs = ['FPN', 'BiFPN', 'PAN', 'NAS-FPN']
# Head搜索
self.head_depth = [1, 2, 3, 4]
self.head_channels = [64, 128, 256, 512]
# 特征融合
self.fusion_operations = ['add', 'concat', 'attention']6.3 NAS-Bench-201结构详解
# NAS-Bench-201完整定义
NASBENCH201_CONFIG = {
'num_vertices': 4, # 节点数 (包含输入输出)
'num_operations': 6, # 操作数
'max_edges': 6, # 最大边数
'search_space_size': 6**6, # 15625
'operations': [
'none', # 0
'skip_connect', # 1
'sep_conv_3x3', # 2
'sep_conv_5x5', # 3
'dil_conv_3x3', # 4
'dil_conv_5x5', # 5
],
# 固定输入输出
'input_node': 0,
'output_node': 3,
}七、搜索空间设计最佳实践
7.1 设计原则
-
表达力与效率平衡
- 太简单:限制探索能力
- 太复杂:搜索困难
-
领域先验融合
- 将人类知识编码到搜索空间
- 减少无效探索
-
可验证性
- 使用NAS-Bench基准验证设计
- 确保可复现
7.2 常见陷阱
| 陷阱 | 描述 | 解决方案 |
|---|---|---|
| 跳过连接过多 | DARTS倾向于选择Skip | P-DARTS, FairDARTS |
| 节点退化 | 只用少数节点 | 添加稀疏正则 |
| 空间碎片化 | 拓扑过于多样 | 层级化设计 |
7.3 搜索效率提升
class SearchSpaceOptimizer:
"""搜索空间优化"""
def prune_space(self, search_space, dataset):
"""
基于先验知识剪枝搜索空间
"""
pruned = copy.deepcopy(search_space)
# 1. 移除已知无效操作
pruned.operations = [op for op in pruned.operations
if self.is_effective(op, dataset)]
# 2. 限制拓扑复杂度
pruned.max_edges = min(pruned.max_edges, 6)
# 3. 添加约束
pruned.constraints = ['max_skip_ratio<0.3']
return pruned