LLM推理优化
大语言模型(LLM)推理优化是部署高效AI服务的核心技术。与训练阶段不同,推理阶段面临独特的挑战:自回归生成导致的生成延迟累积、KV Cache带来的巨大内存压力、以及长上下文场景下的计算复杂度爆炸。
据2026年统计,企业AI推理成本已占AI基础设施支出的70%以上,有效的推理优化可将推理成本降低5-20倍,同时将吞吐量提升一个数量级。
1. KV Cache优化
1.1 KV Cache原理
自回归语言模型在生成第 个token时,需要attend到此前所有 个token的Key和Value向量。KV Cache通过缓存这些向量,避免在每一步重复计算:
设注意力机制为:
在推理时,Query随新token变化,但 和 可从缓存复用:
1.2 内存占用分析
KV Cache的内存占用是推理优化的核心瓶颈。
单层单token的KV内存:
对于LLaMA-7B():
完整KV Cache(考虑所有层):
对于LLaMA-7B(32层,):
长上下文场景:上下文长度扩展到128K时,单个请求的KV Cache可达140GB,远超单卡显存。
1.3 PagedAttention:vLLM的内存管理
vLLM提出的PagedAttention1借鉴操作系统虚拟内存的思路,将KV Cache分页管理:
核心思想:
- 将KV Cache划分为固定大小的块(默认16个token/block)
- 使用块表(Block Table)记录物理块与逻辑块的映射
- 支持不连续的物理存储,提高内存利用率
class BlockTable:
"""块表:管理逻辑块到物理块的映射"""
def __init__(self, block_size=16):
self.block_size = block_size
# 逻辑块 -> 物理块映射
self.physical_blocks = {} # physical_block_id -> tensor
self.block_mapping = {} # logical_idx -> physical_block_id
def allocate(self, num_tokens):
"""分配num_tokens个token的物理块"""
num_blocks = (num_tokens + self.block_size - 1) // self.block_size
allocated = []
for i in range(num_blocks):
logical_block_id = len(self.block_mapping)
# 检查是否需要新物理块
if logical_block_id not in self.block_mapping:
physical_block_id = self._get_free_block()
self.block_mapping[logical_block_id] = physical_block_id
self.physical_blocks[physical_block_id] = self._allocate_tensor()
allocated.append(self.block_mapping[logical_block_id])
return allocated
def get_physical_block(self, token_idx):
"""获取token对应的物理块"""
logical_block = token_idx // self.block_size
offset = token_idx % self.block_size
physical_block = self.block_mapping[logical_block]
return physical_block, offset
class PagedAttention:
"""
PagedAttention实现
核心:支持非连续物理块的并行注意力计算
"""
def __init__(self, block_size=16, num_heads=32, head_dim=128):
self.block_size = block_size
self.num_heads = num_heads
self.head_dim = head_dim
def forward(self, query, block_table, max_seq_len):
"""
Args:
query: (batch_size, num_heads, head_dim)
block_table: 物理块列表
max_seq_len: 最大序列长度
"""
batch_size = query.shape[0]
# 1. 收集物理块中的K和V
k_cache = []
v_cache = []
for block_id in block_table:
k_cache.append(self.get_block_k(block_id))
v_cache.append(self.get_block_v(block_id))
# 拼接
k_full = torch.cat(k_cache, dim=2) # (batch, heads, seq_len, head_dim)
v_full = torch.cat(v_cache, dim=2)
# 2. 标准attention计算
scores = torch.einsum('bhd,bhnd->bhn', query, k_full)
scores = scores / (self.head_dim ** 0.5)
scores = F.softmax(scores, dim=-1)
output = torch.einsum('bhn,bhnd->bhd', scores, v_full)
return outputPagedAttention的内存效率:
| 策略 | 内存利用率 | 并发请求数提升 |
|---|---|---|
| 连续分配 | ~20-30% | 1× |
| PagedAttention | ~60-80% | 2-4× |
1.4 缓存压缩
StreamingLLM
StreamingLLM2提出一种无需微调即可处理无限长度文本的方法。核心观察:attention sink现象——模型强烈attention到初始token,即使它们语义上不重要。
策略:保留4个”attention sink”token + 最近的局部window:
def streaming_llm_attention(query, key, value, sink_size=4, window_size=512):
"""
StreamingLLM:固定大小的滑动窗口注意力
保留: sink tokens + 最近window tokens
"""
seq_len = key.shape[2]
if seq_len <= sink_size + window_size:
# 完整序列,直接计算
return standard_attention(query, key, value)
# 1. 取sink tokens
k_sink = key[:, :, :sink_size, :]
v_sink = value[:, :, :sink_size, :]
# 2. 取最近window tokens
k_window = key[:, :, -window_size:, :]
v_window = value[:, :, -window_size:, :]
# 3. 拼接后计算
k_combined = torch.cat([k_sink, k_window], dim=2)
v_combined = torch.cat([v_sink, v_window], dim=2)
return standard_attention(query, k_combined, v_combined)实验结果:StreamingLLM在40B模型上可处理400万token的连续输入,困惑度几乎不变。
近似缓存策略
H2O(Heavy-Hitter Oracle)3:学习哪些token是”heavy hitter”,只保留这些token的KV:
def h2o_cache_selection(kv_cache, importance_scores, budget):
"""
H2O: 基于attention重要性选择保留的KV
保留累积attention score最高的token
"""
# importance_scores: 每个位置的历史attention累积
scores = importance_scores.sum(dim=1) # (seq_len,)
# 选择top-k
_, indices = torch.topk(scores, k=min(budget, len(scores)))
indices = torch.sort(indices)[0] # 保持顺序
return kv_cache[:, :, indices, :], indices1.5 跨请求KV复用
对于共享前缀的请求,KV Cache可以复用:
class PrefixCaching:
"""
前缀缓存:复用相同prompt前缀的KV Cache
"""
def __init__(self):
# 哈希表存储已计算的prefix KV
self.cache = {} # prompt_hash -> (k_cache, v_cache)
def get_cached_prefix(self, prompt_tokens):
"""获取已缓存的prefix KV"""
prompt_hash = hash(tuple(prompt_tokens))
return self.cache.get(prompt_hash)
def cache_prefix(self, prompt_tokens, k_cache, v_cache):
"""缓存prefix KV"""
prompt_hash = hash(tuple(prompt_tokens))
self.cache[prompt_hash] = (k_cache, v_cache)
def extend_with_cache(self, prompt_tokens, new_tokens):
"""从缓存的prefix扩展"""
cached_kv = self.get_cached_prefix(prompt_tokens)
if cached_kv is None:
return None
k_prefix, v_prefix = cached_kv
# 计算新token的KV
k_new, v_new = compute_kv(new_tokens, k_prefix, v_prefix)
# 拼接
return (
torch.cat([k_prefix, k_new], dim=2),
torch.cat([v_prefix, v_new], dim=2)
)2. 量化技术
2.1 INT8/FP16量化基础
量化将高精度数值映射到低精度表示:
对称量化:
非对称量化:
def symmetric_quantize(W, bits=8):
"""对称量化"""
scale = W.abs().max() / (2**(bits-1) - 1)
W_q = torch.round(W / scale)
W_q = torch.clamp(W_q, -(2**(bits-1)), 2**(bits-1)-1)
return W_q.to(torch.int8), scale
def asymmetric_quantize(W, bits=8):
"""非对称量化"""
scale = (W.max() - W.min()) / (2**bits - 1)
zero_point = torch.round(-W.min() / scale)
W_q = torch.round(W / scale) + zero_point
W_q = torch.clamp(W_q, 0, 2**bits-1)
return W_q.to(torch.uint8), scale, zero_point2.2 GPTQ:最优脑压缩
GPTQ4使用二阶信息( Hessian矩阵)指导量化,在4-bit下仍保持高精度。
核心算法:
class GPTQQuantizer:
"""
GPTQ: 基于OBC框架的量化器
"""
def __init__(self, model, bits=4, group_size=128):
self.model = model
self.bits = bits
self.group_size = group_size
def quantize_layer(self, layer):
W = layer.weight.data.clone()
out_features, in_features = W.shape
# 分组量化
num_groups = in_features // self.group_size
W_quant = torch.zeros_like(W, dtype=torch.int32)
scales = torch.zeros(out_features, num_groups)
for g in range(num_groups):
start = g * self.group_size
end = min((g+1) * self.group_size, in_features)
W_g = W[:, start:end]
# 计算该组的Hessian对角近似
H_diag = (W_g ** 2).mean(dim=1, keepdim=True)
H_diag = H_diag + 1e-8 # 数值稳定
# 基于Hessian的缩放
scale_g = torch.sqrt(H_diag / (2 ** (self.bits - 1)))
W_g_quant = torch.round(W_g / scale_g)
W_g_quant = torch.clamp(W_g_quant, -(2**(self.bits-1)), 2**(self.bits-1)-1)
W_quant[:, start:end] = W_g_quant.to(torch.int32)
scales[:, g] = scale_g.squeeze(-1)
return W_quant, scales2.3 AWQ:激活感知权重量化
AWQ5发现权重对量化的敏感性不同,保护敏感权重可减少误差:
def awq_search_scale(W, A, bits=4, alpha=0.5):
"""
AWQ: 搜索最优缩放因子
核心思想: 敏感权重用更大scale保护
alpha控制敏感度权重
"""
# 计算敏感性:基于激活值的权重重要性
sensitivity = (W.abs() * A.abs().mean(dim=0)).mean(dim=1, keepdim=True)
# 计算缩放因子
# s = (|W| / max(|W|))^alpha
w_abs_max = W.abs().max(dim=1, keepdim=True)[0]
s = (W.abs() / (w_abs_max + 1e-8)).pow(alpha)
return s
def awq_quantize(W, A, bits=4):
"""AWQ量化"""
s = awq_search_scale(W, A)
# 应用缩放
W_scaled = W / s
# 量化
max_val = 2**(bits-1) - 1
W_quant = torch.round(W_scaled)
W_quant = torch.clamp(W_quant, -max_val, max_val)
# 恢复scale
W_dequant = W_quant * s
return W_quant.to(torch.int8), s2.4 GGUF/llama.cpp格式
GGUF是llama.cpp提出的量化格式,支持多种精度(Q2_K到Q8_K):
# GGUF量化级别
GGUF_TYPES = {
"Q8_0": {"bits": 8, "block_size": 32, "type": "float"},
"Q6_K": {"bits": 6, "block_size": 256, "type": "quantized"},
"Q5_K_M": {"bits": 5, "block_size": 256, "type": "quantized"},
"Q4_K_M": {"bits": 4, "block_size": 256, "type": "quantized"},
"Q4_0": {"bits": 4, "block_size": 32, "type": "float"},
"Q3_K_M": {"bits": 3, "block_size": 256, "type": "quantized"},
"Q2_K": {"bits": 2, "block_size": 256, "type": "quantized"},
}
class GGUFQuantizer:
"""GGUF格式量化器"""
def quantize(self, W, quant_type="Q4_K_M"):
config = GGUF_TYPES[quant_type]
block_size = config["block_size"]
out_features, in_features = W.shape
num_blocks = in_features // block_size
# 量化每个block
quantized_blocks = []
scales = []
for i in range(num_blocks):
w_block = W[:, i*block_size:(i+1)*block_size]
# 计算block scale
scale = w_block.abs().max() / (2 ** (config["bits"] - 1))
scales.append(scale)
# 量化
w_quant = torch.round(w_block / scale)
w_quant = torch.clamp(w_quant, -(2**(config["bits"]-1)), 2**(config["bits"]-1)-1)
quantized_blocks.append(w_quant.to(torch.int8))
return quantized_blocks, torch.stack(scales)2.5 FP8量化
FP8(8-bit Float)是Hopper架构(H100/H200)的新数据类型:
class FP8Quantizer:
"""
FP8量化: H100/H200原生支持
E4M3: 符号(1) + 指数(4) + 尾数(3) -> 高精度范围
E5M2: 符号(1) + 指数(5) + 尾数(2) -> 高动态范围
"""
@staticmethod
def quantize_e4m3(W):
"""FP8 E4M3量化"""
# 范围: [-448, 448]
# 更适合权重
W_clamped = torch.clamp(W, -448, 448)
return W_clamped.to(torch.float8_e4m3fn)
@staticmethod
def quantize_e5m2(W):
"""FP8 E5M2量化"""
# 范围: [-57344, 57344]
# 更适合激活
W_clamped = torch.clamp(W, -57344, 57344)
return W_clamped.to(torch.float8_e5m2)
def forward(self, x):
# 权重用E4M3
w_fp8 = self.quantize_e4m3(self.weight)
# 激活用E5M2
x_fp8 = self.quantize_e5m2(x)
return F.linear(x_fp8, w_fp8)FP8 vs INT8 对比:
| 特性 | FP8 (E4M3) | INT8 |
|---|---|---|
| 动态范围 | 较小 | 适中 |
| 精度 | 较高 | 依赖校准 |
| 硬件支持 | H100/H200原生 | 通用 |
| 适用场景 | 权重 | 通用 |
3. 推测解码(Speculative Decoding)
3.1 基本原理
推测解码6使用小模型(Draft Model)快速生成候选token,再由大模型(Target Model)验证:
目标序列: [The, cat, sat, on, the, mat, ...]
↓ Draft (快速生成)
猜测序列: [The, cat, sat, on, the, mat, and, purr, ...]
↓ Target (并行验证)
接受序列: [The, cat, sat, on, the, mat, and] ✓
拒绝序列: [purr, ...] ✗ → 回退
加速比理论上限:
设Draft接受率为 ,Draft生成 个token耗时 ,Target验证耗时 :
当 且 时,加速比接近 。
3.2 Draft Model设计原则
class DraftModelConfig:
"""
Draft Model配置原则
"""
# 1. 参数规模: Target的1/10 ~ 1/20
DRAFT_SCALE_RATIO = 1 / 16 # 7B Target → 0.4B Draft
# 2. 共享 embedding + LM head
# 复用Target的embedding和output projection
# 3. 相同vocab但结构简化
def create_draft_model(target_model):
draft = copy.deepcopy(target_model)
# 减少层数
draft.num_layers = target_model.num_layers // 4
# 减小hidden dimension
draft.hidden dim = target_model.hidden_dim // 2
# 减少attention heads
draft.num_heads = target_model.num_heads // 2
return draft3.3 接受率与加速比分析
def analyze_speculative_decoding(draft_accept_rates, t_d, t_t, k):
"""
分析推测解码性能
Args:
draft_accept_rates: 每步的接受率列表
t_d: Draft生成k个token耗时
t_t: Target验证k个token耗时
k: 每轮Draft生成的token数
"""
results = []
for alpha in draft_accept_rates:
# 期望每轮生成的token数
E_tokens = 1 + alpha * k
# 期望耗时
E_time = t_t + alpha * k * (t_d / k) + (1 - alpha) * t_t
# 加速比 vs 自回归
speedup = (k + 1) * t_t / (t_d + t_t)
results.append({
"accept_rate": alpha,
"expected_tokens": E_tokens,
"speedup": speedup
})
return results
# 示例分析
# 假设: t_d = 10ms (Draft 0.4B), t_t = 100ms (Target 7B)
results = analyze_speculative_decoding(
draft_accept_rates=[0.5, 0.7, 0.8, 0.9, 0.95],
t_d=10, t_t=100, k=4
)
# 输出:
# alpha=0.7: speedup=3.4x
# alpha=0.9: speedup=5.6x3.4 EAGLE方法
EAGLE7(Self-Corrective EAGLE)使用自回归方式生成Draft,并引入校正机制:
class EAGLEDraft:
"""
EAGLE: 逐token生成Draft,使用Target的hidden state校正
"""
def __init__(self, target_model, draft_model):
self.target = target_model
self.draft = draft_model
self.temperature = 0.0 # EAGLE通常用贪心
def draft_forward(self, input_ids, past_kv=None):
"""Draft模型前向"""
outputs = self.draft(
input_ids=input_ids,
past_key_values=past_kv,
output_hidden_states=True
)
# 返回logits和hidden states用于校正
return {
"logits": outputs.logits[:, -1, :],
"hidden": outputs.hidden_states[-1],
"kv": outputs.past_key_values
}
def verify_and_advance(self, input_ids, draft_hidden, target_hidden, draft_kv):
"""
验证Draft并决定是否接受
EAGLE核心: 使用hidden state相似度判断
"""
# 计算target在draft位置的实际hidden state
target_outputs = self.target(
input_ids=input_ids,
past_key_values=target_kv,
output_hidden_states=True
)
target_h = target_outputs.hidden_states[-1]
# cosine相似度判断
sim = F.cosine_similarity(draft_hidden, target_h, dim=-1)
# 接受阈值
accept = sim > 0.9
return accept, target_h3.5 Medusa方法
Medusa8使用多个独立的预测头并行生成:
class MedusaHead(nn.Module):
"""
Medusa: 多预测头
每个头预测未来特定位置的token
"""
def __init__(self, hidden_size, vocab_size, depth):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Relu(),
nn.Linear(hidden_size, vocab_size)
) for _ in range(depth)
])
def forward(self, hidden_state):
"""
并行预测多个位置的token
Returns: [(token_id, prob), ...] for each head
"""
predictions = []
for layer in self.layers:
logits = layer(hidden_state)
probs = F.softmax(logits, dim=-1)
top_k = torch.topk(probs, k=5)
predictions.append((top_k.indices, top_k.values))
return predictions
class MedusaDecoding:
"""Medusa解码"""
def __init__(self, base_model, medusa_heads):
self.base = base_model
self.medusa = medusa_heads
def generate(self, input_ids, num_heads=5):
"""生成并验证"""
# 1. Base model计算hidden state
outputs = self.base(input_ids, output_hidden_states=True)
hidden = outputs.hidden_states[-1]
# 2. Medusa heads并行预测
medusa_preds = self.medusa(hidden)
# 3. 验证
accepted = [input_ids[0, -1]] # 第一个token总是base的预测
all_hidden = hidden
for head_idx, (token_ids, probs) in enumerate(medusa_preds):
# 尝试接受每个候选
for token_id in token_ids:
# 用target model验证
next_output = self.base(
torch.cat([input_ids, token_id.unsqueeze(0)], dim=1)
)
next_hidden = next_output.hidden_states[-1]
# 验证: hidden state相似度 + 概率阈值
if F.cosine_similarity(hidden, next_hidden) > 0.95:
accepted.append(token_id)
hidden = next_hidden
break
return accepted4. 批处理与并行
4.1 Continuous Batching
传统静态批处理要求批次内序列等长,导致大量气泡。Continuous Batching9实现真正的动态批处理:
class ContinuousBatcher:
"""
Continuous Batching: 动态批次管理
"""
def __init__(self, max_batch_size, max_seq_len):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.running_seqs = [] # 正在生成中的序列
self.pending_seqs = [] # 等待调度的序列
def add_request(self, request):
"""添加新请求"""
self.pending_seqs.append({
"id": request.id,
"tokens": request.input_tokens,
"max_new_tokens": request.max_new_tokens,
"finished": False,
"output_tokens": []
})
def step(self, model):
"""
执行一步前向
核心: 动态添加/移除序列
"""
# 1. 填充批次
batch = self._prepare_batch()
# 2. 执行模型
outputs = model(batch["input_ids"], batch["past_kv"])
# 3. 处理结果
for i, seq in enumerate(batch["sequences"]):
if seq["finished"]:
continue
next_token = outputs.logits[i, -1].argmax()
seq["output_tokens"].append(next_token)
# 检查是否完成
if (len(seq["output_tokens"]) >= seq["max_new_tokens"] or
next_token == EOS_TOKEN):
seq["finished"] = True
self._return_result(seq)
# 4. 补充新请求
self._refill_batch()
return batch["sequences"]
def _prepare_batch(self):
"""准备批次:合并running + 部分pending"""
# 选择最多max_batch_size个序列
selected = self.running_seqs[:self.max_batch_size]
# 填充到相同长度
max_len = max(len(s["tokens"]) + len(s["output_tokens"])
for s in selected)
batch_input = []
batch_kv = []
batch_sequences = []
for seq in selected:
# 填充逻辑
input_ids = self._pad_to(seq, max_len)
batch_input.append(input_ids)
batch_sequences.append(seq)
return {
"input_ids": torch.stack(batch_input),
"past_kv": self._get_past_kv(selected),
"sequences": batch_sequences
}4.2 前缀缓存
共享Prompt场景下的KV Cache复用:
class PrefixCache:
"""
前缀缓存: 相同前缀的请求共享KV
"""
def __init__(self, cache_size=1000):
self.cache = {} # prompt_hash -> KV cache
self.access_count = {} # 用于LRU淘汰
def compute_hash(self, tokens):
"""计算prompt的哈希值"""
return hashlib.sha256(tokens.tobytes()).hexdigest()
def lookup(self, prompt_tokens):
"""查找缓存的前缀KV"""
h = self.compute_hash(prompt_tokens)
if h in self.cache:
self.access_count[h] += 1
return self.cache[h]
return None
def store(self, prompt_tokens, kv_cache):
"""存储前缀KV"""
h = self.compute_hash(prompt_tokens)
# LRU淘汰
if len(self.cache) >= self.cache_size:
min_access = min(self.access_count.values())
evict_h = [k for k, v in self.access_count.items()
if v == min_access][0]
del self.cache[evict_h]
del self.access_count[evict_h]
self.cache[h] = kv_cache
self.access_count[h] = 1
def prefix_match(self, new_tokens, cached_tokens):
"""
检查新tokens是否以缓存的prefix开头
返回匹配长度
"""
min_len = min(len(new_tokens), len(cached_tokens))
match_len = 0
for i in range(min_len):
if new_tokens[i] == cached_tokens[i]:
match_len += 1
else:
break
return match_len4.3 序列并行
Ring Attention
Ring Attention10将序列维度分片,多设备协同计算attention:
class RingAttention(nn.Module):
"""
Ring Attention: 序列并行的attention实现
将KV沿序列维度分片
"""
def __init__(self, num_devices, dim_model, num_heads):
self.num_devices = num_devices
self.dim_model = dim_model
self.num_heads = num_heads
self.head_dim = dim_model // num_heads
def forward(self, q, k, v, ring_index):
"""
Args:
q: (batch, num_heads, seq_len//num_devices, head_dim)
k, v: 同样分片
ring_index: 当前设备在ring中的索引
"""
device_rank = ring_index
local_k = k
local_v = v
local_attn = torch.zeros_like(q)
# Ring communication: 环形传递KV
for step in range(self.num_devices):
# 计算当前step的attention
peer_rank = (device_rank + step) % self.num_devices
# 本地Q与peer的K,V计算attention
attn = self._scaled_dot_product(q, local_k, local_v)
local_attn += attn
# 准备下一步:Q轮转
q = torch.roll(q, shifts=-1, dims=2)
# KV传递给下一个peer (ring send/recv)
local_k = self._ring_recv(local_k, peer_rank)
local_v = self._ring_recv(local_v, peer_rank)
return local_attn
def _scaled_dot_product(self, q, k, v):
"""标准scaled dot-product attention"""
d_k = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, v)
def _ring_recv(self, tensor, from_rank):
"""Ring通信:接收来自peer的数据"""
# 实际实现使用NCCL all-to-all
return tensor # 简化Ulysses Attention
Ulysses Attention通过All-to-All通信集中计算:
class UlyssesAttention(nn.Module):
"""
Ulysses Attention: 先聚集再分散
适合短序列、高通信带宽场景
"""
def forward(self, q, k, v, num_devices):
# 1. All-to-All: 将Q,K,V沿sequence维度汇聚
q_all = self.all_to_all(q) # (batch, num_heads, full_seq, head_dim)
k_all = self.all_to_all(k)
v_all = self.all_to_all(v)
# 2. 单设备计算完整attention
attn = self._scaled_dot_product(q_all, k_all, v_all)
# 3. All-to-All: 分散结果回各设备
output = self.all_to_all(attn)
return output
def all_to_all(self, x):
"""All-to-All集合通信"""
# 使用NCCL或torch.distributed
return x # 简化4.4 Tensor并行推理
class TensorParallelInference:
"""
Tensor并行推理: 将模型沿hidden维度分片
"""
def __init__(self, model, num_devices):
self.num_devices = num_devices
self.model = self._shard_model(model)
def _shard_model(self, model):
"""分片模型权重"""
for name, param in model.named_parameters():
if "weight" in name and param.ndim >= 2:
# 沿输出维度分片
dim = 0
chunks = torch.chunk(param, self.num_devices, dim=dim)
for i, chunk in enumerate(chunks):
# 发送到不同设备
param_register(f"{name}_shard_{i}", chunk.to(i))
return model
def forward(self, input_ids):
"""并行前向"""
# 1. 分散输入
input_shards = self._scatter_input(input_ids)
# 2. 各设备计算自己分片的权重
outputs = []
for i, shard in enumerate(input_shards):
out = self._compute_shard(shard, i)
outputs.append(out)
# 3. All-Reduce汇总输出
output = self._all_reduce(outputs)
return output
def _compute_shard(self, input_shard, device_id):
"""计算单个分片"""
# Linear层分片计算
# y = x @ W^T
# W被分片为 [W1, W2, ...]
# y = [x @ W1^T, x @ W2^T, ...]
return input_shard # 简化5. 长上下文推理
5.1 位置编码插值
位置编码外推是长上下文的关键挑战。RoPE的NTK-aware scaling11是有效方法:
class NTKAwareScaling:
"""
NTK-aware Scaling: 无需微调的位置编码外推
"""
@staticmethod
def compute_scaling_factor(context_len, original_len, alpha=8):
"""
计算NTK缩放因子
核心思想: 动态调整不同频率成分
高频成分(短距离)缩放小
低频成分(长距离)缩放大
"""
# base是RoPE的基础频率参数
base = 10000
# 计算等效的上下文扩展比例
# 对于alpha倍扩展,使用alpha^(dim/dim_small)缩放
scale = (original_len / context_len) ** (1 / alpha)
return scale
@staticmethod
def apply_rope_with_scaling(q, k, position_ids, scaling_factor):
"""应用带缩放的RoPE"""
# 原始RoPE角度
inv_freq = 1.0 / (10000 ** (torch.arange(0, q.shape[-1], 2).float() / q.shape[-1]))
# NTK缩放:调整频率
inv_freq_scaled = inv_freq / scaling_factor
# 计算旋转位置编码
freqs = torch.outer(position_ids, inv_freq_scaled)
freqs = torch.cat([freqs, freqs], dim=-1)
# 应用旋转
q_rot = self._rotate_half(q, freqs)
k_rot = self._rotate_half(k, freqs)
return q_rot, k_rot
@staticmethod
def _rotate_half(x, freqs):
"""旋转操作"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
# 复数乘法: x * exp(i * theta)
x_new = torch.cat([
x1 * torch.cos(freqs) - x2 * torch.sin(freqs),
x1 * torch.sin(freqs) + x2 * torch.cos(freqs)
], dim=-1)
return x_new5.2 稀疏注意力
Sliding Window Attention
class SlidingWindowAttention(nn.Module):
"""
Sliding Window Attention: 局部注意力 + 全局tokens
类似Mistral的Attention机制
"""
def __init__(self, window_size=4096, num_global=256):
self.window_size = window_size
self.num_global = num_global
def forward(self, q, k, v, is_global_token):
"""
Args:
is_global_token: 标记哪些token是全局的
"""
seq_len = q.shape[2]
# 1. 构建attention mask
mask = torch.zeros(seq_len, seq_len, device=q.device)
for i in range(seq_len):
for j in range(seq_len):
# 全局token可以attend到所有位置
if is_global_token[i] or is_global_token[j]:
mask[i, j] = 0 # 允许
# 非全局token只能attend到window内的位置
elif abs(i - j) <= self.window_size:
mask[i, j] = 0
else:
mask[i, j] = -inf
# 2. 计算attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
scores = scores + mask
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, v)Hippo State Space
Hippo(HIstorical Pricing Persisting Optimization)基于状态空间模型处理长序列:
class HippoAttention(nn.Module):
"""
Hippo: 近似全注意力的线性状态空间方法
理论基础: LegS, Dimp乐等多项式投影
"""
@staticmethod
def compute_legendre_matrix(N, L):
"""
LegS (Legendre Select) 矩阵
用于将历史信息投影到N维状态
"""
# Legendre多项式根计算
# 实际实现使用scipy.special.legendre
return torch.randn(N, L) # 简化
def forward(self, x, state_size=64):
"""
将full attention近似为线性状态空间
"""
seq_len = x.shape[2]
# 1. 投影到状态空间
A = self.compute_legendre_matrix(state_size, seq_len)
# 2. 线性投影
state = A @ x.transpose(1, 2) # (batch, heads, state_size, head_dim)
# 3. 状态转移 + 输出
# 简化的SSM计算
h = torch.zeros_like(state[:, :, 0])
outputs = []
for t in range(seq_len):
h = state[:, :, t] + 0.9 * h # 简化状态转移
outputs.append(h)
return torch.stack(outputs, dim=2)5.3 内存高效注意力变体
class FlashAttention:
"""
Flash Attention: IO-aware的高效attention实现
核心: 分块计算,减少HBM访问
"""
@staticmethod
def forward(q, k, v, block_size=128):
"""
Flash Attention前向
通过tiling避免实现full N×N attention matrix
"""
B, H, N, D = q.shape
# 初始化
output = torch.zeros(B, H, N, D, device=q.device, dtype=q.dtype)
l = torch.zeros(B, H, N, device=q.device) # 行累加
m = torch.full((B, H, N), -float('inf'), device=q.device) # 行最大值
# 分块计算
for j in range(0, N, block_size):
# 加载K,V块
k_block = k[:, :, j:j+block_size, :]
v_block = v[:, :, j:j+block_size, :]
for i in range(0, N, block_size):
# 加载Q块
q_block = q[:, :, i:i+block_size, :]
# 计算S = Q @ K^T
s = torch.matmul(q_block, k_block.transpose(-2, -1))
# 考虑因果mask
mask = torch.triu(torch.ones(i+block_size, j+block_size),
-j).to(q.device)
s = s.masked_fill(mask == 0, -float('inf'))
# 稳定softmax: P = exp(S - row_max)
block_m = s.amax(dim=-1, keepdim=True)
s = s - block_m
p = torch.exp(s)
# 更新
alpha = torch.exp(m[:, :, i:i+block_size] - block_m)
p = p * alpha.unsqueeze(-1)
l_block = p.sum(dim=-1)
# 更新输出
output_block = torch.matmul(p, v_block)
output[:, :, i:i+block_size, :] = (
alpha.unsqueeze(-1) * output[:, :, i:i+block_size, :] +
output_block
)
l[:, :, i:i+block_size] = (
alpha * l[:, :, i:i+block_size] + l_block
)
m[:, :, i:i+block_size] = block_m.squeeze(-1)
# 归一化
output = output / l.unsqueeze(-1)
return output5.4 上下文长度外推技术
class ContextExtrapolation:
"""
上下文长度外推技术集合
"""
@staticmethod
def YaRN_transform(rope_scaling, context_len, original_dim=128):
"""
YaRN: Yet another RoPE extensioN
结合NTK和拉伸(scaling)
"""
type_, factor = rope_scaling
if type_ == "yarn":
# YaRN的特殊处理
half_dim = original_dim // 2
alpha = 10000 * factor
# 调整base
base = alpha * (1 - (1 / factor))
return base
return 10000
@staticmethod
def longrope_interpolation(context_len, factor):
"""
LongRoPE: 渐进式位置编码插值
"""
# 直接线性插值
return 1.0 / factor # 位置缩放
@staticmethod
def self_extend(q, k, position_ids, group_size=4096, extend_factor=2):
"""
Self-Extend: 分组注意力
近期tokens用细粒度attention
远期tokens用粗粒度(group) attention
"""
seq_len = position_ids.shape[-1]
device = q.device
# 分组
num_groups = seq_len // group_size
# 创建group-aware的attention mask
mask = torch.zeros(seq_len, seq_len, device=device)
for i in range(seq_len):
for j in range(seq_len):
if i < group_size or j < group_size:
# 细粒度: 正常attention
mask[i, j] = 0 if i >= j else -float('inf')
else:
# 粗粒度: group级别attention
i_group = (i - group_size) // group_size
j_group = (j - group_size) // group_size
if i_group == j_group:
mask[i, j] = 0 if i >= j else -float('inf')
else:
# 跨组: 允许(extend_factor > 1时)
if (i_group - j_group) <= extend_factor:
mask[i, j] = 0
else:
mask[i, j] = -float('inf')
return mask6. 实践指南
6.1 推理引擎选择
| 引擎 | 优势 | 劣势 | 适用场景 |
|---|---|---|---|
| vLLM | PagedAttention、Continuous Batching、活跃开发 | 功能多但某些优化不如专用引擎 | 通用推理服务、高并发场景 |
| TensorRT-LLM | H100深度优化、极致性能 | NVIDIA独有、配置复杂 | 生产级高性能推理 |
| llama.cpp | 纯CPU/GPU、GGUF支持、跨平台 | 不如专用引擎快 | 边缘部署、推理研究 |
| SGLang | RadixAttention、前缀缓存强 | 相对较新 | 长上下文、高复用场景 |
| Ollama | 易于使用、本地部署 | 优化有限 | 本地开发和测试 |
# vLLM推理示例
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-3-8B-Instruct",
tensor_parallel_size=2, # 多GPU
gpu_memory_utilization=0.9,
max_model_len=8192,
enable_prefix_caching=True # 前缀缓存
)
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
max_tokens=512
)
outputs = llm.generate(["Hello, how are you?"], sampling_params)
print(outputs[0].outputs[0].text)
# TensorRT-LLM推理示例
# from tensorrt_llm import LLM as TRTLLM
# llm = TRTLLM.from_engine('llama-8b-fp16.engine')
# outputs = llm.generate(["Hello, how are you?"])6.2 服务部署最佳实践
# docker-compose.yml for vLLM
version: '3.8'
services:
vllm:
image: vllm/vllm-openai:latest
ports:
- "8000:8000"
volumes:
- ./models:/models
environment:
- MODEL_NAME=meta-llama/Llama-3-8B-Instruct
- GPU_MEMORY_UTILIZATION=0.9
- MAX_MODEL_LEN=8192
- TENSOR_PARALLEL_SIZE=2
- ENABLE_PREFIX_CACHING=true
- ENFORCE_EAGER=false # 启用CUDA graphs
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 2
capabilities: [gpu]
command: --host 0.0.0.0 --port 8000
# 或者使用Kubernetes部署
# apiVersion: v1
# kind: ConfigMap
# metadata:
# name: vllm-config
# data:
# model: "meta-llama/Llama-3-70B-Instruct"
# tensor-parallel-size: "4"
# gpu-memory-utilization: "0.9"6.3 性能基准测试
import time
import torch
from vllm import LLM
def benchmark_throughput(model_path, num_requests=1000, concurrency=32):
"""吞吐量基准测试"""
llm = LLM(model=model_path, tensor_parallel_size=2)
prompts = ["Write a short story." for _ in range(num_requests)]
# 预热
llm.generate(["warmup"])
# 并发测试
start = time.time()
outputs = llm.generate(prompts,
max_tokens=256,
concurrency=concurrency)
elapsed = time.time() - start
throughput = num_requests / elapsed
avg_latency = elapsed / num_requests
print(f"Total time: {elapsed:.2f}s")
print(f"Throughput: {throughput:.2f} req/s")
print(f"Average latency: {avg_latency*1000:.2f}ms")
return {"throughput": throughput, "latency": avg_latency}
def benchmark_memory(model_path, max_model_len):
"""内存占用测试"""
llm = LLM(
model=model_path,
max_model_len=max_model_len,
gpu_memory_utilization=0.95
)
# 获取实际显存占用
memory_allocated = torch.cuda.memory_allocated() / 1e9 # GB
memory_reserved = torch.cuda.memory_reserved() / 1e9
print(f"Memory allocated: {memory_allocated:.2f} GB")
print(f"Memory reserved: {memory_reserved:.2f} GB")
return {"allocated": memory_allocated, "reserved": memory_reserved}
def benchmark_long_context(model_path, context_lengths):
"""长上下文性能测试"""
results = []
for ctx_len in context_lengths:
llm = LLM(model=model_path, max_model_len=ctx_len)
# 生成指定长度的输入
prompt = "Hello " * (ctx_len // 2)
start = time.time()
outputs = llm.generate([prompt], max_tokens=64)
latency = time.time() - start
results.append({
"context_len": ctx_len,
"latency": latency,
"tokens_per_sec": 64 / latency
})
return results6.4 优化配置推荐
# 推理优化配置模板
# 1. vLLM推荐配置
vllm_config = {
# 内存优化
"gpu_memory_utilization": 0.95, # 高利用率
"enable_prefix_caching": True,
# 吞吐量优化
"max_num_batched_tokens": 8192, # 增大批次
"max_num_seqs": 256,
"enforce_eager": False, # 启用CUDA graphs
# 长上下文
"max_model_len": 32768, # 根据需求调整
# 量化(可选)
# "quantization": "awq",
# "quantization_param_path": "./w Awq-output-format/",
}
# 2. TensorRT-LLM配置
trt_config = {
"precision": "fp16", # 或 fp8
"tensor_parallel": 2,
"num_layers": 40,
"num_heads": 32,
"hidden_size": 4096,
"vocab_size": 128256,
"max_batch_size": 64,
"max_input_len": 4096,
"max_output_len": 512,
"use_gpt_attention": True,
"remove_input_padding": True, # 移除padding优化
}
# 3. llama.cpp量化推荐
llama_cpp_config = {
"model_path": "./model.Q4_K_M.gguf",
"n_ctx": 8192,
"n_gpu_layers": 35, # GPU加速层数
"n_threads": 8,
"n_batch": 512, # prompt批处理大小
"use_mlock": True, # 锁定内存防止swap
"use_mmap": True, # 内存映射
"rope_freq_base": 1000000, # 长上下文rope base
}7. 总结
LLM推理优化是一个系统工程,涉及内存管理、计算优化、并行策略等多个维度:
| 优化方向 | 核心技术 | 收益 |
|---|---|---|
| KV Cache | PagedAttention、前缀缓存 | 2-4× 并发提升 |
| 量化 | GPTQ、AWQ、FP8 | 2-4× 显存减少 |
| 推测解码 | EAGLE、Medusa | 2-3× 生成加速 |
| 批处理 | Continuous Batching | 5-10× 吞吐量提升 |
| 并行 | Ring Attention、Tensor并行 | 线性扩展 |
| 长上下文 | NTK Scaling、稀疏注意力 | 支持更长序列 |
推荐实践路径:
- 快速上线:vLLM + AWQ量化
- 性能优先:TensorRT-LLM + FP16/FP8
- 长上下文:SGLang + Prefix Caching
- 边缘部署:llama.cpp + GGUF量化
参考资料
扩展阅读:
- 模型量化技术 — 更详细的量化方法
- Transformer缩放定律 — 模型规模与性能关系
- RAG(检索增强生成) — LLM应用架构
- LLM评估 — 模型性能评估方法
Footnotes
-
Kwon W, Li Z, Zhuang S, et al. Efficient memory management for large language model serving using paged attention. SOSP, 2023. arXiv:2309.06180 ↩
-
Xiao G, Tian Y, Chen B, et al. Efficient streaming language models with attention sinks. arXiv:2309.17453, 2023. ↩
-
Zhang Z, Sheng Y, Zhou T, et al. H2O: Heavy-Hitter Oracle for efficiently serving large language models. arXiv:2309.17453, 2023. ↩
-
Frantar E, Ashkboos S, Eggers C, et al. GPTQ: Accurate post-training quantization for generative pre-trained transformers. ICLR, 2023. arXiv:2210.17323 ↩
-
Lin J, Tang J, Tang H, et al. AWQ: Activation-aware weight quantization for LLM compression and serving. MLSys, 2024. arXiv:2306.00978 ↩
-
Leviathan Y, Kalman M, Matias Y. Fast inference from transformers via speculative decoding. ICML, 2023. ↩
-
Li Y, Wei F, Zhang C, et al. EAGLE: Speculative sampling with consistent feedback. arXiv:2401.15077, 2024. ↩
-
Chen C, Banaag M, Nijkamp M, et al. Medusa: Simple LLM inference acceleration framework with multiple Decoding Heads. arXiv:2401.10774, 2024. ↩
-
Yu G I, Jeong J S, Kim G W, et al. Orca: A distributed serving system for transformer-based generative models. OSDI, 2022. ↩
-
Li M, Xia Y, Carbin D, et al. Ring Attention: Distributed Attention for Long-context Transformers. arXiv:2310.07707, 2023. ↩
-
Press O, Smith N A, Levy O. Improving language understanding by generative pre-training. arXiv:1801.06146, 2018. ↩