硬件感知NAS深度解析

1. 概述

硬件感知NAS(Hardware-Aware NAS)将硬件特性(延迟、能耗、内存、吞吐量)纳入架构搜索的优化目标1。随着深度学习在移动端、边缘设备和数据中心的大规模部署,硬件效率变得与准确率同等重要。

核心挑战:

  • 多目标优化:准确率 vs 延迟 vs 能耗 vs 内存
  • 硬件差异性:不同设备特性差异巨大
  • 延迟建模困难:精确测量成本高昂

2. Jet-Nemotron方法

2.1 方法背景

论文: Gu et al., “Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search” (NeurIPS 2025)

核心贡献: PostNAS方法,从预训练的完整注意力模型出发,搜索混合SSM-Attention架构。

2.2 PostNAS核心思想

创新点: 不从随机初始化开始搜索,而是在预训练模型基础上进行架构优化。

传统NAS流程:

PostNAS流程:

2.3 Jet-Nemotron架构设计

混合架构组成:

组件配置
基础模型Nemotron-8B
SSM层Mamba层替换
Attention层选择性保留
MLP层SwiGLU激活

搜索决策变量:

对于每一层,决定:

  • :是否使用SSM替代
  • 如果使用SSM:选择SSM配置

2.4 搜索算法

class PostNAS:
    def __init__(self, base_model):
        self.model = base_model
        self.decisions = {}  # 每层的决策
    
    def search(self, target_latency):
        """贪心搜索 + 延迟预算"""
        layers = self.model.layers
        remaining_budget = target_latency
        
        for i, layer in enumerate(layers):
            # 计算替换前后的延迟
            latency_before = self.measure_latency(layer)
            
            # 尝试SSM替换
            ssr_layer = convert_to_ssm(layer)
            latency_after = self.measure_latency(ssr_layer)
            
            # 评估准确率影响
            acc_drop = self.estimate_acc_drop(layer, ssr_layer)
            
            # 决策
            if latency_after < latency_before and acc_drop < threshold:
                self.decisions[i] = 'ssm'
            else:
                self.decisions[i] = 'attention'
        
        return self.apply_decisions()

2.5 性能结果

吞吐量提升:

模型准确率(Pile)吞吐量(Tokens/s)提升
Nemotron-8B18.2451x
Jet-Nemotron-8B18.11804x

内存效率:

模型激活内存KV Cache总内存
Full Attention100%100%100%
Hybrid (1:1)55%50%60%
Hybrid (1:3)40%25%45%

3. Quasar-ViT方法

3.1 方法背景

论文: Li et al., “Quasar-ViT: Hardware-Oriented Quantization-Aware Architecture Search for Vision Transformers”

核心贡献: 联合优化架构和量化策略,针对ViT的硬件感知NAS。

3.2 量化感知搜索空间

量化配置搜索:

配置位宽搜索选项
权重4-8 bit逐层选择
激活4-8 bit逐层选择
注意力8-16 bit固定

联合搜索变量:

3.3 延迟模型

分段线性延迟模型:

其中是参数量,是量化位宽。


4. 硬件感知NAS框架

4.1 多目标优化问题

硬件感知NAS可形式化为:

Pareto前沿: 所有无法在不牺牲其他目标的情况下改进任一目标的解的集合。

4.2 延迟建模方法

查找表方法

操作类型延迟(cycles)
MatMul 64x641000
MatMul 128x1284500
Attention 51215000
Softmax2000

神经网络预测器

class LatencyPredictor(nn.Module):
    def __init__(self):
        self.layers = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # 预测延迟
        )
    
    def forward(self, arch_features):
        return self.layers(arch_features)

4.3 能耗建模

能耗分解:

DRAM访问能耗:


5. 硬件感知搜索算法

5.1 约束优化方法

拉格朗日松弛:

5.2 进化多目标优化

NSGA-II应用于NAS:

class HardwareAwareNSGA:
    def __init__(self, pop_size=50):
        self.pop_size = pop_size
    
    def evolve(self, objectives=['acc', 'latency']):
        population = self.initialize_population()
        
        for generation in range(num_generations):
            # 评估
            fitness = []
            for ind in population:
                acc = self.evaluate_accuracy(ind)
                lat = self.predict_latency(ind)
                fitness.append([acc, -lat])  # 最大化acc,最小化latency
            
            # 快速非支配排序
            fronts = self.non_dominated_sort(fitness)
            
            # 选择与变异
            population = self.select_and_mutate(fronts)
        
        return self.extract_pareto_front(fitness)

5.3 梯度基方法

可微延迟近似:

class DifferentiableHardwareNAS:
    def forward_with_latency(self, x, alpha):
        # 软化延迟近似
        latencies = torch.tensor([get_latency(op) for op in self.ops])
        alpha_soft = F.softmax(alpha, dim=-1)
        
        # 期望延迟
        expected_latency = (alpha_soft * latencies).sum()
        
        return output, expected_latency
    
    def loss(self, output, target, latency, max_latency):
        acc_loss = self.criterion(output, target)
        latency_loss = F.relu(latency - max_latency)
        
        return acc_loss + self.lambda_latency * latency_loss

6. CIMNAS方法

6.1 计算内存架构

论文: “CIMNAS: Compute-In-Memory-Aware NAS” (arXiv 2509.25862)

核心场景: 面向存算一体(CIM)架构的NAS。

6.2 CIM特性建模

CIM操作能耗:

其中是权重矩阵的非零元素数量。

6.3 稀疏性感知搜索

搜索目标:


7. 硬件平台特性

7.1 平台对比

平台算力内存带宽能效特点
NVIDIA A100312 TFLOPS2TB/s通用GPU
NVIDIA H100989 TFLOPS3.35TB/sHopper架构
Apple M438 TOPS100GB/s移动端
Google TPU275 TFLOPS900GB/s云端推理
FPGA可变定制化

7.2 操作级延迟差异

操作CPU延迟(ms)GPU延迟(ms)NPU延迟(ms)
3x3 Conv5.20.10.3
Attention50.00.82.5
LayerNorm0.80.050.1
Softmax1.50.10.2

8. 实践指南

8.1 搜索策略选择

场景推荐方法原因
快速部署PostNAS利用预训练模型
精确优化进化算法全局搜索
端侧部署约束梯度方法快速收敛
存算一体CIMNAS专门优化

8.2 延迟测量策略

class LatencyMeasurer:
    def __init__(self, device):
        self.device = device
        self.warmup_runs = 10
        self.measure_runs = 100
    
    def measure(self, model, input_shape):
        model.eval()
        
        # Warmup
        with torch.no_grad():
            for _ in range(self.warmup_runs):
                x = torch.randn(input_shape).to(self.device)
                _ = model(x)
        
        # 测量
        torch.cuda.synchronize() if self.device == 'cuda' else None
        
        times = []
        with torch.no_grad():
            for _ in range(self.measure_runs):
                start = time.perf_counter()
                x = torch.randn(input_shape).to(self.device)
                _ = model(x)
                torch.cuda.synchronize() if self.device == 'cuda' else None
                end = time.perf_counter()
                times.append(end - start)
        
        return np.median(times)

8.3 超参数配置

hardware_aware_config = {
    'search': {
        'algorithm': 'nsga_iii',
        'population_size': 64,
        'generations': 50,
        'mutation_rate': 0.1,
        'crossover_rate': 0.8,
    },
    'objectives': {
        'accuracy': {'weight': 1.0, 'target': None},
        'latency': {'weight': 0.5, 'target': '10ms'},
        'energy': {'weight': 0.3, 'target': '100mJ'},
    },
    'constraints': {
        'max_latency': '50ms',
        'max_memory': '100MB',
        'min_accuracy': '75%',
    }
}

9. Jet-Nemotron实现详解

9.1 模型架构

class JetNemotronBlock(nn.Module):
    def __init__(self, hidden_size, mode='attention'):
        super().__init__()
        self.mode = mode
        
        if mode == 'attention':
            self.attn = Attention(hidden_size)
            self.mlp = MLP(hidden_size)
        else:  # SSM mode
            self.ssm = MambaBlock(hidden_size)
            self.mlp = MLP(hidden_size)
        
        self.norm1 = RMSNorm(hidden_size)
        self.norm2 = RMSNorm(hidden_size)
    
    def forward(self, x):
        if self.mode == 'attention':
            x = x + self.attn(self.norm1(x))
        else:
            x = x + self.ssm(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

9.2 搜索算法伪代码

Algorithm: Jet-Nemotron PostNAS

Input: Pretrained full-attention model M, target latency L
Output: Hybrid model M*

1. Initialize: decisions = []
2. For each layer l in M.layers:
   a. if l.type == 'attention':
      - latency_attn = measure_latency(l)
      - latency_ssm = measure_latency(convert_to_ssm(l))
      - acc_drop = estimate_acc_drop(l, convert_to_ssm(l))
      
      - if latency_ssm < latency_attn and acc_drop < θ:
          decisions.append('ssm')
      else:
          decisions.append('attention')
   
3. M* = apply_conversions(M, decisions)
4. Fine-tune M* on target data
5. Return M*

10. 相关主题


参考文献

Footnotes

  1. Gu, Y., et al. (2025). Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search. NeurIPS 2025.