NAS搜索策略详解
搜索策略是NAS的核心,决定了如何在庞大的架构空间中高效探索最优解。本章详细介绍三类主流搜索策略:强化学习、进化算法和梯度优化方法。
一、基于强化学习的NAS
1.1 核心原理
强化学习NAS将架构生成建模为序列决策问题:
┌─────────────────────────────────────────────────────┐
│ RNN控制器 │
│ │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ h₁ │ -> │ h₂ │ -> │ h₃ │ -> │ h₄ │ -> ... │
│ └─────┘ └─────┘ └─────┘ └─────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ [op选择] [连接选择] [节点输入] [更多决策] │
└─────────────────────────────────────────────────────┘
关键组件:
| 组件 | 描述 |
|---|---|
| Agent | RNN控制器,输出架构描述 |
| State | RNN隐藏状态 |
| Action | 选择操作类型、连接关系等 |
| Reward | 子网络验证集准确率 |
1.2 控制器设计
// RNN控制器伪代码
class Controller {
private:
LSTMCell rnn;
int hidden_size = 100;
int num_ops = 6; // 候选操作数
public:
// 采样一个完整的架构
vector<int> sample_architecture(int num_layers) {
vector<int> arch;
vector<float> hidden_states;
for (int step = 0; step < num_layers; step++) {
// 更新RNN状态
auto [h, c] = rnn(previous_hidden);
// 输出层选择概率
vector<float> logits = fully_connected(h, num_ops);
vector<float> probs = softmax(logits);
// 采样操作
int action = categorical_sample(probs);
arch.push_back(action);
previous_hidden = h;
}
return arch;
}
// 策略梯度更新
void update(vector<int> arch, float reward) {
// 计算对数概率
float log_prob = sum(log(probs[i]) for i in arch);
// 策略梯度: ∇J = E[∇log π(a|s) * R]
float loss = -log_prob * reward;
loss.backward();
optimizer.step();
}
};1.3 代表性方法
NASNet (Zoph et al., 2017)
开创性工作,但计算代价极高:
| 指标 | 数值 |
|---|---|
| GPU Days | 800 |
| 搜索空间 | 大量候选操作 |
| 最终准确率 | SOTA on CIFAR-10 |
ENAS - 高效神经网络架构搜索
核心贡献:引入权重复用,避免从头训练
# ENAS的关键思想
class ENAS:
def __init__(self):
self.shared_weights = build_supernet()
def train_controller(self):
"""训练RNN控制器"""
for _ in range(num_epochs):
# 1. 用控制器采样架构
arch = self.controller.sample()
# 2. 继承超网络的权重
child_weights = self.shared_weights.get_subnet_weights(arch)
# 3. 快速微调
val_acc = fine_tune(child_weights, arch, few_epochs)
# 4. 用验证准确率更新控制器
self.controller.update(arch, val_acc)效率提升:从800 GPU days降至0.5 GPU days
二、基于进化算法的NAS
2.1 核心原理
进化算法将网络架构编码为”染色体”,通过自然选择的机制进化:
┌──────────────────────────────────────────────────────┐
│ 进化流程 │
│ │
│ 初始化 ──> 评估 ──> 选择 ──> 交叉 ──> 变异 ──> 评估 │
│ │ │ │
│ └──────────────────────────────────────┘ │
│ (循环) │
└──────────────────────────────────────────────────────┘
关键操作:
| 操作 | 描述 | 示例 |
|---|---|---|
| 选择 | 保留优秀个体 | 锦标赛选择、轮盘赌 |
| 交叉 | 组合两个个体 | 单点交叉、均匀交叉 |
| 变异 | 随机修改个体 | 操作替换、连接变化 |
2.2 Regularized Evolution
Real等人提出的方法,通过”老化池”机制改进进化:
class RegularizedEvolution {
private:
int population_size;
int tournament_size;
queue<Individual> history; // 老化池
public:
Individual evolve() {
// 1. 锦标赛选择
vector<Individual> tournament;
for (int i = 0; i < tournament_size; i++) {
tournament.push_back(sample_random(history));
}
Individual parent = tournament_best(tournament);
// 2. 变异
Individual child = mutate(parent);
// 3. 评估
child.fitness = evaluate(child);
// 4. 加入历史,移除最老个体(FIFO)
history.push(child);
history.pop();
return child;
}
};2.3 AmoebaNet
将强化学习的奖励塑形与进化算法结合:
- 使用PPN(Proximal Policy)替代随机采样
- 自适应变异概率
- 在ImageNet上达到SOTA
三、基于梯度的NAS(DARTS系列)
3.1 DARTS核心原理
DARTS是NAS领域的里程碑工作,将离散的架构选择松弛为连续可优化参数。
操作松弛
原始离散选择:
松弛为连续加权:
其中 是边 的架构参数。
双层优化
| 层级 | 优化目标 | 更新频率 |
|---|---|---|
| 上层 | 架构参数 | 低频 |
| 下层 | 网络权重 | 高频 |
3.2 DARTS搜索空间
# NAS-Bench-201标准操作集
CANDIDATE_OPS = {
0: 'none', # 无操作(残差断开)
1: 'skip_connect', # 恒等映射
2: 'sep_conv_3x3', # 深度可分离卷积 3x3
3: 'sep_conv_5x5', # 深度可分离卷积 5x5
4: 'dil_conv_3x3', # 空洞卷积 3x3
5: 'dil_conv_5x5', # 空洞卷积 5x5
}
# 4节点DAG搜索空间
# 节点: 0=输入, 1,2,3=中间节点, 4=输出
# 边: 0→1, 0→2, 0→3, 1→2, 1→3, 2→3
# 组合数: 6^6 = 156253.3 DARTS实现
class DARTSCell {
private:
int num_nodes;
int num_ops;
vector<nn.Parameter> edge_weights; // [6 edges x 6 ops]
public:
// 前向传播:混合所有操作的输出
Tensor forward(Tensor x, Tensor h_prev) {
vector<Tensor> states = {x, h_prev};
for (int node = 1; node <= num_nodes; node++) {
Tensor aggregate;
// 遍历所有输入边
for (auto [src, dst] : edges_to(node)) {
// 软最大加权
Tensor w = softmax(edge_weights[dst][src]);
// 混合操作
for (int op_id = 0; op_id < num_ops; op_id++) {
aggregate += w[op_id] * ops[op_id](states[src]);
}
}
states.push_back(aggregate);
}
return states.back();
}
// 离散化获取最终架构
vector<int> get_architecture() {
vector<int> arch;
for (auto& w : edge_weights) {
arch.push_back(w.argmax()); // 取最大权重的操作
}
return arch;
}
};3.4 DARTS变体
| 方法 | 年份 | 核心改进 | 解决的问题 |
|---|---|---|---|
| P-DARTS | 2019 | 渐进式搜索 | 跳过连接过多 |
| FairDARTS | 2020 | 竞争消除 | 操作间不公平 |
| DARTS+ | 2020 | 早停策略 | 性能崩溃 |
| SNAS | 2019 | 随机NAT | 梯度估计偏差 |
| GDAS | 2019 | 可微采样 | One-hot梯度问题 |
P-DARTS:渐进式搜索
分阶段逐步增加搜索复杂度:
阶段1: 深度=2, 跳过连接概率高
阶段2: 深度=4, 正常训练
阶段3: 深度=6, 精细化
FairDARTS:竞争消除
问题:跳过连接因”零操作”特性被过度选择
解决方案:引入sigmoid门控替代softmax
改为:
3.5 2024-2025年DARTS新进展
BOSE-NAS (2024)
Bi-level Optimization with Stable Equilibrium
- 提出**均衡影响(Equilibrium Influential, EI)**指标评估操作重要性
- 解决双层优化的不稳定性
- 在多个基准上达到SOTA
FX-DARTS (2025)
移除传统DARTS的关键约束:
传统DARTS假设:同类型单元格(Normal/Reduction)必须共享拓扑
FX-DARTS创新:允许不同单元格独立学习最优拓扑
四、方法对比总结
| 方法 | 搜索效率 | 搜索质量 | 计算成本 | 稳定性 |
|---|---|---|---|---|
| RL-based | ⭐⭐ | ⭐⭐⭐⭐⭐ | 高 | ⭐⭐⭐ |
| Evolution | ⭐⭐ | ⭐⭐⭐⭐⭐ | 高 | ⭐⭐⭐ |
| DARTS | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 低 | ⭐⭐⭐ |
| DARTS+变体 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | 低 | ⭐⭐⭐⭐ |
各方法适用场景
| 场景 | 推荐方法 |
|---|---|
| 超大规模搜索 | DARTS |
| 高精度需求 | RL + 长时间搜索 |
| 多目标优化 | 进化算法 |
| 快速原型 | DARTS |
| 资源受限 | DARTS变体 |
五、代码示例:简化DARTS
// 完整的简化DARTS训练流程
int main() {
// 1. 构建超网络
DARTSearchSpace search_space(num_nodes=4, num_ops=6);
// 2. 定义优化器
auto weight_opt = torch::optim::Adam(search_space.weights(), lr=0.001);
auto arch_opt = torch::optim::Adam(search_space.arch_params(), lr=0.0003);
// 3. 双层交替优化
for (int epoch = 0; epoch < 50; epoch++) {
// 下层:更新权重
weight_opt.zero_grad();
float train_loss = compute_loss(search_space, train_data);
train_loss.backward();
weight_opt.step();
// 上层:更新架构参数
arch_opt.zero_grad();
float val_loss = compute_loss(search_space, val_data);
val_loss.backward();
arch_opt.step();
}
// 4. 离散化获取最终架构
auto arch = search_space.get_architecture();
print("Found architecture:", arch);
// 5. 从头训练最终架构
auto final_model = build_model(arch);
train(final_model, full_data);
return 0;
}