NAS搜索策略详解

搜索策略是NAS的核心,决定了如何在庞大的架构空间中高效探索最优解。本章详细介绍三类主流搜索策略:强化学习、进化算法和梯度优化方法。

一、基于强化学习的NAS

1.1 核心原理

强化学习NAS将架构生成建模为序列决策问题:

┌─────────────────────────────────────────────────────┐
│                  RNN控制器                          │
│                                                     │
│  ┌─────┐    ┌─────┐    ┌─────┐    ┌─────┐          │
│  │ h₁ │ -> │ h₂ │ -> │ h₃ │ -> │ h₄ │ -> ...      │
│  └─────┘    └─────┘    └─────┘    └─────┘          │
│     │          │          │          │              │
│     ▼          ▼          ▼          ▼              │
│  [op选择]  [连接选择]  [节点输入]  [更多决策]        │
└─────────────────────────────────────────────────────┘

关键组件

组件描述
AgentRNN控制器,输出架构描述
StateRNN隐藏状态
Action选择操作类型、连接关系等
Reward子网络验证集准确率

1.2 控制器设计

// RNN控制器伪代码
class Controller {
private:
    LSTMCell rnn;
    int hidden_size = 100;
    int num_ops = 6; // 候选操作数
    
public:
    // 采样一个完整的架构
    vector<int> sample_architecture(int num_layers) {
        vector<int> arch;
        vector<float> hidden_states;
        
        for (int step = 0; step < num_layers; step++) {
            // 更新RNN状态
            auto [h, c] = rnn(previous_hidden);
            
            // 输出层选择概率
            vector<float> logits = fully_connected(h, num_ops);
            vector<float> probs = softmax(logits);
            
            // 采样操作
            int action = categorical_sample(probs);
            arch.push_back(action);
            
            previous_hidden = h;
        }
        return arch;
    }
    
    // 策略梯度更新
    void update(vector<int> arch, float reward) {
        // 计算对数概率
        float log_prob = sum(log(probs[i]) for i in arch);
        
        // 策略梯度: ∇J = E[∇log π(a|s) * R]
        float loss = -log_prob * reward;
        loss.backward();
        optimizer.step();
    }
};

1.3 代表性方法

NASNet (Zoph et al., 2017)

开创性工作,但计算代价极高:

指标数值
GPU Days800
搜索空间大量候选操作
最终准确率SOTA on CIFAR-10

ENAS - 高效神经网络架构搜索

核心贡献:引入权重复用,避免从头训练

# ENAS的关键思想
class ENAS:
    def __init__(self):
        self.shared_weights = build_supernet()
    
    def train_controller(self):
        """训练RNN控制器"""
        for _ in range(num_epochs):
            # 1. 用控制器采样架构
            arch = self.controller.sample()
            
            # 2. 继承超网络的权重
            child_weights = self.shared_weights.get_subnet_weights(arch)
            
            # 3. 快速微调
            val_acc = fine_tune(child_weights, arch, few_epochs)
            
            # 4. 用验证准确率更新控制器
            self.controller.update(arch, val_acc)

效率提升:从800 GPU days降至0.5 GPU days

二、基于进化算法的NAS

2.1 核心原理

进化算法将网络架构编码为”染色体”,通过自然选择的机制进化:

┌──────────────────────────────────────────────────────┐
│                   进化流程                           │
│                                                      │
│   初始化 ──> 评估 ──> 选择 ──> 交叉 ──> 变异 ──> 评估 │
│     │                                      │        │
│     └──────────────────────────────────────┘        │
│                      (循环)                          │
└──────────────────────────────────────────────────────┘

关键操作

操作描述示例
选择保留优秀个体锦标赛选择、轮盘赌
交叉组合两个个体单点交叉、均匀交叉
变异随机修改个体操作替换、连接变化

2.2 Regularized Evolution

Real等人提出的方法,通过”老化池”机制改进进化:

class RegularizedEvolution {
private:
    int population_size;
    int tournament_size;
    queue<Individual> history; // 老化池
    
public:
    Individual evolve() {
        // 1. 锦标赛选择
        vector<Individual> tournament;
        for (int i = 0; i < tournament_size; i++) {
            tournament.push_back(sample_random(history));
        }
        Individual parent = tournament_best(tournament);
        
        // 2. 变异
        Individual child = mutate(parent);
        
        // 3. 评估
        child.fitness = evaluate(child);
        
        // 4. 加入历史,移除最老个体(FIFO)
        history.push(child);
        history.pop();
        
        return child;
    }
};

2.3 AmoebaNet

将强化学习的奖励塑形与进化算法结合:

  • 使用PPN(Proximal Policy)替代随机采样
  • 自适应变异概率
  • 在ImageNet上达到SOTA

三、基于梯度的NAS(DARTS系列)

3.1 DARTS核心原理

DARTS是NAS领域的里程碑工作,将离散的架构选择松弛为连续可优化参数。

操作松弛

原始离散选择:

松弛为连续加权:

其中 是边 的架构参数。

双层优化

层级优化目标更新频率
上层架构参数 低频
下层网络权重 高频

3.2 DARTS搜索空间

# NAS-Bench-201标准操作集
CANDIDATE_OPS = {
    0: 'none',           # 无操作(残差断开)
    1: 'skip_connect',   # 恒等映射
    2: 'sep_conv_3x3',   # 深度可分离卷积 3x3
    3: 'sep_conv_5x5',   # 深度可分离卷积 5x5
    4: 'dil_conv_3x3',   # 空洞卷积 3x3
    5: 'dil_conv_5x5',   # 空洞卷积 5x5
}
 
# 4节点DAG搜索空间
# 节点: 0=输入, 1,2,3=中间节点, 4=输出
# 边: 0→1, 0→2, 0→3, 1→2, 1→3, 2→3
# 组合数: 6^6 = 15625

3.3 DARTS实现

class DARTSCell {
private:
    int num_nodes;
    int num_ops;
    vector<nn.Parameter> edge_weights; // [6 edges x 6 ops]
    
public:
    // 前向传播:混合所有操作的输出
    Tensor forward(Tensor x, Tensor h_prev) {
        vector<Tensor> states = {x, h_prev};
        
        for (int node = 1; node <= num_nodes; node++) {
            Tensor aggregate;
            // 遍历所有输入边
            for (auto [src, dst] : edges_to(node)) {
                // 软最大加权
                Tensor w = softmax(edge_weights[dst][src]);
                // 混合操作
                for (int op_id = 0; op_id < num_ops; op_id++) {
                    aggregate += w[op_id] * ops[op_id](states[src]);
                }
            }
            states.push_back(aggregate);
        }
        return states.back();
    }
    
    // 离散化获取最终架构
    vector<int> get_architecture() {
        vector<int> arch;
        for (auto& w : edge_weights) {
            arch.push_back(w.argmax()); // 取最大权重的操作
        }
        return arch;
    }
};

3.4 DARTS变体

方法年份核心改进解决的问题
P-DARTS2019渐进式搜索跳过连接过多
FairDARTS2020竞争消除操作间不公平
DARTS+2020早停策略性能崩溃
SNAS2019随机NAT梯度估计偏差
GDAS2019可微采样One-hot梯度问题

P-DARTS:渐进式搜索

分阶段逐步增加搜索复杂度:

阶段1: 深度=2, 跳过连接概率高
阶段2: 深度=4, 正常训练
阶段3: 深度=6, 精细化

FairDARTS:竞争消除

问题:跳过连接因”零操作”特性被过度选择

解决方案:引入sigmoid门控替代softmax

改为:

3.5 2024-2025年DARTS新进展

BOSE-NAS (2024)

Bi-level Optimization with Stable Equilibrium

  • 提出**均衡影响(Equilibrium Influential, EI)**指标评估操作重要性
  • 解决双层优化的不稳定性
  • 在多个基准上达到SOTA

FX-DARTS (2025)

移除传统DARTS的关键约束:

传统DARTS假设:同类型单元格(Normal/Reduction)必须共享拓扑
FX-DARTS创新:允许不同单元格独立学习最优拓扑

四、方法对比总结

方法搜索效率搜索质量计算成本稳定性
RL-based⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
Evolution⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
DARTS⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
DARTS+变体⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐

各方法适用场景

场景推荐方法
超大规模搜索DARTS
高精度需求RL + 长时间搜索
多目标优化进化算法
快速原型DARTS
资源受限DARTS变体

五、代码示例:简化DARTS

// 完整的简化DARTS训练流程
int main() {
    // 1. 构建超网络
    DARTSearchSpace search_space(num_nodes=4, num_ops=6);
    
    // 2. 定义优化器
    auto weight_opt = torch::optim::Adam(search_space.weights(), lr=0.001);
    auto arch_opt = torch::optim::Adam(search_space.arch_params(), lr=0.0003);
    
    // 3. 双层交替优化
    for (int epoch = 0; epoch < 50; epoch++) {
        // 下层:更新权重
        weight_opt.zero_grad();
        float train_loss = compute_loss(search_space, train_data);
        train_loss.backward();
        weight_opt.step();
        
        // 上层:更新架构参数
        arch_opt.zero_grad();
        float val_loss = compute_loss(search_space, val_data);
        val_loss.backward();
        arch_opt.step();
    }
    
    // 4. 离散化获取最终架构
    auto arch = search_space.get_architecture();
    print("Found architecture:", arch);
    
    // 5. 从头训练最终架构
    auto final_model = build_model(arch);
    train(final_model, full_data);
    
    return 0;
}

六、参考论文