DARTS可微架构搜索
DARTS(Differentiable Architecture Search)1是神经架构搜索领域的里程碑工作,首次将离散架构搜索问题转化为连续可微优化问题,实现了高效的端到端架构搜索。
1. 核心思想
DARTS的核心创新是连续松弛(Continuous Relaxation):将离散的架构选择转化为连续可微的操作权重。
1.1 离散问题形式化
原始问题为组合优化:
其中 是架构 的最优权重。
1.2 连续松弛
将操作选择从离散集合松弛为所有操作的加权和:
其中:
- 是边 的架构参数
- 操作权重通过 归一化
1.3 混合操作可视化
原始离散选择: 连续松弛后:
conv_3x3
│ ┌─────────────────────────┐
▼ │ α₁·conv_3x3 │
conv_5x5 │ α₂·conv_5x5 │
│ │ α₃·max_pool │
▼ │ α₄·skip_connect │
skip_connect └──────────┬──────────────┘
│
▼
加权和输出
2. 双层优化框架
2.1 优化目标
DARTS求解以下双层优化问题:
其中:
- 上层变量 :架构参数
- 下层变量 :网络权重
2.2 梯度计算
直接计算 需要在内层求解优化问题,计算代价高昂。使用链式法则近似:
其中 是内层优化的学习率。
一阶近似:当 时,
二阶近似:当 时,
2.3 PyTorch实现
class DARTSController(nn.Module):
"""DARTS架构控制器"""
def __init__(self, num_nodes, num_ops, C_in, C_out):
super().__init__()
self.num_nodes = num_nodes
self.num_ops = num_ops
# 边操作权重: (num_ops) for each edge
# edges: (i,j) where i < j, total = num_nodes * (num_nodes-1) // 2
self.edge_weights = nn.Parameter(
torch.randn(num_nodes, num_nodes, num_ops) * 1e-3
)
def mixed_op(self, x, edge_idx):
"""混合操作: 所有操作的加权和"""
# 归一化权重
weights = F.softmax(self.edge_weights[edge_idx], dim=-1)
# 各操作的输出
op_outs = []
for op_idx, op in enumerate(OPS):
if op == 'skip_connect' or op == 'none':
op_outs.append(x)
elif op == 'max_pool':
op_outs.append(F.max_pool2d(x, kernel_size=3, stride=1, padding=1))
elif op == 'avg_pool':
op_outs.append(F.avg_pool2d(x, kernel_size=3, stride=1, padding=1))
elif 'sep_conv' in op:
# 可分离卷积
op_outs.append(self._sep_conv(x, op))
elif 'conv' in op:
# 标准卷积
op_outs.append(self._conv(x, op))
# 加权和
return sum(w * out for w, out in zip(weights, op_outs))
def forward(self, input_nodes):
"""前向传播"""
nodes = list(input_nodes) # [node_0, node_1, ...]
for node_idx in range(1, self.num_nodes):
# 聚合所有前驱节点的信息
aggregated = 0
for prev_idx in range(node_idx):
edge_weight = self.edge_weights[prev_idx, node_idx]
mixed_out = self.mixed_op(nodes[prev_idx], (prev_idx, node_idx))
aggregated = aggregated + mixed_out
# ReLU激活(跟在BatchNorm之后)
nodes.append(F.relu(aggregated))
return nodes[-1]
def get_discrete_architecture(self):
"""导出离散架构"""
arch = []
for i in range(self.num_nodes):
for j in range(i + 1, self.num_nodes):
edge_weights = self.edge_weights[i, j]
best_op_idx = edge_weights.argmax().item()
arch.append((i, j, best_op_idx))
return arch3. Skip连接问题
3.1 问题分析
DARTS的一个关键问题是**skip连接主导(Skip Connection Dominance)**现象。
在双层优化中,skip连接因为:
- 梯度流直接,训练稳定
- 不引入额外参数
- 允许信息直接传递
导致架构参数收敛到skip连接占主导的状态,产生性能崩塌(Performance Collapse)。
3.2 原因分析
从优化角度,skip连接的损失景观更加平滑:
skip连接不改变特征维度,避免了卷积操作引入的维度变换损失。
3.3 解决方案
方案1:P-DARTS(Progressive DARTS)
核心思想:渐进增加搜索深度,同时限制skip连接数量。
class PDARTS(nn.Module):
def __init__(self):
self.skip_connect_thresh = 2 # 每阶段最多2个skip连接
def penalize_skip_connect(self, edge_weights):
"""惩罚skip连接权重"""
# 将skip连接的权重降低
skip_idx = OPS.index('skip_connect')
penalized_weights = edge_weights.clone()
penalized_weights[skip_idx] *= 0.5
return F.softmax(penalized_weights, dim=-1)方案2:DARTS+(DARTS with Warm Start)
核心思想:在搜索初期使用更多的卷积操作作为热身。
方案3:DropBranch
核心思想:随机丢弃部分skip连接进行正则化。
def drop_branch_hook(module, input, output):
if random.random() < self.drop_prob:
return output * 0 # 随机丢弃
return output方案4:随机丢弃边
训练时随机丢弃边,防止过于依赖某些连接:
class DARTSWithEdgeDropout(nn.Module):
def __init__(self, drop_prob=0.2):
self.drop_prob = drop_prob
def forward(self, x, edge_weights, edge_idx):
weights = F.softmax(edge_weights, dim=-1)
if self.training and random.random() < self.drop_prob:
# 随机选择一个非skip连接的操作
non_skip_weights = weights.clone()
non_skip_weights[skip_idx] = 0
probs = non_skip_weights / non_skip_weights.sum()
best_op = torch.multinomial(non_skip_weights, 1).item()
else:
best_op = weights.argmax().item()
return OPS[best_op](x)4. DARTS变体
4.1 P-DARTS
论文:Progressive Differentiable Architecture Search2
核心贡献:
- 渐进增加搜索深度
- 边级别dropout防止skip主导
- 两阶段训练:搜索阶段 + 微调阶段
训练流程:
Stage 1 (浅层): 4个节点, 8个epochs
↓
Stage 2 (中层): 6个节点, 25个epochs
↓
Stage 3 (深层): 8个节点, 50个epochs
4.2 PC-DARTS
论文:PC-DARTS: Partial Channel Connections3
核心思想:部分通道连接减少显存。
其中仅随机选择 的通道进行操作连接。
4.3 DrNAS
论文:DrNAS: Dirichlet Neural Architecture Search4
核心思想:使用Dirichlet分布建模架构参数的随机性。
优势:
- 自然实现操作级别的dropout
- 避免过度拟合搜索空间
- 更好的泛化性能
4.4 OStr-DARTS
论文:One-Shot Neural Architecture Search with Proper Orthogonal Decomposition5
核心思想:使用主成分分析(PCA)减少架构参数过拟合。
class OStrDARTS:
def __init__(self, n_components=4):
self.n_components = n_components
def architecture_loss(self, alpha):
"""带正交正则的损失"""
# 架构参数的正交分解
U, S, Vt = torch.pca_lowrank(alpha, q=self.n_components)
# 正交正则化项
ortho_loss = ((alpha - U @ Vt @ alpha) ** 2).mean()
return self.task_loss + self.lambda_ortho * ortho_loss4.5 DARTS变体对比
| 变体 | 核心改进 | 论文年份 | 主要贡献 |
|---|---|---|---|
| P-DARTS | 渐进搜索 | 2019 | 防止skip主导 |
| PC-DARTS | 部分连接 | 2020 | 减少显存 |
| DrNAS | Dirichlet分布 | 2021 | 随机正则化 |
| OStr-DARTS | 正交分解 | 2020 | 减少过拟合 |
| ZO-DARTS++ | 零阶优化 | 2023 | 鲁棒性增强 |
5. 完整训练流程
5.1 训练代码
class DARTSearcher:
def __init__(self, model, criterion, optimizer, scheduler):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = scheduler
def step(self, input, target):
"""一次搜索迭代"""
# ====== Step 1: 更新权重 ======
self.optimizer.zero_grad()
output = self.model(input, 'alphas_train')
loss = self.criterion(output, target)
loss.backward()
# 梯度裁剪防止权重爆炸
torch.nn.utils.clip_grad_norm_(self.model.weight_parameters(), 5.0)
self.optimizer.step()
# ====== Step 2: 更新架构参数 ======
# 使用验证集更新架构
self.optimizer.zero_grad()
output = self.model(input, 'alphas_val')
loss = self.criterion(output, target)
# 一阶近似: 直接反向传播
loss.backward()
self.optimizer.arch_optimizer.step()
return loss.item()
def train(self, train_loader, val_loader, epochs):
"""完整训练流程"""
for epoch in range(epochs):
# 训练权重
self.model.train()
for input, target in train_loader:
self.step(input, target)
# 验证
self.model.eval()
with torch.no_grad():
acc = self._evaluate(val_loader)
self.scheduler.step()
# 定期导出架构
if (epoch + 1) % 50 == 0:
arch = self.model.get_discrete_architecture()
print(f"Epoch {epoch+1}: Acc={acc:.2f}%")
print(f"Architecture: {arch}")
def _evaluate(self, loader):
"""评估模型"""
correct = 0
total = 0
for input, target in loader:
output = self.model(input)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
return 100 * correct / total5.2 架构导出
def derive_discrete_architecture(model, X_small, device='cuda'):
"""从连续参数导出离散架构"""
model.eval()
# 方法1: 直接选择最大权重
def get_best_op(weights):
return weights.argmax().item()
# 方法2: 基于小批量数据的Gumbel-Softmax
def get_gumbel_op(weights, temp=1.0, hard=True):
gumbels = -torch.empty_like(weights).exponential_().log()
logits = (weights.log_softmax(dim=-1) + gumbels) / temp
probs = F.softmax(logits, dim=-1)
if hard:
one_hot = F.one_hot(probs.argmax(dim=-1), num_classes=len(OPS))
return (one_hot.float() - probs).detach() + probs
return probs
# 构建离散架构
arch = {
'nodes': [],
'edges': []
}
with torch.no_grad():
for i in range(model.num_nodes):
for j in range(i + 1, model.num_nodes):
weights = model.edge_weights[i, j]
# op_idx = get_best_op(weights)
op_idx = get_gumbel_op(weights, temp=0.1).argmax().item()
arch['edges'].append({
'from': i,
'to': j,
'operation': OPS[op_idx]
})
return arch6. 理论分析
6.1 表达能力保证
DARTS搜索到的架构表达能力不低于任何单操作基线:
定理:对于任意离散架构 ,存在连续参数 使得:
其中 是架构 对应的操作。
证明:取 for ,其他为 ,则:
6.2 收敛性分析
双层优化的收敛性分析6:
假设:
- 和 关于 是 -光滑的
- 内层问题是强凸的
结论:使用一阶近似,DARTS收敛到:
其中 是内层优化的步长。
7. 实践注意事项
7.1 超参数设置
| 超参数 | 典型值 | 说明 |
|---|---|---|
| 搜索 epochs | 50 | 搜索阶段训练轮数 |
| 权重学习率 | 0.025 | 网络权重学习率 |
| 架构学习率 | 0.003 | 架构参数学习率 |
| 权重衰减 | 权重正则化 | |
| DropPath | 0.3 | 路径丢弃概率 |
| Cutout | 16 | 数据增强 |
7.2 常见问题
| 问题 | 原因 | 解决方案 |
|---|---|---|
| Skip主导 | Skip连接过于优势 | P-DARTS, DropBranch |
| 过拟合 | 搜索空间过拟合 | OStr-DARTS, 正则化 |
| 显存爆炸 | 全连接混合操作 | PC-DARTS, 梯度检查点 |
| 不稳定 | 梯度估计误差 | 减小架构学习率 |
7.3 与其他方法对比
| 维度 | DARTS | RL-based NAS | 进化NAS |
|---|---|---|---|
| 搜索效率 | 高 | 低 | 低 |
| 实现复杂度 | 中 | 高 | 高 |
| 搜索质量 | 中 | 高 | 高 |
| 可解释性 | 好 | 差 | 中 |
| 跳过连接问题 | 存在 | 无 | 无 |
参考文献
Footnotes
-
Liu H, Simonyan K, Yang YM. DARTS: Differentiable Architecture Search. ICLR 2019. ↩
-
Chen X, Xie L, Wu J, Tian Q. Progressive Differentiable Architecture Search: Bridging the Depth Gap Between Search and Evaluation. ICCV 2019. ↩
-
Xu Y, Xie L, Zhang P, Chen X, Tian Q. PC-DARTS: Partial Channel Connections for Memory-Efficient Architecture Search. ICLR 2020. ↩
-
Zhou D, Zhou X, Zhang W, et al. DrNAS: Dirichlet Neural Architecture Search. ICLR 2021. ↩
-
Xiong Y, Yeh I, Liao SC, Yaser AF. OStr-DARTS: One-Shot Differentiable Architecture Search with Orthogonal Decomposition. AAAI 2021. ↩
-
Xie L, Chen X, Bi K, et al. Differentiable Architecture Search: Theory and Practice. JMLR 2022. ↩