硬件感知NAS
硬件感知神经网络架构搜索(Hardware-Aware NAS)1旨在联合优化模型精度与硬件效率(延迟、能耗、功耗),解决”搜索-部署差距”问题。
1. 问题背景
1.1 搜索-部署差距
传统NAS只优化精度,忽视硬件特性:
┌─────────────────────────────────────────────────────────────┐
│ 传统NAS流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 搜索阶段: 部署阶段: │
│ ┌─────────┐ ┌─────────┐ │
│ │ 搜索精度 │ → │ 实际延迟 │ │
│ │ 最优 │ │ 爆炸 │ │
│ └─────────┘ └─────────┘ │
│ ↓ ↓ │
│ FLOPs/Params最优 GPU/CPU/移动端差异 │
│ │
└─────────────────────────────────────────────────────────────┘
问题根源:
- 硬件差异:不同设备有不同特性(GPU并行度、内存带宽、缓存层次)
- 库优化差异:cuDNN优化Conv2d但不完全优化depthwise conv
- 量化差异:FP32/INT8/INT4表现差异大
1.2 硬件特性分析
| 硬件 | 优化操作 | 瓶颈 | 适合架构 |
|---|---|---|---|
| NVIDIA GPU | cuDNN卷积 | 内存带宽 | 大通道数Conv |
| ARM CPU | NEON向量化 | 指令级并行 | Depthwise Conv |
| 移动端GPU | Mali GPU | 填充率 | 低分辨率特征 |
| NPU | Winograd加速 | 数据布局 | Winograd友好 |
| DSP | 定点运算 | 精度损失 | INT8量化网络 |
2. 延迟预测器
2.1 问题定义
目标函数:
或加权优化:
2.2 延迟预测方法分类
| 方法 | 描述 | 优势 | 劣势 |
|---|---|---|---|
| 查表法 | 预计算每种操作延迟 | 精确 | 难以建模操作组合 |
| 解析模型 | 理论计算FLOPs/内存访问 | 可解释 | 不考虑硬件优化 |
| 学习模型 | 神经网络预测延迟 | 灵活 | 需要训练数据 |
| 混合方法 | 查找表+学习校正 | 平衡 | 实现复杂 |
2.3 ProxylessNAS:直接搜索
论文:ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware2
核心思想:不使用代理任务,直接在目标硬件上搜索。
二值化架构参数
将连续松弛的架构参数二值化:
使用Gumbel-Softmax近似:
其中 。
硬件感知梯度
延迟梯度:
class ProxylessNAS(nn.Module):
"""ProxylessNAS硬件感知架构搜索"""
def __init__(self, num_blocks, candidate_ops):
super().__init__()
self.num_blocks = num_blocks
self.candidate_ops = candidate_ops
# 每块的架构参数
self.arch_params = nn.ParameterList([
nn.Parameter(torch.randn(len(candidate_ops)))
for _ in range(num_blocks)
])
# 操作模块
self.ops = nn.ModuleList([
self._build_ops(C_in, C_out, stride)
for _ in range(num_blocks)
])
def forward(self, x):
for i, (arch_param, op_list) in enumerate(
zip(self.arch_params, self.ops)
):
# Gumbel-Softmax采样
probs = F.gumbel_softmax(arch_param, tau=1.0, dim=-1)
# 加权和
out = sum(p * op(x) for p, op in zip(probs, op_list))
x = out
return x
def get_latency(self, x):
"""估算延迟"""
latency = 0
for i, (arch_param, op_list) in enumerate(
zip(self.arch_params, self.ops)
):
# 选择最可能的操作
op_idx = arch_param.argmax().item()
op = op_list[op_idx]
# 测量或估算该操作的延迟
latency += lookup_latency(op, x.shape)
# 更新x的形状
x = op(x)
return latency
def hardware_aware_loss(self, x, target, lambda_lat=0.1):
"""硬件感知损失"""
output = self.forward(x)
acc_loss = F.cross_entropy(output, target)
# 延迟惩罚
latency = self.get_latency(x)
latency_loss = lambda_lat * torch.log(latency + 1)
return acc_loss + latency_loss2.4 延迟查找表
class LatencyLookupTable:
"""延迟查找表"""
def __init__(self):
self.table = self._build_table()
def _build_table(self):
"""构建延迟表"""
table = {}
# 预定义操作延迟(单位:微秒)
base_ops = {
# Conv2d: (out_channels, kernel_size, stride, groups)
('conv', (32, 3, 1, 1)): 120, # 32x3x3, stride=1
('conv', (32, 3, 2, 1)): 180, # stride=2更慢
('conv', (64, 3, 1, 1)): 240,
('conv', (64, 3, 2, 1)): 300,
('dw_conv', (32, 3, 1, 32)): 50, # Depthwise
('dw_conv', (32, 3, 2, 32)): 80,
('pool', ('max', 3, 2)): 30,
('pool', ('avg', 3, 2)): 40,
('skip', (1, 1)): 5, # Skip connection
}
# 考虑输入分辨率的影响
for resolution in [112, 224, 336]:
table[f'res_{resolution}'] = {}
for op, base_lat in base_ops.items():
# 分辨率越大,延迟越高(近似线性)
scale = (resolution / 224) ** 2
table[f'res_{resolution}'][op] = base_lat * scale
return table
def predict(self, op_type, params, resolution):
"""预测操作延迟"""
key = (op_type, params)
if f'res_{resolution}' in self.table:
return self.table[f'res_{resolution}'].get(key, 100)
# 近似计算
return 100 * (resolution / 224) ** 2
def get_architecture_latency(self, arch, input_shape):
"""预测整个架构的延迟"""
total_latency = 0
resolution = input_shape[2]
for block in arch.blocks:
op = block['op']
params = block['params']
latency = self.predict(op, params, resolution)
total_latency += latency
# 更新分辨率(stride变化)
if 'stride' in block and block['stride'] == 2:
resolution //= 2
return total_latency3. 多目标优化
3.1 Pareto前沿
当同时优化精度和延迟时,Pareto前沿定义:
3.2 NSGA-III for NAS
class MultiObjectiveNAS:
"""多目标NAS框架"""
def __init__(self, supernet, objectives=['acc', 'latency', 'params']):
self.supernet = supernet
self.objectives = objectives
self.reference_points = self._generate_reference_points()
def _generate_reference_points(self):
"""生成参考点"""
# 均匀分布在目标空间
n_obj = len(self.objectives)
ref_points = []
for i in range(n_obj):
# 单目标最优
point = [0] * n_obj
point[i] = 1
ref_points.append(point)
# 对角线点
for _ in range(10):
point = [random.random() for _ in range(n_obj)]
ref_points.append(point)
return torch.tensor(ref_points)
def evaluate(self, arch):
"""评估架构"""
metrics = {}
# 准确率(需要训练或使用代理)
metrics['acc'] = self._predict_accuracy(arch)
# 延迟
metrics['latency'] = self._predict_latency(arch)
# 参数量
metrics['params'] = self._count_params(arch)
return metrics
def _predict_accuracy(self, arch):
"""使用零代价代理预测准确率"""
from zero_cost_nas import AZNASScorer
scorer = AZNASScorer(self.supernet, self.train_loader)
return scorer.final_score(arch)
def _predict_latency(self, arch):
"""预测延迟"""
latency_table = LatencyLookupTable()
return latency_table.get_architecture_latency(arch, input_shape)
def _count_params(self, arch):
"""计算参数量"""
...
def nsga_iii_select(self, population, n_select):
"""NSGA-III选择"""
# 评估所有个体
fitness = [self.evaluate(arch) for arch in population]
# 归一化目标
normalized = self._normalize_objectives(fitness)
# 关联到参考点
associations = self._associate_to_reference(normalized)
# 基于参考点选择
selected = self._select_based_on_reference(
population, associations, n_select
)
return selected
def _normalize_objectives(self, fitness):
"""归一化目标"""
normalized = {}
for obj in self.objectives:
values = [f[obj] for f in fitness]
min_v, max_v = min(values), max(values)
normalized[obj] = [
(v - min_v) / (max_v - min_v + 1e-10)
for v in values
]
return normalized3.3 帕累托最优选择
def select_pareto_optimal(archs, metrics, objective_weights=None):
"""选择Pareto最优架构"""
if objective_weights is None:
objective_weights = {obj: 1.0 for obj in metrics[0].keys()}
pareto_front = []
for i, (arch, m) in enumerate(zip(archs, metrics)):
is_dominated = False
for j, (other_arch, other_m) in enumerate(zip(archs, metrics)):
if i == j:
continue
# 检查是否被支配
dominated = True
strictly_better = False
for obj, weight in objective_weights.items():
if weight * m[obj] > weight * other_m[obj]:
dominated = False
break
if weight * m[obj] < weight * other_m[obj]:
strictly_better = True
if dominated and strictly_better:
is_dominated = True
break
if not is_dominated:
pareto_front.append((arch, m))
return pareto_front4. Once-for-All网络
4.1 核心思想
Once-for-All(OFA)3提出弹性网络概念:
训练一个超网络,支持多种配置(深度、宽度、分辨率),按需提取子网络。
┌─────────────────────────────────────────────────────────────┐
│ Once-for-All 超网络 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────┐ │
│ │ 共享权重超网络 │ │
│ │ │ │
│ │ ┌────┐ ┌────┐ ┌────┐ │ │
│ │ │Block│ │Block│ │Block│ │ │
│ │ └──┬─┘ └──┬─┘ └──┬─┘ │ │
│ │ ↓ ↓ ↓ │ │
│ │ 弹性深度/宽度/分辨率 │ │
│ └──────────┬───────────────┘ │
│ ↓ │
│ ┌─────────────┬─────────────┬─────────────┐ │
│ ↓ ↓ ↓ ↓ │
│ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ │
│ │大模型│ │中模型│ │小模型│ │超小 │ │
│ │(GPU) │ │(移动) │ │(边缘) │ │(嵌入式)│ │
│ └──────┘ └──────┘ └──────┘ └──────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
4.2 弹性维度
| 弹性维度 | 变化范围 | 策略 |
|---|---|---|
| 深度 | 选择前个块 | |
| 宽度 | 选择前个通道 | |
| 分辨率 | 可变输入分辨率 |
4.3 Progressive Shrinking
训练策略:
┌─────────────────────────────────────────────────────────────┐
│ Progressive Shrinking 训练流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Stage 1: 训练最大网络 (d=D_max, w=W_max, r=R_max) │
│ ↓ │
│ Stage 2: 微调子网络 (d<D_max, w<W_max) │
│ ↓ │
│ Stage 3: 继续微调更小子网络 │
│ ↓ │
│ Stage 4: 最小网络 (d=D_min, w=W_min, r=R_min) │
│ │
└─────────────────────────────────────────────────────────────┘
4.4 OFA PyTorch实现
class OnceForAll(nn.Module):
"""Once-for-All网络"""
def __init__(self,
depth_list=[2, 3, 4],
width_list=[0.65, 0.8, 1.0],
resolution_list=[160, 176, 192, 208, 224]):
super().__init__()
self.depth_list = depth_list
self.width_list = width_list
self.resolution_list = resolution_list
# 构建超网络
self.blocks = nn.ModuleList([
ElasticConvBlock(elastic_channels=width_list),
ElasticDepthBlock(depth_list=depth_list),
ElasticConvBlock(elastic_channels=width_list),
...
])
self.resolution_embed = ResolutionEmbedding(resolution_list)
def forward(self, x, depth=None, width=None, resolution=None):
"""弹性前向传播"""
# 弹性分辨率
if resolution is not None:
x = self.resolution_embed(x, resolution)
# 弹性块
for i, block in enumerate(self.blocks):
if isinstance(block, ElasticDepthBlock):
# 弹性深度:选择前d个块
num_active = depth if depth is not None else max(self.depth_list)
if i >= num_active:
break
x = block(x, num_active)
elif isinstance(block, ElasticConvBlock):
# 弹性宽度:选择前w个通道
x = block(x, width)
else:
x = block(x)
return x
def extract_subnet(self, depth, width, resolution):
"""提取指定配置的子网络"""
subnet = copy.deepcopy(self)
for block in subnet.blocks:
if hasattr(block, 'extract_subnet'):
block.extract_subnet(width)
return subnet4.5 延迟约束搜索
class OFASearcher:
"""OFA搜索器"""
def __init__(self, ofa_net, latency_predictor):
self.ofa_net = ofa_net
self.predictor = latency_predictor
def search(self, latency_budget, n_candidates=1000):
"""在延迟约束下搜索最优配置"""
candidates = []
# 生成候选配置
for depth in self.ofa_net.depth_list:
for width in self.ofa_net.width_list:
for res in self.ofa_net.resolution_list:
candidates.append({
'depth': depth,
'width': width,
'resolution': res
})
# 过滤延迟约束
valid_candidates = []
for config in candidates:
latency = self.predictor.predict(
depth=config['depth'],
width=config['width'],
resolution=config['resolution']
)
if latency <= latency_budget:
config['latency'] = latency
valid_candidates.append(config)
# 评估精度(使用零代价代理或微调)
for config in valid_candidates:
accuracy = self._evaluate_config(config)
config['accuracy'] = accuracy
# 返回Pareto最优
return sorted(valid_candidates,
key=lambda x: x['accuracy'],
reverse=True)5. 边缘部署实践
5.1 部署平台特性
| 平台 | 特性 | 推荐架构 |
|---|---|---|
| NVIDIA Jetson | CUDA加速,GPU共享内存 | 大型网络 |
| Apple Neural Engine | 专用NPU,CoreML | MobileNetV3, EfficientNet |
| Google Edge TPU | INT8优化,175TOPS | MobileNetV2+量化 |
| Qualcomm DSP | Hexagon HVX/HTA | 小型CNN |
| ARM Cortex-M | 极度受限,FPU可选 | MicroNet, MobileNet-Tiny |
5.2 量化感知NAS
class QuantizationAwareNAS(nn.Module):
"""量化感知的NAS"""
def __init__(self):
self.weight_bits = [2, 4, 8] # 搜索权重量化位数
self.act_bits = [4, 8] # 搜索激活量化位数
def forward(self, x, w_bits, a_bits):
"""量化感知前向"""
# 权重量化
weight = quantize(self.weight, w_bits)
# 激活量化
x = quantize(x, a_bits)
# 前向计算
return F.conv2d(x, weight, ...)5.3 完整部署流程
┌─────────────────────────────────────────────────────────────┐
│ 硬件感知NAS完整流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 定义搜索空间 │
│ - 操作集(Conv, DWConv, Skip等) │
│ - 深度范围(1-4块) │
│ - 宽度范围(0.65-1.0) │
│ ↓ │
│ 2. 构建延迟预测器 │
│ - 目标硬件上测量操作延迟 │
│ - 建立查找表或训练神经网络 │
│ ↓ │
│ 3. 多目标搜索(精度+延迟) │
│ - NSGA-III或加权求和 │
│ - 生成Pareto前沿 │
│ ↓ │
│ 4. 子网络提取(可选:Once-for-All) │
│ - 弹性深度/宽度/分辨率 │
│ ↓ │
│ 5. 训练与微调 │
│ - 知识蒸馏(如使用OFA) │
│ ↓ │
│ 6. 量化与部署 │
│ - INT8量化(PTQ/QAT) │
│ - 编译为目标平台(TensorRT/TVM/CoreML) │
│ │
└─────────────────────────────────────────────────────────────┘
参考文献
Footnotes
-
Wu B, Dai X, Zhang P, et al. FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable Neural Architecture Search. CVPR 2019. ↩
-
Cai H, Zhu LQ, Han S. ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware. ICLR 2019. ↩
-
Cai H, Gan CZ, Wang TS, Maetschke A, Yan SC. Once-for-All: Train One Network and Specialize It for Efficient Deployment. ICLR 2020. ↩