硬件感知NAS深度解析
1. 概述
硬件感知NAS(Hardware-Aware NAS)将硬件特性(延迟、能耗、内存、吞吐量)纳入架构搜索的优化目标1。随着深度学习在移动端、边缘设备和数据中心的大规模部署,硬件效率变得与准确率同等重要。
核心挑战:
- 多目标优化:准确率 vs 延迟 vs 能耗 vs 内存
- 硬件差异性:不同设备特性差异巨大
- 延迟建模困难:精确测量成本高昂
2. Jet-Nemotron方法
2.1 方法背景
论文: Gu et al., “Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search” (NeurIPS 2025)
核心贡献: PostNAS方法,从预训练的完整注意力模型出发,搜索混合SSM-Attention架构。
2.2 PostNAS核心思想
创新点: 不从随机初始化开始搜索,而是在预训练模型基础上进行架构优化。
传统NAS流程:
PostNAS流程:
2.3 Jet-Nemotron架构设计
混合架构组成:
| 组件 | 配置 |
|---|---|
| 基础模型 | Nemotron-8B |
| SSM层 | Mamba层替换 |
| Attention层 | 选择性保留 |
| MLP层 | SwiGLU激活 |
搜索决策变量:
对于每一层,决定:
- :是否使用SSM替代
- 如果使用SSM:选择SSM配置
2.4 搜索算法
class PostNAS:
def __init__(self, base_model):
self.model = base_model
self.decisions = {} # 每层的决策
def search(self, target_latency):
"""贪心搜索 + 延迟预算"""
layers = self.model.layers
remaining_budget = target_latency
for i, layer in enumerate(layers):
# 计算替换前后的延迟
latency_before = self.measure_latency(layer)
# 尝试SSM替换
ssr_layer = convert_to_ssm(layer)
latency_after = self.measure_latency(ssr_layer)
# 评估准确率影响
acc_drop = self.estimate_acc_drop(layer, ssr_layer)
# 决策
if latency_after < latency_before and acc_drop < threshold:
self.decisions[i] = 'ssm'
else:
self.decisions[i] = 'attention'
return self.apply_decisions()2.5 性能结果
吞吐量提升:
| 模型 | 准确率(Pile) | 吞吐量(Tokens/s) | 提升 |
|---|---|---|---|
| Nemotron-8B | 18.2 | 45 | 1x |
| Jet-Nemotron-8B | 18.1 | 180 | 4x |
内存效率:
| 模型 | 激活内存 | KV Cache | 总内存 |
|---|---|---|---|
| Full Attention | 100% | 100% | 100% |
| Hybrid (1:1) | 55% | 50% | 60% |
| Hybrid (1:3) | 40% | 25% | 45% |
3. Quasar-ViT方法
3.1 方法背景
论文: Li et al., “Quasar-ViT: Hardware-Oriented Quantization-Aware Architecture Search for Vision Transformers”
核心贡献: 联合优化架构和量化策略,针对ViT的硬件感知NAS。
3.2 量化感知搜索空间
量化配置搜索:
| 配置 | 位宽 | 搜索选项 |
|---|---|---|
| 权重 | 4-8 bit | 逐层选择 |
| 激活 | 4-8 bit | 逐层选择 |
| 注意力 | 8-16 bit | 固定 |
联合搜索变量:
3.3 延迟模型
分段线性延迟模型:
其中是参数量,是量化位宽。
4. 硬件感知NAS框架
4.1 多目标优化问题
硬件感知NAS可形式化为:
Pareto前沿: 所有无法在不牺牲其他目标的情况下改进任一目标的解的集合。
4.2 延迟建模方法
查找表方法
| 操作类型 | 延迟(cycles) |
|---|---|
| MatMul 64x64 | 1000 |
| MatMul 128x128 | 4500 |
| Attention 512 | 15000 |
| Softmax | 2000 |
神经网络预测器
class LatencyPredictor(nn.Module):
def __init__(self):
self.layers = nn.Sequential(
nn.Linear(num_features, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1) # 预测延迟
)
def forward(self, arch_features):
return self.layers(arch_features)4.3 能耗建模
能耗分解:
DRAM访问能耗:
5. 硬件感知搜索算法
5.1 约束优化方法
拉格朗日松弛:
5.2 进化多目标优化
NSGA-II应用于NAS:
class HardwareAwareNSGA:
def __init__(self, pop_size=50):
self.pop_size = pop_size
def evolve(self, objectives=['acc', 'latency']):
population = self.initialize_population()
for generation in range(num_generations):
# 评估
fitness = []
for ind in population:
acc = self.evaluate_accuracy(ind)
lat = self.predict_latency(ind)
fitness.append([acc, -lat]) # 最大化acc,最小化latency
# 快速非支配排序
fronts = self.non_dominated_sort(fitness)
# 选择与变异
population = self.select_and_mutate(fronts)
return self.extract_pareto_front(fitness)5.3 梯度基方法
可微延迟近似:
class DifferentiableHardwareNAS:
def forward_with_latency(self, x, alpha):
# 软化延迟近似
latencies = torch.tensor([get_latency(op) for op in self.ops])
alpha_soft = F.softmax(alpha, dim=-1)
# 期望延迟
expected_latency = (alpha_soft * latencies).sum()
return output, expected_latency
def loss(self, output, target, latency, max_latency):
acc_loss = self.criterion(output, target)
latency_loss = F.relu(latency - max_latency)
return acc_loss + self.lambda_latency * latency_loss6. CIMNAS方法
6.1 计算内存架构
论文: “CIMNAS: Compute-In-Memory-Aware NAS” (arXiv 2509.25862)
核心场景: 面向存算一体(CIM)架构的NAS。
6.2 CIM特性建模
CIM操作能耗:
其中是权重矩阵的非零元素数量。
6.3 稀疏性感知搜索
搜索目标:
7. 硬件平台特性
7.1 平台对比
| 平台 | 算力 | 内存带宽 | 能效 | 特点 |
|---|---|---|---|---|
| NVIDIA A100 | 312 TFLOPS | 2TB/s | 低 | 通用GPU |
| NVIDIA H100 | 989 TFLOPS | 3.35TB/s | 中 | Hopper架构 |
| Apple M4 | 38 TOPS | 100GB/s | 高 | 移动端 |
| Google TPU | 275 TFLOPS | 900GB/s | 中 | 云端推理 |
| FPGA | 可变 | 中 | 高 | 定制化 |
7.2 操作级延迟差异
| 操作 | CPU延迟(ms) | GPU延迟(ms) | NPU延迟(ms) |
|---|---|---|---|
| 3x3 Conv | 5.2 | 0.1 | 0.3 |
| Attention | 50.0 | 0.8 | 2.5 |
| LayerNorm | 0.8 | 0.05 | 0.1 |
| Softmax | 1.5 | 0.1 | 0.2 |
8. 实践指南
8.1 搜索策略选择
| 场景 | 推荐方法 | 原因 |
|---|---|---|
| 快速部署 | PostNAS | 利用预训练模型 |
| 精确优化 | 进化算法 | 全局搜索 |
| 端侧部署 | 约束梯度方法 | 快速收敛 |
| 存算一体 | CIMNAS | 专门优化 |
8.2 延迟测量策略
class LatencyMeasurer:
def __init__(self, device):
self.device = device
self.warmup_runs = 10
self.measure_runs = 100
def measure(self, model, input_shape):
model.eval()
# Warmup
with torch.no_grad():
for _ in range(self.warmup_runs):
x = torch.randn(input_shape).to(self.device)
_ = model(x)
# 测量
torch.cuda.synchronize() if self.device == 'cuda' else None
times = []
with torch.no_grad():
for _ in range(self.measure_runs):
start = time.perf_counter()
x = torch.randn(input_shape).to(self.device)
_ = model(x)
torch.cuda.synchronize() if self.device == 'cuda' else None
end = time.perf_counter()
times.append(end - start)
return np.median(times)8.3 超参数配置
hardware_aware_config = {
'search': {
'algorithm': 'nsga_iii',
'population_size': 64,
'generations': 50,
'mutation_rate': 0.1,
'crossover_rate': 0.8,
},
'objectives': {
'accuracy': {'weight': 1.0, 'target': None},
'latency': {'weight': 0.5, 'target': '10ms'},
'energy': {'weight': 0.3, 'target': '100mJ'},
},
'constraints': {
'max_latency': '50ms',
'max_memory': '100MB',
'min_accuracy': '75%',
}
}9. Jet-Nemotron实现详解
9.1 模型架构
class JetNemotronBlock(nn.Module):
def __init__(self, hidden_size, mode='attention'):
super().__init__()
self.mode = mode
if mode == 'attention':
self.attn = Attention(hidden_size)
self.mlp = MLP(hidden_size)
else: # SSM mode
self.ssm = MambaBlock(hidden_size)
self.mlp = MLP(hidden_size)
self.norm1 = RMSNorm(hidden_size)
self.norm2 = RMSNorm(hidden_size)
def forward(self, x):
if self.mode == 'attention':
x = x + self.attn(self.norm1(x))
else:
x = x + self.ssm(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x9.2 搜索算法伪代码
Algorithm: Jet-Nemotron PostNAS
Input: Pretrained full-attention model M, target latency L
Output: Hybrid model M*
1. Initialize: decisions = []
2. For each layer l in M.layers:
a. if l.type == 'attention':
- latency_attn = measure_latency(l)
- latency_ssm = measure_latency(convert_to_ssm(l))
- acc_drop = estimate_acc_drop(l, convert_to_ssm(l))
- if latency_ssm < latency_attn and acc_drop < θ:
decisions.append('ssm')
else:
decisions.append('attention')
3. M* = apply_conversions(M, decisions)
4. Fine-tune M* on target data
5. Return M*
10. 相关主题
参考文献
Footnotes
-
Gu, Y., et al. (2025). Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search. NeurIPS 2025. ↩