LLM推理优化

大语言模型(LLM)推理优化是部署高效AI服务的核心技术。与训练阶段不同,推理阶段面临独特的挑战:自回归生成导致的生成延迟累积、KV Cache带来的巨大内存压力、以及长上下文场景下的计算复杂度爆炸。

据2026年统计,企业AI推理成本已占AI基础设施支出的70%以上,有效的推理优化可将推理成本降低5-20倍,同时将吞吐量提升一个数量级。

1. KV Cache优化

1.1 KV Cache原理

自回归语言模型在生成第 个token时,需要attend到此前所有 个token的Key和Value向量。KV Cache通过缓存这些向量,避免在每一步重复计算:

设注意力机制为:

在推理时,Query随新token变化,但 可从缓存复用:

1.2 内存占用分析

KV Cache的内存占用是推理优化的核心瓶颈。

单层单token的KV内存

对于LLaMA-7B():

完整KV Cache(考虑所有层)

对于LLaMA-7B(32层,):

长上下文场景:上下文长度扩展到128K时,单个请求的KV Cache可达140GB,远超单卡显存。

1.3 PagedAttention:vLLM的内存管理

vLLM提出的PagedAttention1借鉴操作系统虚拟内存的思路,将KV Cache分页管理:

核心思想

  • 将KV Cache划分为固定大小的块(默认16个token/block)
  • 使用块表(Block Table)记录物理块与逻辑块的映射
  • 支持不连续的物理存储,提高内存利用率
class BlockTable:
    """块表:管理逻辑块到物理块的映射"""
    def __init__(self, block_size=16):
        self.block_size = block_size
        # 逻辑块 -> 物理块映射
        self.physical_blocks = {}  # physical_block_id -> tensor
        self.block_mapping = {}    # logical_idx -> physical_block_id
    
    def allocate(self, num_tokens):
        """分配num_tokens个token的物理块"""
        num_blocks = (num_tokens + self.block_size - 1) // self.block_size
        allocated = []
        
        for i in range(num_blocks):
            logical_block_id = len(self.block_mapping)
            # 检查是否需要新物理块
            if logical_block_id not in self.block_mapping:
                physical_block_id = self._get_free_block()
                self.block_mapping[logical_block_id] = physical_block_id
                self.physical_blocks[physical_block_id] = self._allocate_tensor()
            allocated.append(self.block_mapping[logical_block_id])
        
        return allocated
    
    def get_physical_block(self, token_idx):
        """获取token对应的物理块"""
        logical_block = token_idx // self.block_size
        offset = token_idx % self.block_size
        physical_block = self.block_mapping[logical_block]
        return physical_block, offset
 
 
class PagedAttention:
    """
    PagedAttention实现
    核心:支持非连续物理块的并行注意力计算
    """
    def __init__(self, block_size=16, num_heads=32, head_dim=128):
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim
    
    def forward(self, query, block_table, max_seq_len):
        """
        Args:
            query: (batch_size, num_heads, head_dim)
            block_table: 物理块列表
            max_seq_len: 最大序列长度
        """
        batch_size = query.shape[0]
        
        # 1. 收集物理块中的K和V
        k_cache = []
        v_cache = []
        
        for block_id in block_table:
            k_cache.append(self.get_block_k(block_id))
            v_cache.append(self.get_block_v(block_id))
        
        # 拼接
        k_full = torch.cat(k_cache, dim=2)  # (batch, heads, seq_len, head_dim)
        v_full = torch.cat(v_cache, dim=2)
        
        # 2. 标准attention计算
        scores = torch.einsum('bhd,bhnd->bhn', query, k_full)
        scores = scores / (self.head_dim ** 0.5)
        scores = F.softmax(scores, dim=-1)
        
        output = torch.einsum('bhn,bhnd->bhd', scores, v_full)
        
        return output

PagedAttention的内存效率

策略内存利用率并发请求数提升
连续分配~20-30%
PagedAttention~60-80%2-4×

1.4 缓存压缩

StreamingLLM

StreamingLLM2提出一种无需微调即可处理无限长度文本的方法。核心观察:attention sink现象——模型强烈attention到初始token,即使它们语义上不重要。

策略:保留4个”attention sink”token + 最近的局部window:

def streaming_llm_attention(query, key, value, sink_size=4, window_size=512):
    """
    StreamingLLM:固定大小的滑动窗口注意力
    保留: sink tokens + 最近window tokens
    """
    seq_len = key.shape[2]
    
    if seq_len <= sink_size + window_size:
        # 完整序列,直接计算
        return standard_attention(query, key, value)
    
    # 1. 取sink tokens
    k_sink = key[:, :, :sink_size, :]
    v_sink = value[:, :, :sink_size, :]
    
    # 2. 取最近window tokens
    k_window = key[:, :, -window_size:, :]
    v_window = value[:, :, -window_size:, :]
    
    # 3. 拼接后计算
    k_combined = torch.cat([k_sink, k_window], dim=2)
    v_combined = torch.cat([v_sink, v_window], dim=2)
    
    return standard_attention(query, k_combined, v_combined)

实验结果:StreamingLLM在40B模型上可处理400万token的连续输入,困惑度几乎不变。

近似缓存策略

H2O(Heavy-Hitter Oracle)3:学习哪些token是”heavy hitter”,只保留这些token的KV:

def h2o_cache_selection(kv_cache, importance_scores, budget):
    """
    H2O: 基于attention重要性选择保留的KV
    保留累积attention score最高的token
    """
    # importance_scores: 每个位置的历史attention累积
    scores = importance_scores.sum(dim=1)  # (seq_len,)
    
    # 选择top-k
    _, indices = torch.topk(scores, k=min(budget, len(scores)))
    indices = torch.sort(indices)[0]  # 保持顺序
    
    return kv_cache[:, :, indices, :], indices

1.5 跨请求KV复用

对于共享前缀的请求,KV Cache可以复用:

class PrefixCaching:
    """
    前缀缓存:复用相同prompt前缀的KV Cache
    """
    def __init__(self):
        # 哈希表存储已计算的prefix KV
        self.cache = {}  # prompt_hash -> (k_cache, v_cache)
    
    def get_cached_prefix(self, prompt_tokens):
        """获取已缓存的prefix KV"""
        prompt_hash = hash(tuple(prompt_tokens))
        return self.cache.get(prompt_hash)
    
    def cache_prefix(self, prompt_tokens, k_cache, v_cache):
        """缓存prefix KV"""
        prompt_hash = hash(tuple(prompt_tokens))
        self.cache[prompt_hash] = (k_cache, v_cache)
    
    def extend_with_cache(self, prompt_tokens, new_tokens):
        """从缓存的prefix扩展"""
        cached_kv = self.get_cached_prefix(prompt_tokens)
        if cached_kv is None:
            return None
        
        k_prefix, v_prefix = cached_kv
        
        # 计算新token的KV
        k_new, v_new = compute_kv(new_tokens, k_prefix, v_prefix)
        
        # 拼接
        return (
            torch.cat([k_prefix, k_new], dim=2),
            torch.cat([v_prefix, v_new], dim=2)
        )

2. 量化技术

2.1 INT8/FP16量化基础

量化将高精度数值映射到低精度表示:

对称量化

非对称量化

def symmetric_quantize(W, bits=8):
    """对称量化"""
    scale = W.abs().max() / (2**(bits-1) - 1)
    W_q = torch.round(W / scale)
    W_q = torch.clamp(W_q, -(2**(bits-1)), 2**(bits-1)-1)
    return W_q.to(torch.int8), scale
 
def asymmetric_quantize(W, bits=8):
    """非对称量化"""
    scale = (W.max() - W.min()) / (2**bits - 1)
    zero_point = torch.round(-W.min() / scale)
    W_q = torch.round(W / scale) + zero_point
    W_q = torch.clamp(W_q, 0, 2**bits-1)
    return W_q.to(torch.uint8), scale, zero_point

2.2 GPTQ:最优脑压缩

GPTQ4使用二阶信息( Hessian矩阵)指导量化,在4-bit下仍保持高精度。

核心算法

class GPTQQuantizer:
    """
    GPTQ: 基于OBC框架的量化器
    """
    def __init__(self, model, bits=4, group_size=128):
        self.model = model
        self.bits = bits
        self.group_size = group_size
    
    def quantize_layer(self, layer):
        W = layer.weight.data.clone()
        out_features, in_features = W.shape
        
        # 分组量化
        num_groups = in_features // self.group_size
        W_quant = torch.zeros_like(W, dtype=torch.int32)
        scales = torch.zeros(out_features, num_groups)
        
        for g in range(num_groups):
            start = g * self.group_size
            end = min((g+1) * self.group_size, in_features)
            W_g = W[:, start:end]
            
            # 计算该组的Hessian对角近似
            H_diag = (W_g ** 2).mean(dim=1, keepdim=True)
            H_diag = H_diag + 1e-8  # 数值稳定
            
            # 基于Hessian的缩放
            scale_g = torch.sqrt(H_diag / (2 ** (self.bits - 1)))
            W_g_quant = torch.round(W_g / scale_g)
            W_g_quant = torch.clamp(W_g_quant, -(2**(self.bits-1)), 2**(self.bits-1)-1)
            
            W_quant[:, start:end] = W_g_quant.to(torch.int32)
            scales[:, g] = scale_g.squeeze(-1)
        
        return W_quant, scales

2.3 AWQ:激活感知权重量化

AWQ5发现权重对量化的敏感性不同,保护敏感权重可减少误差:

def awq_search_scale(W, A, bits=4, alpha=0.5):
    """
    AWQ: 搜索最优缩放因子
    
    核心思想: 敏感权重用更大scale保护
    alpha控制敏感度权重
    """
    # 计算敏感性:基于激活值的权重重要性
    sensitivity = (W.abs() * A.abs().mean(dim=0)).mean(dim=1, keepdim=True)
    
    # 计算缩放因子
    # s = (|W| / max(|W|))^alpha
    w_abs_max = W.abs().max(dim=1, keepdim=True)[0]
    s = (W.abs() / (w_abs_max + 1e-8)).pow(alpha)
    
    return s
 
def awq_quantize(W, A, bits=4):
    """AWQ量化"""
    s = awq_search_scale(W, A)
    
    # 应用缩放
    W_scaled = W / s
    
    # 量化
    max_val = 2**(bits-1) - 1
    W_quant = torch.round(W_scaled)
    W_quant = torch.clamp(W_quant, -max_val, max_val)
    
    # 恢复scale
    W_dequant = W_quant * s
    
    return W_quant.to(torch.int8), s

2.4 GGUF/llama.cpp格式

GGUF是llama.cpp提出的量化格式,支持多种精度(Q2_K到Q8_K):

# GGUF量化级别
GGUF_TYPES = {
    "Q8_0": {"bits": 8, "block_size": 32, "type": "float"},
    "Q6_K": {"bits": 6, "block_size": 256, "type": "quantized"},
    "Q5_K_M": {"bits": 5, "block_size": 256, "type": "quantized"},
    "Q4_K_M": {"bits": 4, "block_size": 256, "type": "quantized"},
    "Q4_0": {"bits": 4, "block_size": 32, "type": "float"},
    "Q3_K_M": {"bits": 3, "block_size": 256, "type": "quantized"},
    "Q2_K": {"bits": 2, "block_size": 256, "type": "quantized"},
}
 
class GGUFQuantizer:
    """GGUF格式量化器"""
    def quantize(self, W, quant_type="Q4_K_M"):
        config = GGUF_TYPES[quant_type]
        block_size = config["block_size"]
        
        out_features, in_features = W.shape
        num_blocks = in_features // block_size
        
        # 量化每个block
        quantized_blocks = []
        scales = []
        
        for i in range(num_blocks):
            w_block = W[:, i*block_size:(i+1)*block_size]
            
            # 计算block scale
            scale = w_block.abs().max() / (2 ** (config["bits"] - 1))
            scales.append(scale)
            
            # 量化
            w_quant = torch.round(w_block / scale)
            w_quant = torch.clamp(w_quant, -(2**(config["bits"]-1)), 2**(config["bits"]-1)-1)
            
            quantized_blocks.append(w_quant.to(torch.int8))
        
        return quantized_blocks, torch.stack(scales)

2.5 FP8量化

FP8(8-bit Float)是Hopper架构(H100/H200)的新数据类型:

class FP8Quantizer:
    """
    FP8量化: H100/H200原生支持
    E4M3: 符号(1) + 指数(4) + 尾数(3) -> 高精度范围
    E5M2: 符号(1) + 指数(5) + 尾数(2) -> 高动态范围
    """
    @staticmethod
    def quantize_e4m3(W):
        """FP8 E4M3量化"""
        # 范围: [-448, 448]
        # 更适合权重
        W_clamped = torch.clamp(W, -448, 448)
        return W_clamped.to(torch.float8_e4m3fn)
    
    @staticmethod
    def quantize_e5m2(W):
        """FP8 E5M2量化"""
        # 范围: [-57344, 57344]
        # 更适合激活
        W_clamped = torch.clamp(W, -57344, 57344)
        return W_clamped.to(torch.float8_e5m2)
    
    def forward(self, x):
        # 权重用E4M3
        w_fp8 = self.quantize_e4m3(self.weight)
        # 激活用E5M2
        x_fp8 = self.quantize_e5m2(x)
        return F.linear(x_fp8, w_fp8)

FP8 vs INT8 对比

特性FP8 (E4M3)INT8
动态范围较小适中
精度较高依赖校准
硬件支持H100/H200原生通用
适用场景权重通用

3. 推测解码(Speculative Decoding)

3.1 基本原理

推测解码6使用小模型(Draft Model)快速生成候选token,再由大模型(Target Model)验证:

目标序列: [The, cat, sat, on, the, mat, ...]
                    ↓ Draft (快速生成)
猜测序列: [The, cat, sat, on, the, mat, and, purr, ...]
                    ↓ Target (并行验证)
接受序列: [The, cat, sat, on, the, mat, and] ✓
拒绝序列: [purr, ...] ✗ → 回退

加速比理论上限

设Draft接受率为 ,Draft生成 个token耗时 ,Target验证耗时

时,加速比接近

3.2 Draft Model设计原则

class DraftModelConfig:
    """
    Draft Model配置原则
    """
    # 1. 参数规模: Target的1/10 ~ 1/20
    DRAFT_SCALE_RATIO = 1 / 16  # 7B Target → 0.4B Draft
    
    # 2. 共享 embedding + LM head
    # 复用Target的embedding和output projection
    
    # 3. 相同vocab但结构简化
    def create_draft_model(target_model):
        draft = copy.deepcopy(target_model)
        
        # 减少层数
        draft.num_layers = target_model.num_layers // 4
        
        # 减小hidden dimension
        draft.hidden dim = target_model.hidden_dim // 2
        
        # 减少attention heads
        draft.num_heads = target_model.num_heads // 2
        
        return draft

3.3 接受率与加速比分析

def analyze_speculative_decoding(draft_accept_rates, t_d, t_t, k):
    """
    分析推测解码性能
    
    Args:
        draft_accept_rates: 每步的接受率列表
        t_d: Draft生成k个token耗时
        t_t: Target验证k个token耗时
        k: 每轮Draft生成的token数
    """
    results = []
    
    for alpha in draft_accept_rates:
        # 期望每轮生成的token数
        E_tokens = 1 + alpha * k
        
        # 期望耗时
        E_time = t_t + alpha * k * (t_d / k) + (1 - alpha) * t_t
        
        # 加速比 vs 自回归
        speedup = (k + 1) * t_t / (t_d + t_t)
        
        results.append({
            "accept_rate": alpha,
            "expected_tokens": E_tokens,
            "speedup": speedup
        })
    
    return results
 
 
# 示例分析
# 假设: t_d = 10ms (Draft 0.4B), t_t = 100ms (Target 7B)
results = analyze_speculative_decoding(
    draft_accept_rates=[0.5, 0.7, 0.8, 0.9, 0.95],
    t_d=10, t_t=100, k=4
)
# 输出:
# alpha=0.7: speedup=3.4x
# alpha=0.9: speedup=5.6x

3.4 EAGLE方法

EAGLE7(Self-Corrective EAGLE)使用自回归方式生成Draft,并引入校正机制:

class EAGLEDraft:
    """
    EAGLE: 逐token生成Draft,使用Target的hidden state校正
    """
    def __init__(self, target_model, draft_model):
        self.target = target_model
        self.draft = draft_model
        self.temperature = 0.0  # EAGLE通常用贪心
    
    def draft_forward(self, input_ids, past_kv=None):
        """Draft模型前向"""
        outputs = self.draft(
            input_ids=input_ids,
            past_key_values=past_kv,
            output_hidden_states=True
        )
        
        # 返回logits和hidden states用于校正
        return {
            "logits": outputs.logits[:, -1, :],
            "hidden": outputs.hidden_states[-1],
            "kv": outputs.past_key_values
        }
    
    def verify_and_advance(self, input_ids, draft_hidden, target_hidden, draft_kv):
        """
        验证Draft并决定是否接受
        EAGLE核心: 使用hidden state相似度判断
        """
        # 计算target在draft位置的实际hidden state
        target_outputs = self.target(
            input_ids=input_ids,
            past_key_values=target_kv,
            output_hidden_states=True
        )
        target_h = target_outputs.hidden_states[-1]
        
        # cosine相似度判断
        sim = F.cosine_similarity(draft_hidden, target_h, dim=-1)
        
        # 接受阈值
        accept = sim > 0.9
        
        return accept, target_h

3.5 Medusa方法

Medusa8使用多个独立的预测头并行生成:

class MedusaHead(nn.Module):
    """
    Medusa: 多预测头
    每个头预测未来特定位置的token
    """
    def __init__(self, hidden_size, vocab_size, depth):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.Relu(),
                nn.Linear(hidden_size, vocab_size)
            ) for _ in range(depth)
        ])
    
    def forward(self, hidden_state):
        """
        并行预测多个位置的token
        Returns: [(token_id, prob), ...] for each head
        """
        predictions = []
        for layer in self.layers:
            logits = layer(hidden_state)
            probs = F.softmax(logits, dim=-1)
            top_k = torch.topk(probs, k=5)
            predictions.append((top_k.indices, top_k.values))
        
        return predictions
 
 
class MedusaDecoding:
    """Medusa解码"""
    def __init__(self, base_model, medusa_heads):
        self.base = base_model
        self.medusa = medusa_heads
    
    def generate(self, input_ids, num_heads=5):
        """生成并验证"""
        # 1. Base model计算hidden state
        outputs = self.base(input_ids, output_hidden_states=True)
        hidden = outputs.hidden_states[-1]
        
        # 2. Medusa heads并行预测
        medusa_preds = self.medusa(hidden)
        
        # 3. 验证
        accepted = [input_ids[0, -1]]  # 第一个token总是base的预测
        all_hidden = hidden
        
        for head_idx, (token_ids, probs) in enumerate(medusa_preds):
            # 尝试接受每个候选
            for token_id in token_ids:
                # 用target model验证
                next_output = self.base(
                    torch.cat([input_ids, token_id.unsqueeze(0)], dim=1)
                )
                next_hidden = next_output.hidden_states[-1]
                
                # 验证: hidden state相似度 + 概率阈值
                if F.cosine_similarity(hidden, next_hidden) > 0.95:
                    accepted.append(token_id)
                    hidden = next_hidden
                    break
        
        return accepted

4. 批处理与并行

4.1 Continuous Batching

传统静态批处理要求批次内序列等长,导致大量气泡。Continuous Batching9实现真正的动态批处理:

class ContinuousBatcher:
    """
    Continuous Batching: 动态批次管理
    """
    def __init__(self, max_batch_size, max_seq_len):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        self.running_seqs = []  # 正在生成中的序列
        self.pending_seqs = []   # 等待调度的序列
    
    def add_request(self, request):
        """添加新请求"""
        self.pending_seqs.append({
            "id": request.id,
            "tokens": request.input_tokens,
            "max_new_tokens": request.max_new_tokens,
            "finished": False,
            "output_tokens": []
        })
    
    def step(self, model):
        """
        执行一步前向
        核心: 动态添加/移除序列
        """
        # 1. 填充批次
        batch = self._prepare_batch()
        
        # 2. 执行模型
        outputs = model(batch["input_ids"], batch["past_kv"])
        
        # 3. 处理结果
        for i, seq in enumerate(batch["sequences"]):
            if seq["finished"]:
                continue
            
            next_token = outputs.logits[i, -1].argmax()
            seq["output_tokens"].append(next_token)
            
            # 检查是否完成
            if (len(seq["output_tokens"]) >= seq["max_new_tokens"] or 
                next_token == EOS_TOKEN):
                seq["finished"] = True
                self._return_result(seq)
        
        # 4. 补充新请求
        self._refill_batch()
        
        return batch["sequences"]
    
    def _prepare_batch(self):
        """准备批次:合并running + 部分pending"""
        # 选择最多max_batch_size个序列
        selected = self.running_seqs[:self.max_batch_size]
        
        # 填充到相同长度
        max_len = max(len(s["tokens"]) + len(s["output_tokens"]) 
                      for s in selected)
        
        batch_input = []
        batch_kv = []
        batch_sequences = []
        
        for seq in selected:
            # 填充逻辑
            input_ids = self._pad_to(seq, max_len)
            batch_input.append(input_ids)
            batch_sequences.append(seq)
        
        return {
            "input_ids": torch.stack(batch_input),
            "past_kv": self._get_past_kv(selected),
            "sequences": batch_sequences
        }

4.2 前缀缓存

共享Prompt场景下的KV Cache复用:

class PrefixCache:
    """
    前缀缓存: 相同前缀的请求共享KV
    """
    def __init__(self, cache_size=1000):
        self.cache = {}  # prompt_hash -> KV cache
        self.access_count = {}  # 用于LRU淘汰
    
    def compute_hash(self, tokens):
        """计算prompt的哈希值"""
        return hashlib.sha256(tokens.tobytes()).hexdigest()
    
    def lookup(self, prompt_tokens):
        """查找缓存的前缀KV"""
        h = self.compute_hash(prompt_tokens)
        if h in self.cache:
            self.access_count[h] += 1
            return self.cache[h]
        return None
    
    def store(self, prompt_tokens, kv_cache):
        """存储前缀KV"""
        h = self.compute_hash(prompt_tokens)
        
        # LRU淘汰
        if len(self.cache) >= self.cache_size:
            min_access = min(self.access_count.values())
            evict_h = [k for k, v in self.access_count.items() 
                      if v == min_access][0]
            del self.cache[evict_h]
            del self.access_count[evict_h]
        
        self.cache[h] = kv_cache
        self.access_count[h] = 1
    
    def prefix_match(self, new_tokens, cached_tokens):
        """
        检查新tokens是否以缓存的prefix开头
        返回匹配长度
        """
        min_len = min(len(new_tokens), len(cached_tokens))
        match_len = 0
        
        for i in range(min_len):
            if new_tokens[i] == cached_tokens[i]:
                match_len += 1
            else:
                break
        
        return match_len

4.3 序列并行

Ring Attention

Ring Attention10将序列维度分片,多设备协同计算attention:

class RingAttention(nn.Module):
    """
    Ring Attention: 序列并行的attention实现
    将KV沿序列维度分片
    """
    def __init__(self, num_devices, dim_model, num_heads):
        self.num_devices = num_devices
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.head_dim = dim_model // num_heads
    
    def forward(self, q, k, v, ring_index):
        """
        Args:
            q: (batch, num_heads, seq_len//num_devices, head_dim)
            k, v: 同样分片
            ring_index: 当前设备在ring中的索引
        """
        device_rank = ring_index
        local_k = k
        local_v = v
        local_attn = torch.zeros_like(q)
        
        # Ring communication: 环形传递KV
        for step in range(self.num_devices):
            # 计算当前step的attention
            peer_rank = (device_rank + step) % self.num_devices
            
            # 本地Q与peer的K,V计算attention
            attn = self._scaled_dot_product(q, local_k, local_v)
            local_attn += attn
            
            # 准备下一步:Q轮转
            q = torch.roll(q, shifts=-1, dims=2)
            
            # KV传递给下一个peer (ring send/recv)
            local_k = self._ring_recv(local_k, peer_rank)
            local_v = self._ring_recv(local_v, peer_rank)
        
        return local_attn
    
    def _scaled_dot_product(self, q, k, v):
        """标准scaled dot-product attention"""
        d_k = q.shape[-1]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, v)
    
    def _ring_recv(self, tensor, from_rank):
        """Ring通信:接收来自peer的数据"""
        # 实际实现使用NCCL all-to-all
        return tensor  # 简化

Ulysses Attention

Ulysses Attention通过All-to-All通信集中计算:

class UlyssesAttention(nn.Module):
    """
    Ulysses Attention: 先聚集再分散
    适合短序列、高通信带宽场景
    """
    def forward(self, q, k, v, num_devices):
        # 1. All-to-All: 将Q,K,V沿sequence维度汇聚
        q_all = self.all_to_all(q)  # (batch, num_heads, full_seq, head_dim)
        k_all = self.all_to_all(k)
        v_all = self.all_to_all(v)
        
        # 2. 单设备计算完整attention
        attn = self._scaled_dot_product(q_all, k_all, v_all)
        
        # 3. All-to-All: 分散结果回各设备
        output = self.all_to_all(attn)
        
        return output
    
    def all_to_all(self, x):
        """All-to-All集合通信"""
        # 使用NCCL或torch.distributed
        return x  # 简化

4.4 Tensor并行推理

class TensorParallelInference:
    """
    Tensor并行推理: 将模型沿hidden维度分片
    """
    def __init__(self, model, num_devices):
        self.num_devices = num_devices
        self.model = self._shard_model(model)
    
    def _shard_model(self, model):
        """分片模型权重"""
        for name, param in model.named_parameters():
            if "weight" in name and param.ndim >= 2:
                # 沿输出维度分片
                dim = 0
                chunks = torch.chunk(param, self.num_devices, dim=dim)
                for i, chunk in enumerate(chunks):
                    # 发送到不同设备
                    param_register(f"{name}_shard_{i}", chunk.to(i))
        
        return model
    
    def forward(self, input_ids):
        """并行前向"""
        # 1. 分散输入
        input_shards = self._scatter_input(input_ids)
        
        # 2. 各设备计算自己分片的权重
        outputs = []
        for i, shard in enumerate(input_shards):
            out = self._compute_shard(shard, i)
            outputs.append(out)
        
        # 3. All-Reduce汇总输出
        output = self._all_reduce(outputs)
        
        return output
    
    def _compute_shard(self, input_shard, device_id):
        """计算单个分片"""
        # Linear层分片计算
        # y = x @ W^T
        # W被分片为 [W1, W2, ...]
        # y = [x @ W1^T, x @ W2^T, ...]
        return input_shard  # 简化

5. 长上下文推理

5.1 位置编码插值

位置编码外推是长上下文的关键挑战。RoPE的NTK-aware scaling11是有效方法:

class NTKAwareScaling:
    """
    NTK-aware Scaling: 无需微调的位置编码外推
    """
    @staticmethod
    def compute_scaling_factor(context_len, original_len, alpha=8):
        """
        计算NTK缩放因子
        
        核心思想: 动态调整不同频率成分
        高频成分(短距离)缩放小
        低频成分(长距离)缩放大
        """
        # base是RoPE的基础频率参数
        base = 10000
        
        # 计算等效的上下文扩展比例
        # 对于alpha倍扩展,使用alpha^(dim/dim_small)缩放
        scale = (original_len / context_len) ** (1 / alpha)
        
        return scale
    
    @staticmethod
    def apply_rope_with_scaling(q, k, position_ids, scaling_factor):
        """应用带缩放的RoPE"""
        # 原始RoPE角度
        inv_freq = 1.0 / (10000 ** (torch.arange(0, q.shape[-1], 2).float() / q.shape[-1]))
        
        # NTK缩放:调整频率
        inv_freq_scaled = inv_freq / scaling_factor
        
        # 计算旋转位置编码
        freqs = torch.outer(position_ids, inv_freq_scaled)
        freqs = torch.cat([freqs, freqs], dim=-1)
        
        # 应用旋转
        q_rot = self._rotate_half(q, freqs)
        k_rot = self._rotate_half(k, freqs)
        
        return q_rot, k_rot
    
    @staticmethod
    def _rotate_half(x, freqs):
        """旋转操作"""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        
        # 复数乘法: x * exp(i * theta)
        x_new = torch.cat([
            x1 * torch.cos(freqs) - x2 * torch.sin(freqs),
            x1 * torch.sin(freqs) + x2 * torch.cos(freqs)
        ], dim=-1)
        
        return x_new

5.2 稀疏注意力

Sliding Window Attention

class SlidingWindowAttention(nn.Module):
    """
    Sliding Window Attention: 局部注意力 + 全局tokens
    类似Mistral的Attention机制
    """
    def __init__(self, window_size=4096, num_global=256):
        self.window_size = window_size
        self.num_global = num_global
    
    def forward(self, q, k, v, is_global_token):
        """
        Args:
            is_global_token: 标记哪些token是全局的
        """
        seq_len = q.shape[2]
        
        # 1. 构建attention mask
        mask = torch.zeros(seq_len, seq_len, device=q.device)
        
        for i in range(seq_len):
            for j in range(seq_len):
                # 全局token可以attend到所有位置
                if is_global_token[i] or is_global_token[j]:
                    mask[i, j] = 0  # 允许
                # 非全局token只能attend到window内的位置
                elif abs(i - j) <= self.window_size:
                    mask[i, j] = 0
                else:
                    mask[i, j] = -inf
        
        # 2. 计算attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
        scores = scores + mask
        
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, v)

Hippo State Space

Hippo(HIstorical Pricing Persisting Optimization)基于状态空间模型处理长序列:

class HippoAttention(nn.Module):
    """
    Hippo: 近似全注意力的线性状态空间方法
    理论基础: LegS, Dimp乐等多项式投影
    """
    @staticmethod
    def compute_legendre_matrix(N, L):
        """
        LegS (Legendre Select) 矩阵
        用于将历史信息投影到N维状态
        """
        # Legendre多项式根计算
        # 实际实现使用scipy.special.legendre
        return torch.randn(N, L)  # 简化
    
    def forward(self, x, state_size=64):
        """
        将full attention近似为线性状态空间
        """
        seq_len = x.shape[2]
        
        # 1. 投影到状态空间
        A = self.compute_legendre_matrix(state_size, seq_len)
        
        # 2. 线性投影
        state = A @ x.transpose(1, 2)  # (batch, heads, state_size, head_dim)
        
        # 3. 状态转移 + 输出
        # 简化的SSM计算
        h = torch.zeros_like(state[:, :, 0])
        outputs = []
        
        for t in range(seq_len):
            h = state[:, :, t] + 0.9 * h  # 简化状态转移
            outputs.append(h)
        
        return torch.stack(outputs, dim=2)

5.3 内存高效注意力变体

class FlashAttention:
    """
    Flash Attention: IO-aware的高效attention实现
    核心: 分块计算,减少HBM访问
    """
    @staticmethod
    def forward(q, k, v, block_size=128):
        """
        Flash Attention前向
        通过tiling避免实现full N×N attention matrix
        """
        B, H, N, D = q.shape
        
        # 初始化
        output = torch.zeros(B, H, N, D, device=q.device, dtype=q.dtype)
        l = torch.zeros(B, H, N, device=q.device)  # 行累加
        m = torch.full((B, H, N), -float('inf'), device=q.device)  # 行最大值
        
        # 分块计算
        for j in range(0, N, block_size):
            # 加载K,V块
            k_block = k[:, :, j:j+block_size, :]
            v_block = v[:, :, j:j+block_size, :]
            
            for i in range(0, N, block_size):
                # 加载Q块
                q_block = q[:, :, i:i+block_size, :]
                
                # 计算S = Q @ K^T
                s = torch.matmul(q_block, k_block.transpose(-2, -1))
                
                # 考虑因果mask
                mask = torch.triu(torch.ones(i+block_size, j+block_size), 
                                  -j).to(q.device)
                s = s.masked_fill(mask == 0, -float('inf'))
                
                # 稳定softmax: P = exp(S - row_max)
                block_m = s.amax(dim=-1, keepdim=True)
                s = s - block_m
                p = torch.exp(s)
                
                # 更新
                alpha = torch.exp(m[:, :, i:i+block_size] - block_m)
                p = p * alpha.unsqueeze(-1)
                
                l_block = p.sum(dim=-1)
                
                # 更新输出
                output_block = torch.matmul(p, v_block)
                
                output[:, :, i:i+block_size, :] = (
                    alpha.unsqueeze(-1) * output[:, :, i:i+block_size, :] + 
                    output_block
                )
                l[:, :, i:i+block_size] = (
                    alpha * l[:, :, i:i+block_size] + l_block
                )
                m[:, :, i:i+block_size] = block_m.squeeze(-1)
        
        # 归一化
        output = output / l.unsqueeze(-1)
        
        return output

5.4 上下文长度外推技术

class ContextExtrapolation:
    """
    上下文长度外推技术集合
    """
    
    @staticmethod
    def YaRN_transform(rope_scaling, context_len, original_dim=128):
        """
        YaRN: Yet another RoPE extensioN
        结合NTK和拉伸(scaling)
        """
        type_, factor = rope_scaling
        
        if type_ == "yarn":
            # YaRN的特殊处理
            half_dim = original_dim // 2
            alpha = 10000 * factor
            
            # 调整base
            base = alpha * (1 - (1 / factor))
            
            return base
        return 10000
    
    @staticmethod
    def longrope_interpolation(context_len, factor):
        """
        LongRoPE: 渐进式位置编码插值
        """
        # 直接线性插值
        return 1.0 / factor  # 位置缩放
    
    @staticmethod
    def self_extend(q, k, position_ids, group_size=4096, extend_factor=2):
        """
        Self-Extend: 分组注意力
        近期tokens用细粒度attention
        远期tokens用粗粒度(group) attention
        """
        seq_len = position_ids.shape[-1]
        device = q.device
        
        # 分组
        num_groups = seq_len // group_size
        
        # 创建group-aware的attention mask
        mask = torch.zeros(seq_len, seq_len, device=device)
        
        for i in range(seq_len):
            for j in range(seq_len):
                if i < group_size or j < group_size:
                    # 细粒度: 正常attention
                    mask[i, j] = 0 if i >= j else -float('inf')
                else:
                    # 粗粒度: group级别attention
                    i_group = (i - group_size) // group_size
                    j_group = (j - group_size) // group_size
                    
                    if i_group == j_group:
                        mask[i, j] = 0 if i >= j else -float('inf')
                    else:
                        # 跨组: 允许(extend_factor > 1时)
                        if (i_group - j_group) <= extend_factor:
                            mask[i, j] = 0
                        else:
                            mask[i, j] = -float('inf')
        
        return mask

6. 实践指南

6.1 推理引擎选择

引擎优势劣势适用场景
vLLMPagedAttention、Continuous Batching、活跃开发功能多但某些优化不如专用引擎通用推理服务、高并发场景
TensorRT-LLMH100深度优化、极致性能NVIDIA独有、配置复杂生产级高性能推理
llama.cpp纯CPU/GPU、GGUF支持、跨平台不如专用引擎快边缘部署、推理研究
SGLangRadixAttention、前缀缓存强相对较新长上下文、高复用场景
Ollama易于使用、本地部署优化有限本地开发和测试
# vLLM推理示例
from vllm import LLM, SamplingParams
 
llm = LLM(
    model="meta-llama/Llama-3-8B-Instruct",
    tensor_parallel_size=2,  # 多GPU
    gpu_memory_utilization=0.9,
    max_model_len=8192,
    enable_prefix_caching=True  # 前缀缓存
)
 
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.95,
    max_tokens=512
)
 
outputs = llm.generate(["Hello, how are you?"], sampling_params)
print(outputs[0].outputs[0].text)
 
# TensorRT-LLM推理示例
# from tensorrt_llm import LLM as TRTLLM
# llm = TRTLLM.from_engine('llama-8b-fp16.engine')
# outputs = llm.generate(["Hello, how are you?"])

6.2 服务部署最佳实践

# docker-compose.yml for vLLM
version: '3.8'
services:
  vllm:
    image: vllm/vllm-openai:latest
    ports:
      - "8000:8000"
    volumes:
      - ./models:/models
    environment:
      - MODEL_NAME=meta-llama/Llama-3-8B-Instruct
      - GPU_MEMORY_UTILIZATION=0.9
      - MAX_MODEL_LEN=8192
      - TENSOR_PARALLEL_SIZE=2
      - ENABLE_PREFIX_CACHING=true
      - ENFORCE_EAGER=false  # 启用CUDA graphs
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 2
              capabilities: [gpu]
    command: --host 0.0.0.0 --port 8000
 
# 或者使用Kubernetes部署
# apiVersion: v1
# kind: ConfigMap
# metadata:
#   name: vllm-config
# data:
#   model: "meta-llama/Llama-3-70B-Instruct"
#   tensor-parallel-size: "4"
#   gpu-memory-utilization: "0.9"

6.3 性能基准测试

import time
import torch
from vllm import LLM
 
def benchmark_throughput(model_path, num_requests=1000, concurrency=32):
    """吞吐量基准测试"""
    llm = LLM(model=model_path, tensor_parallel_size=2)
    
    prompts = ["Write a short story." for _ in range(num_requests)]
    
    # 预热
    llm.generate(["warmup"])
    
    # 并发测试
    start = time.time()
    
    outputs = llm.generate(prompts, 
                           max_tokens=256,
                           concurrency=concurrency)
    
    elapsed = time.time() - start
    
    throughput = num_requests / elapsed
    avg_latency = elapsed / num_requests
    
    print(f"Total time: {elapsed:.2f}s")
    print(f"Throughput: {throughput:.2f} req/s")
    print(f"Average latency: {avg_latency*1000:.2f}ms")
    
    return {"throughput": throughput, "latency": avg_latency}
 
 
def benchmark_memory(model_path, max_model_len):
    """内存占用测试"""
    llm = LLM(
        model=model_path,
        max_model_len=max_model_len,
        gpu_memory_utilization=0.95
    )
    
    # 获取实际显存占用
    memory_allocated = torch.cuda.memory_allocated() / 1e9  # GB
    memory_reserved = torch.cuda.memory_reserved() / 1e9
    
    print(f"Memory allocated: {memory_allocated:.2f} GB")
    print(f"Memory reserved: {memory_reserved:.2f} GB")
    
    return {"allocated": memory_allocated, "reserved": memory_reserved}
 
 
def benchmark_long_context(model_path, context_lengths):
    """长上下文性能测试"""
    results = []
    
    for ctx_len in context_lengths:
        llm = LLM(model=model_path, max_model_len=ctx_len)
        
        # 生成指定长度的输入
        prompt = "Hello " * (ctx_len // 2)
        
        start = time.time()
        outputs = llm.generate([prompt], max_tokens=64)
        latency = time.time() - start
        
        results.append({
            "context_len": ctx_len,
            "latency": latency,
            "tokens_per_sec": 64 / latency
        })
    
    return results

6.4 优化配置推荐

# 推理优化配置模板
 
# 1. vLLM推荐配置
vllm_config = {
    # 内存优化
    "gpu_memory_utilization": 0.95,  # 高利用率
    "enable_prefix_caching": True,
    
    # 吞吐量优化
    "max_num_batched_tokens": 8192,  # 增大批次
    "max_num_seqs": 256,
    "enforce_eager": False,  # 启用CUDA graphs
    
    # 长上下文
    "max_model_len": 32768,  # 根据需求调整
    
    # 量化(可选)
    # "quantization": "awq",
    # "quantization_param_path": "./w Awq-output-format/",
}
 
# 2. TensorRT-LLM配置
trt_config = {
    "precision": "fp16",  # 或 fp8
    "tensor_parallel": 2,
    "num_layers": 40,
    "num_heads": 32,
    "hidden_size": 4096,
    "vocab_size": 128256,
    "max_batch_size": 64,
    "max_input_len": 4096,
    "max_output_len": 512,
    "use_gpt_attention": True,
    "remove_input_padding": True,  # 移除padding优化
}
 
# 3. llama.cpp量化推荐
llama_cpp_config = {
    "model_path": "./model.Q4_K_M.gguf",
    "n_ctx": 8192,
    "n_gpu_layers": 35,  # GPU加速层数
    "n_threads": 8,
    "n_batch": 512,  # prompt批处理大小
    "use_mlock": True,  # 锁定内存防止swap
    "use_mmap": True,  # 内存映射
    "rope_freq_base": 1000000,  # 长上下文rope base
}

7. 总结

LLM推理优化是一个系统工程,涉及内存管理、计算优化、并行策略等多个维度:

优化方向核心技术收益
KV CachePagedAttention、前缀缓存2-4× 并发提升
量化GPTQ、AWQ、FP82-4× 显存减少
推测解码EAGLE、Medusa2-3× 生成加速
批处理Continuous Batching5-10× 吞吐量提升
并行Ring Attention、Tensor并行线性扩展
长上下文NTK Scaling、稀疏注意力支持更长序列

推荐实践路径

  1. 快速上线:vLLM + AWQ量化
  2. 性能优先:TensorRT-LLM + FP16/FP8
  3. 长上下文:SGLang + Prefix Caching
  4. 边缘部署:llama.cpp + GGUF量化

参考资料

扩展阅读:

Footnotes

  1. Kwon W, Li Z, Zhuang S, et al. Efficient memory management for large language model serving using paged attention. SOSP, 2023. arXiv:2309.06180

  2. Xiao G, Tian Y, Chen B, et al. Efficient streaming language models with attention sinks. arXiv:2309.17453, 2023.

  3. Zhang Z, Sheng Y, Zhou T, et al. H2O: Heavy-Hitter Oracle for efficiently serving large language models. arXiv:2309.17453, 2023.

  4. Frantar E, Ashkboos S, Eggers C, et al. GPTQ: Accurate post-training quantization for generative pre-trained transformers. ICLR, 2023. arXiv:2210.17323

  5. Lin J, Tang J, Tang H, et al. AWQ: Activation-aware weight quantization for LLM compression and serving. MLSys, 2024. arXiv:2306.00978

  6. Leviathan Y, Kalman M, Matias Y. Fast inference from transformers via speculative decoding. ICML, 2023.

  7. Li Y, Wei F, Zhang C, et al. EAGLE: Speculative sampling with consistent feedback. arXiv:2401.15077, 2024.

  8. Chen C, Banaag M, Nijkamp M, et al. Medusa: Simple LLM inference acceleration framework with multiple Decoding Heads. arXiv:2401.10774, 2024.

  9. Yu G I, Jeong J S, Kim G W, et al. Orca: A distributed serving system for transformer-based generative models. OSDI, 2022.

  10. Li M, Xia Y, Carbin D, et al. Ring Attention: Distributed Attention for Long-context Transformers. arXiv:2310.07707, 2023.

  11. Press O, Smith N A, Levy O. Improving language understanding by generative pre-training. arXiv:1801.06146, 2018.