硬件感知NAS

硬件感知神经网络架构搜索(Hardware-Aware NAS)1旨在联合优化模型精度与硬件效率(延迟、能耗、功耗),解决”搜索-部署差距”问题。


1. 问题背景

1.1 搜索-部署差距

传统NAS只优化精度,忽视硬件特性:

┌─────────────────────────────────────────────────────────────┐
│                  传统NAS流程                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  搜索阶段:          部署阶段:                               │
│  ┌─────────┐       ┌─────────┐                           │
│  │ 搜索精度 │  →    │ 实际延迟 │                           │
│  │ 最优     │       │ 爆炸     │                           │
│  └─────────┘       └─────────┘                           │
│       ↓                  ↓                                 │
│  FLOPs/Params最优    GPU/CPU/移动端差异                    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

问题根源

  1. 硬件差异:不同设备有不同特性(GPU并行度、内存带宽、缓存层次)
  2. 库优化差异:cuDNN优化Conv2d但不完全优化depthwise conv
  3. 量化差异:FP32/INT8/INT4表现差异大

1.2 硬件特性分析

硬件优化操作瓶颈适合架构
NVIDIA GPUcuDNN卷积内存带宽大通道数Conv
ARM CPUNEON向量化指令级并行Depthwise Conv
移动端GPUMali GPU填充率低分辨率特征
NPUWinograd加速数据布局Winograd友好
DSP定点运算精度损失INT8量化网络

2. 延迟预测器

2.1 问题定义

目标函数:

或加权优化:

2.2 延迟预测方法分类

方法描述优势劣势
查表法预计算每种操作延迟精确难以建模操作组合
解析模型理论计算FLOPs/内存访问可解释不考虑硬件优化
学习模型神经网络预测延迟灵活需要训练数据
混合方法查找表+学习校正平衡实现复杂

2.3 ProxylessNAS:直接搜索

论文:ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware2

核心思想:不使用代理任务,直接在目标硬件上搜索。

二值化架构参数

将连续松弛的架构参数二值化:

使用Gumbel-Softmax近似:

其中

硬件感知梯度

延迟梯度:

class ProxylessNAS(nn.Module):
    """ProxylessNAS硬件感知架构搜索"""
    def __init__(self, num_blocks, candidate_ops):
        super().__init__()
        self.num_blocks = num_blocks
        self.candidate_ops = candidate_ops
        
        # 每块的架构参数
        self.arch_params = nn.ParameterList([
            nn.Parameter(torch.randn(len(candidate_ops)))
            for _ in range(num_blocks)
        ])
        
        # 操作模块
        self.ops = nn.ModuleList([
            self._build_ops(C_in, C_out, stride)
            for _ in range(num_blocks)
        ])
    
    def forward(self, x):
        for i, (arch_param, op_list) in enumerate(
            zip(self.arch_params, self.ops)
        ):
            # Gumbel-Softmax采样
            probs = F.gumbel_softmax(arch_param, tau=1.0, dim=-1)
            
            # 加权和
            out = sum(p * op(x) for p, op in zip(probs, op_list))
            x = out
        
        return x
    
    def get_latency(self, x):
        """估算延迟"""
        latency = 0
        for i, (arch_param, op_list) in enumerate(
            zip(self.arch_params, self.ops)
        ):
            # 选择最可能的操作
            op_idx = arch_param.argmax().item()
            op = op_list[op_idx]
            
            # 测量或估算该操作的延迟
            latency += lookup_latency(op, x.shape)
            
            # 更新x的形状
            x = op(x)
        
        return latency
    
    def hardware_aware_loss(self, x, target, lambda_lat=0.1):
        """硬件感知损失"""
        output = self.forward(x)
        acc_loss = F.cross_entropy(output, target)
        
        # 延迟惩罚
        latency = self.get_latency(x)
        latency_loss = lambda_lat * torch.log(latency + 1)
        
        return acc_loss + latency_loss

2.4 延迟查找表

class LatencyLookupTable:
    """延迟查找表"""
    def __init__(self):
        self.table = self._build_table()
    
    def _build_table(self):
        """构建延迟表"""
        table = {}
        
        # 预定义操作延迟(单位:微秒)
        base_ops = {
            # Conv2d: (out_channels, kernel_size, stride, groups)
            ('conv', (32, 3, 1, 1)): 120,      # 32x3x3, stride=1
            ('conv', (32, 3, 2, 1)): 180,      # stride=2更慢
            ('conv', (64, 3, 1, 1)): 240,
            ('conv', (64, 3, 2, 1)): 300,
            ('dw_conv', (32, 3, 1, 32)): 50,  # Depthwise
            ('dw_conv', (32, 3, 2, 32)): 80,
            ('pool', ('max', 3, 2)): 30,
            ('pool', ('avg', 3, 2)): 40,
            ('skip', (1, 1)): 5,                # Skip connection
        }
        
        # 考虑输入分辨率的影响
        for resolution in [112, 224, 336]:
            table[f'res_{resolution}'] = {}
            for op, base_lat in base_ops.items():
                # 分辨率越大,延迟越高(近似线性)
                scale = (resolution / 224) ** 2
                table[f'res_{resolution}'][op] = base_lat * scale
        
        return table
    
    def predict(self, op_type, params, resolution):
        """预测操作延迟"""
        key = (op_type, params)
        
        if f'res_{resolution}' in self.table:
            return self.table[f'res_{resolution}'].get(key, 100)
        
        # 近似计算
        return 100 * (resolution / 224) ** 2
    
    def get_architecture_latency(self, arch, input_shape):
        """预测整个架构的延迟"""
        total_latency = 0
        resolution = input_shape[2]
        
        for block in arch.blocks:
            op = block['op']
            params = block['params']
            
            latency = self.predict(op, params, resolution)
            total_latency += latency
            
            # 更新分辨率(stride变化)
            if 'stride' in block and block['stride'] == 2:
                resolution //= 2
        
        return total_latency

3. 多目标优化

3.1 Pareto前沿

当同时优化精度和延迟时,Pareto前沿定义:

3.2 NSGA-III for NAS

class MultiObjectiveNAS:
    """多目标NAS框架"""
    def __init__(self, supernet, objectives=['acc', 'latency', 'params']):
        self.supernet = supernet
        self.objectives = objectives
        self.reference_points = self._generate_reference_points()
    
    def _generate_reference_points(self):
        """生成参考点"""
        # 均匀分布在目标空间
        n_obj = len(self.objectives)
        ref_points = []
        
        for i in range(n_obj):
            # 单目标最优
            point = [0] * n_obj
            point[i] = 1
            ref_points.append(point)
        
        # 对角线点
        for _ in range(10):
            point = [random.random() for _ in range(n_obj)]
            ref_points.append(point)
        
        return torch.tensor(ref_points)
    
    def evaluate(self, arch):
        """评估架构"""
        metrics = {}
        
        # 准确率(需要训练或使用代理)
        metrics['acc'] = self._predict_accuracy(arch)
        
        # 延迟
        metrics['latency'] = self._predict_latency(arch)
        
        # 参数量
        metrics['params'] = self._count_params(arch)
        
        return metrics
    
    def _predict_accuracy(self, arch):
        """使用零代价代理预测准确率"""
        from zero_cost_nas import AZNASScorer
        scorer = AZNASScorer(self.supernet, self.train_loader)
        return scorer.final_score(arch)
    
    def _predict_latency(self, arch):
        """预测延迟"""
        latency_table = LatencyLookupTable()
        return latency_table.get_architecture_latency(arch, input_shape)
    
    def _count_params(self, arch):
        """计算参数量"""
        ...
    
    def nsga_iii_select(self, population, n_select):
        """NSGA-III选择"""
        # 评估所有个体
        fitness = [self.evaluate(arch) for arch in population]
        
        # 归一化目标
        normalized = self._normalize_objectives(fitness)
        
        # 关联到参考点
        associations = self._associate_to_reference(normalized)
        
        # 基于参考点选择
        selected = self._select_based_on_reference(
            population, associations, n_select
        )
        
        return selected
    
    def _normalize_objectives(self, fitness):
        """归一化目标"""
        normalized = {}
        for obj in self.objectives:
            values = [f[obj] for f in fitness]
            min_v, max_v = min(values), max(values)
            normalized[obj] = [
                (v - min_v) / (max_v - min_v + 1e-10) 
                for v in values
            ]
        return normalized

3.3 帕累托最优选择

def select_pareto_optimal(archs, metrics, objective_weights=None):
    """选择Pareto最优架构"""
    if objective_weights is None:
        objective_weights = {obj: 1.0 for obj in metrics[0].keys()}
    
    pareto_front = []
    
    for i, (arch, m) in enumerate(zip(archs, metrics)):
        is_dominated = False
        
        for j, (other_arch, other_m) in enumerate(zip(archs, metrics)):
            if i == j:
                continue
            
            # 检查是否被支配
            dominated = True
            strictly_better = False
            
            for obj, weight in objective_weights.items():
                if weight * m[obj] > weight * other_m[obj]:
                    dominated = False
                    break
                if weight * m[obj] < weight * other_m[obj]:
                    strictly_better = True
            
            if dominated and strictly_better:
                is_dominated = True
                break
        
        if not is_dominated:
            pareto_front.append((arch, m))
    
    return pareto_front

4. Once-for-All网络

4.1 核心思想

Once-for-All(OFA)3提出弹性网络概念:

训练一个超网络,支持多种配置(深度、宽度、分辨率),按需提取子网络。

┌─────────────────────────────────────────────────────────────┐
│                     Once-for-All 超网络                       │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│              ┌──────────────────────────┐                   │
│              │      共享权重超网络       │                   │
│              │                          │                   │
│              │  ┌────┐ ┌────┐ ┌────┐  │                   │
│              │  │Block│ │Block│ │Block│  │                   │
│              │  └──┬─┘ └──┬─┘ └──┬─┘  │                   │
│              │     ↓      ↓      ↓     │                   │
│              │  弹性深度/宽度/分辨率    │                   │
│              └──────────┬───────────────┘                   │
│                         ↓                                   │
│    ┌─────────────┬─────────────┬─────────────┐             │
│    ↓             ↓             ↓             ↓             │
│ ┌──────┐    ┌──────┐    ┌──────┐    ┌──────┐            │
│ │大模型│    │中模型│    │小模型│    │超小  │            │
│ │(GPU) │    │(移动) │    │(边缘) │    │(嵌入式)│           │
│ └──────┘    └──────┘    └──────┘    └──────┘            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

4.2 弹性维度

弹性维度变化范围策略
深度选择前个块
宽度选择前个通道
分辨率可变输入分辨率

4.3 Progressive Shrinking

训练策略:

┌─────────────────────────────────────────────────────────────┐
│              Progressive Shrinking 训练流程                   │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│ Stage 1: 训练最大网络 (d=D_max, w=W_max, r=R_max)          │
│          ↓                                                   │
│ Stage 2: 微调子网络 (d<D_max, w<W_max)                     │
│          ↓                                                   │
│ Stage 3: 继续微调更小子网络                                 │
│          ↓                                                   │
│ Stage 4: 最小网络 (d=D_min, w=W_min, r=R_min)             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

4.4 OFA PyTorch实现

class OnceForAll(nn.Module):
    """Once-for-All网络"""
    def __init__(self, 
                 depth_list=[2, 3, 4],
                 width_list=[0.65, 0.8, 1.0],
                 resolution_list=[160, 176, 192, 208, 224]):
        super().__init__()
        
        self.depth_list = depth_list
        self.width_list = width_list
        self.resolution_list = resolution_list
        
        # 构建超网络
        self.blocks = nn.ModuleList([
            ElasticConvBlock(elastic_channels=width_list),
            ElasticDepthBlock(depth_list=depth_list),
            ElasticConvBlock(elastic_channels=width_list),
            ...
        ])
        
        self.resolution_embed = ResolutionEmbedding(resolution_list)
    
    def forward(self, x, depth=None, width=None, resolution=None):
        """弹性前向传播"""
        # 弹性分辨率
        if resolution is not None:
            x = self.resolution_embed(x, resolution)
        
        # 弹性块
        for i, block in enumerate(self.blocks):
            if isinstance(block, ElasticDepthBlock):
                # 弹性深度:选择前d个块
                num_active = depth if depth is not None else max(self.depth_list)
                if i >= num_active:
                    break
                x = block(x, num_active)
            elif isinstance(block, ElasticConvBlock):
                # 弹性宽度:选择前w个通道
                x = block(x, width)
            else:
                x = block(x)
        
        return x
    
    def extract_subnet(self, depth, width, resolution):
        """提取指定配置的子网络"""
        subnet = copy.deepcopy(self)
        
        for block in subnet.blocks:
            if hasattr(block, 'extract_subnet'):
                block.extract_subnet(width)
        
        return subnet

4.5 延迟约束搜索

class OFASearcher:
    """OFA搜索器"""
    def __init__(self, ofa_net, latency_predictor):
        self.ofa_net = ofa_net
        self.predictor = latency_predictor
    
    def search(self, latency_budget, n_candidates=1000):
        """在延迟约束下搜索最优配置"""
        candidates = []
        
        # 生成候选配置
        for depth in self.ofa_net.depth_list:
            for width in self.ofa_net.width_list:
                for res in self.ofa_net.resolution_list:
                    candidates.append({
                        'depth': depth,
                        'width': width,
                        'resolution': res
                    })
        
        # 过滤延迟约束
        valid_candidates = []
        for config in candidates:
            latency = self.predictor.predict(
                depth=config['depth'],
                width=config['width'],
                resolution=config['resolution']
            )
            
            if latency <= latency_budget:
                config['latency'] = latency
                valid_candidates.append(config)
        
        # 评估精度(使用零代价代理或微调)
        for config in valid_candidates:
            accuracy = self._evaluate_config(config)
            config['accuracy'] = accuracy
        
        # 返回Pareto最优
        return sorted(valid_candidates, 
                     key=lambda x: x['accuracy'], 
                     reverse=True)

5. 边缘部署实践

5.1 部署平台特性

平台特性推荐架构
NVIDIA JetsonCUDA加速,GPU共享内存大型网络
Apple Neural Engine专用NPU,CoreMLMobileNetV3, EfficientNet
Google Edge TPUINT8优化,175TOPSMobileNetV2+量化
Qualcomm DSPHexagon HVX/HTA小型CNN
ARM Cortex-M极度受限,FPU可选MicroNet, MobileNet-Tiny

5.2 量化感知NAS

class QuantizationAwareNAS(nn.Module):
    """量化感知的NAS"""
    def __init__(self):
        self.weight_bits = [2, 4, 8]  # 搜索权重量化位数
        self.act_bits = [4, 8]         # 搜索激活量化位数
    
    def forward(self, x, w_bits, a_bits):
        """量化感知前向"""
        # 权重量化
        weight = quantize(self.weight, w_bits)
        
        # 激活量化
        x = quantize(x, a_bits)
        
        # 前向计算
        return F.conv2d(x, weight, ...)

5.3 完整部署流程

┌─────────────────────────────────────────────────────────────┐
│                  硬件感知NAS完整流程                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│ 1. 定义搜索空间                                             │
│    - 操作集(Conv, DWConv, Skip等)                         │
│    - 深度范围(1-4块)                                     │
│    - 宽度范围(0.65-1.0)                                  │
│    ↓                                                        │
│ 2. 构建延迟预测器                                          │
│    - 目标硬件上测量操作延迟                                  │
│    - 建立查找表或训练神经网络                                 │
│    ↓                                                        │
│ 3. 多目标搜索(精度+延迟)                                  │
│    - NSGA-III或加权求和                                     │
│    - 生成Pareto前沿                                        │
│    ↓                                                        │
│ 4. 子网络提取(可选:Once-for-All)                         │
│    - 弹性深度/宽度/分辨率                                   │
│    ↓                                                        │
│ 5. 训练与微调                                              │
│    - 知识蒸馏(如使用OFA)                                  │
│    ↓                                                        │
│ 6. 量化与部署                                              │
│    - INT8量化(PTQ/QAT)                                   │
│    - 编译为目标平台(TensorRT/TVM/CoreML)                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

参考文献

Footnotes

  1. Wu B, Dai X, Zhang P, et al. FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable Neural Architecture Search. CVPR 2019.

  2. Cai H, Zhu LQ, Han S. ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware. ICLR 2019.

  3. Cai H, Gan CZ, Wang TS, Maetschke A, Yan SC. Once-for-All: Train One Network and Specialize It for Efficient Deployment. ICLR 2020.