DARTS可微架构搜索

DARTS(Differentiable Architecture Search)1是神经架构搜索领域的里程碑工作,首次将离散架构搜索问题转化为连续可微优化问题,实现了高效的端到端架构搜索。


1. 核心思想

DARTS的核心创新是连续松弛(Continuous Relaxation):将离散的架构选择转化为连续可微的操作权重。

1.1 离散问题形式化

原始问题为组合优化:

其中 是架构 的最优权重。

1.2 连续松弛

将操作选择从离散集合松弛为所有操作的加权和:

其中:

  • 是边 的架构参数
  • 操作权重通过 归一化

1.3 混合操作可视化

原始离散选择:              连续松弛后:
     conv_3x3                   
         │                   ┌─────────────────────────┐
         ▼                   │  α₁·conv_3x3           │
      conv_5x5                │  α₂·conv_5x5           │
         │                   │  α₃·max_pool          │
         ▼                   │  α₄·skip_connect       │
      skip_connect             └──────────┬──────────────┘
                                         │
                                         ▼
                                   加权和输出

2. 双层优化框架

2.1 优化目标

DARTS求解以下双层优化问题:

其中:

  • 上层变量 :架构参数
  • 下层变量 :网络权重

2.2 梯度计算

直接计算 需要在内层求解优化问题,计算代价高昂。使用链式法则近似:

其中 是内层优化的学习率。

一阶近似:当 时,

二阶近似:当 时,

2.3 PyTorch实现

class DARTSController(nn.Module):
    """DARTS架构控制器"""
    def __init__(self, num_nodes, num_ops, C_in, C_out):
        super().__init__()
        self.num_nodes = num_nodes
        self.num_ops = num_ops
        
        # 边操作权重: (num_ops) for each edge
        # edges: (i,j) where i < j, total = num_nodes * (num_nodes-1) // 2
        self.edge_weights = nn.Parameter(
            torch.randn(num_nodes, num_nodes, num_ops) * 1e-3
        )
    
    def mixed_op(self, x, edge_idx):
        """混合操作: 所有操作的加权和"""
        # 归一化权重
        weights = F.softmax(self.edge_weights[edge_idx], dim=-1)
        
        # 各操作的输出
        op_outs = []
        for op_idx, op in enumerate(OPS):
            if op == 'skip_connect' or op == 'none':
                op_outs.append(x)
            elif op == 'max_pool':
                op_outs.append(F.max_pool2d(x, kernel_size=3, stride=1, padding=1))
            elif op == 'avg_pool':
                op_outs.append(F.avg_pool2d(x, kernel_size=3, stride=1, padding=1))
            elif 'sep_conv' in op:
                # 可分离卷积
                op_outs.append(self._sep_conv(x, op))
            elif 'conv' in op:
                # 标准卷积
                op_outs.append(self._conv(x, op))
        
        # 加权和
        return sum(w * out for w, out in zip(weights, op_outs))
    
    def forward(self, input_nodes):
        """前向传播"""
        nodes = list(input_nodes)  # [node_0, node_1, ...]
        
        for node_idx in range(1, self.num_nodes):
            # 聚合所有前驱节点的信息
            aggregated = 0
            for prev_idx in range(node_idx):
                edge_weight = self.edge_weights[prev_idx, node_idx]
                mixed_out = self.mixed_op(nodes[prev_idx], (prev_idx, node_idx))
                aggregated = aggregated + mixed_out
            
            # ReLU激活(跟在BatchNorm之后)
            nodes.append(F.relu(aggregated))
        
        return nodes[-1]
    
    def get_discrete_architecture(self):
        """导出离散架构"""
        arch = []
        for i in range(self.num_nodes):
            for j in range(i + 1, self.num_nodes):
                edge_weights = self.edge_weights[i, j]
                best_op_idx = edge_weights.argmax().item()
                arch.append((i, j, best_op_idx))
        return arch

3. Skip连接问题

3.1 问题分析

DARTS的一个关键问题是**skip连接主导(Skip Connection Dominance)**现象。

在双层优化中,skip连接因为:

  1. 梯度流直接,训练稳定
  2. 不引入额外参数
  3. 允许信息直接传递

导致架构参数收敛到skip连接占主导的状态,产生性能崩塌(Performance Collapse)

3.2 原因分析

从优化角度,skip连接的损失景观更加平滑:

skip连接不改变特征维度,避免了卷积操作引入的维度变换损失。

3.3 解决方案

方案1:P-DARTS(Progressive DARTS)

核心思想:渐进增加搜索深度,同时限制skip连接数量。

class PDARTS(nn.Module):
    def __init__(self):
        self.skip_connect_thresh = 2  # 每阶段最多2个skip连接
    
    def penalize_skip_connect(self, edge_weights):
        """惩罚skip连接权重"""
        # 将skip连接的权重降低
        skip_idx = OPS.index('skip_connect')
        penalized_weights = edge_weights.clone()
        penalized_weights[skip_idx] *= 0.5
        return F.softmax(penalized_weights, dim=-1)

方案2:DARTS+(DARTS with Warm Start)

核心思想:在搜索初期使用更多的卷积操作作为热身。

方案3:DropBranch

核心思想:随机丢弃部分skip连接进行正则化。

def drop_branch_hook(module, input, output):
    if random.random() < self.drop_prob:
        return output * 0  # 随机丢弃
    return output

方案4:随机丢弃边

训练时随机丢弃边,防止过于依赖某些连接:

class DARTSWithEdgeDropout(nn.Module):
    def __init__(self, drop_prob=0.2):
        self.drop_prob = drop_prob
    
    def forward(self, x, edge_weights, edge_idx):
        weights = F.softmax(edge_weights, dim=-1)
        
        if self.training and random.random() < self.drop_prob:
            # 随机选择一个非skip连接的操作
            non_skip_weights = weights.clone()
            non_skip_weights[skip_idx] = 0
            probs = non_skip_weights / non_skip_weights.sum()
            best_op = torch.multinomial(non_skip_weights, 1).item()
        else:
            best_op = weights.argmax().item()
        
        return OPS[best_op](x)

4. DARTS变体

4.1 P-DARTS

论文:Progressive Differentiable Architecture Search2

核心贡献

  1. 渐进增加搜索深度
  2. 边级别dropout防止skip主导
  3. 两阶段训练:搜索阶段 + 微调阶段

训练流程

Stage 1 (浅层): 4个节点, 8个epochs
       ↓
Stage 2 (中层): 6个节点, 25个epochs
       ↓
Stage 3 (深层): 8个节点, 50个epochs

4.2 PC-DARTS

论文:PC-DARTS: Partial Channel Connections3

核心思想:部分通道连接减少显存。

其中仅随机选择 的通道进行操作连接。

4.3 DrNAS

论文:DrNAS: Dirichlet Neural Architecture Search4

核心思想:使用Dirichlet分布建模架构参数的随机性。

优势

  1. 自然实现操作级别的dropout
  2. 避免过度拟合搜索空间
  3. 更好的泛化性能

4.4 OStr-DARTS

论文:One-Shot Neural Architecture Search with Proper Orthogonal Decomposition5

核心思想:使用主成分分析(PCA)减少架构参数过拟合。

class OStrDARTS:
    def __init__(self, n_components=4):
        self.n_components = n_components
    
    def architecture_loss(self, alpha):
        """带正交正则的损失"""
        # 架构参数的正交分解
        U, S, Vt = torch.pca_lowrank(alpha, q=self.n_components)
        
        # 正交正则化项
        ortho_loss = ((alpha - U @ Vt @ alpha) ** 2).mean()
        
        return self.task_loss + self.lambda_ortho * ortho_loss

4.5 DARTS变体对比

变体核心改进论文年份主要贡献
P-DARTS渐进搜索2019防止skip主导
PC-DARTS部分连接2020减少显存
DrNASDirichlet分布2021随机正则化
OStr-DARTS正交分解2020减少过拟合
ZO-DARTS++零阶优化2023鲁棒性增强

5. 完整训练流程

5.1 训练代码

class DARTSearcher:
    def __init__(self, model, criterion, optimizer, scheduler):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
    
    def step(self, input, target):
        """一次搜索迭代"""
        # ====== Step 1: 更新权重 ======
        self.optimizer.zero_grad()
        output = self.model(input, 'alphas_train')
        loss = self.criterion(output, target)
        loss.backward()
        
        # 梯度裁剪防止权重爆炸
        torch.nn.utils.clip_grad_norm_(self.model.weight_parameters(), 5.0)
        self.optimizer.step()
        
        # ====== Step 2: 更新架构参数 ======
        # 使用验证集更新架构
        self.optimizer.zero_grad()
        output = self.model(input, 'alphas_val')
        loss = self.criterion(output, target)
        
        # 一阶近似: 直接反向传播
        loss.backward()
        
        self.optimizer.arch_optimizer.step()
        
        return loss.item()
    
    def train(self, train_loader, val_loader, epochs):
        """完整训练流程"""
        for epoch in range(epochs):
            # 训练权重
            self.model.train()
            for input, target in train_loader:
                self.step(input, target)
            
            # 验证
            self.model.eval()
            with torch.no_grad():
                acc = self._evaluate(val_loader)
            
            self.scheduler.step()
            
            # 定期导出架构
            if (epoch + 1) % 50 == 0:
                arch = self.model.get_discrete_architecture()
                print(f"Epoch {epoch+1}: Acc={acc:.2f}%")
                print(f"Architecture: {arch}")
    
    def _evaluate(self, loader):
        """评估模型"""
        correct = 0
        total = 0
        for input, target in loader:
            output = self.model(input)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
        return 100 * correct / total

5.2 架构导出

def derive_discrete_architecture(model, X_small, device='cuda'):
    """从连续参数导出离散架构"""
    model.eval()
    
    # 方法1: 直接选择最大权重
    def get_best_op(weights):
        return weights.argmax().item()
    
    # 方法2: 基于小批量数据的Gumbel-Softmax
    def get_gumbel_op(weights, temp=1.0, hard=True):
        gumbels = -torch.empty_like(weights).exponential_().log()
        logits = (weights.log_softmax(dim=-1) + gumbels) / temp
        probs = F.softmax(logits, dim=-1)
        if hard:
            one_hot = F.one_hot(probs.argmax(dim=-1), num_classes=len(OPS))
            return (one_hot.float() - probs).detach() + probs
        return probs
    
    # 构建离散架构
    arch = {
        'nodes': [],
        'edges': []
    }
    
    with torch.no_grad():
        for i in range(model.num_nodes):
            for j in range(i + 1, model.num_nodes):
                weights = model.edge_weights[i, j]
                # op_idx = get_best_op(weights)
                op_idx = get_gumbel_op(weights, temp=0.1).argmax().item()
                arch['edges'].append({
                    'from': i,
                    'to': j,
                    'operation': OPS[op_idx]
                })
    
    return arch

6. 理论分析

6.1 表达能力保证

DARTS搜索到的架构表达能力不低于任何单操作基线:

定理:对于任意离散架构 ,存在连续参数 使得:

其中 是架构 对应的操作。

证明:取 for ,其他为 ,则:

6.2 收敛性分析

双层优化的收敛性分析6

假设

  1. 关于 -光滑的
  2. 内层问题是强凸的

结论:使用一阶近似,DARTS收敛到:

其中 是内层优化的步长。


7. 实践注意事项

7.1 超参数设置

超参数典型值说明
搜索 epochs50搜索阶段训练轮数
权重学习率0.025网络权重学习率
架构学习率0.003架构参数学习率
权重衰减权重正则化
DropPath0.3路径丢弃概率
Cutout16数据增强

7.2 常见问题

问题原因解决方案
Skip主导Skip连接过于优势P-DARTS, DropBranch
过拟合搜索空间过拟合OStr-DARTS, 正则化
显存爆炸全连接混合操作PC-DARTS, 梯度检查点
不稳定梯度估计误差减小架构学习率

7.3 与其他方法对比

维度DARTSRL-based NAS进化NAS
搜索效率
实现复杂度
搜索质量
可解释性
跳过连接问题存在

参考文献

Footnotes

  1. Liu H, Simonyan K, Yang YM. DARTS: Differentiable Architecture Search. ICLR 2019.

  2. Chen X, Xie L, Wu J, Tian Q. Progressive Differentiable Architecture Search: Bridging the Depth Gap Between Search and Evaluation. ICCV 2019.

  3. Xu Y, Xie L, Zhang P, Chen X, Tian Q. PC-DARTS: Partial Channel Connections for Memory-Efficient Architecture Search. ICLR 2020.

  4. Zhou D, Zhou X, Zhang W, et al. DrNAS: Dirichlet Neural Architecture Search. ICLR 2021.

  5. Xiong Y, Yeh I, Liao SC, Yaser AF. OStr-DARTS: One-Shot Differentiable Architecture Search with Orthogonal Decomposition. AAAI 2021.

  6. Xie L, Chen X, Bi K, et al. Differentiable Architecture Search: Theory and Practice. JMLR 2022.