连续批处理与Titan推理系统
1. 概述
连续批处理(Continuous Batching)是LLM推理系统中提升GPU利用率的核心技术。Titan推理系统则代表了2025年SOTA的推理优化方案,通过融合连续批处理、投机解码和智能调度来实现极致性能。
核心问题
传统批处理的局限性:
- 静态批处理:所有请求必须同时完成
- GPU空闲:短请求等待长请求
- 资源浪费:无法动态调整
连续批处理的解决方案
动态批处理:新请求可替换已完成请求
结果:GPU利用率从30%提升至90%+
2. 批处理范式对比
2.1 静态批处理
时间 →
┌─────────────────────────────────────────────────────────┐
│ Request A (10 tokens) │████████│ │
│ Request B (50 tokens) │████████░░░░░░░░░░░░░░░░░░░░░░│
│ Request C (30 tokens) │████████░░░░░░░░░░░░░░░░░░░░░░│
└─────────────────────────────────────────────────────────┘
↑
GPU空闲等待
利用率: ~30%
2.2 连续批处理
时间 →
┌─────────────────────────────────────────────────────────┐
│ Batch 1: A, B, C │
│ │▓▓│▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓│
│ │
│ Batch 2: D, E, B, C ← A完成,D/E加入 │
│ │▓▓│▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓│ │
│ │
│ Batch 3: F, G, B, C ← D/E完成,F/G加入 │
│ │▓▓│▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓│ │
└─────────────────────────────────────────────────────────┘
↑
GPU持续工作
利用率: ~90%
3. 连续批处理实现
3.1 核心算法
from dataclasses import dataclass, field
from typing import List, Optional
import heapq
@dataclass
class Request:
"""推理请求"""
request_id: str
prompt_ids: List[int]
max_new_tokens: int
generated_ids: List[int] = field(default_factory=list)
is_finished: bool = False
@property
def total_length(self) -> int:
return len(self.prompt_ids) + len(self.generated_ids)
@property
def remaining_tokens(self) -> int:
return self.max_new_tokens - len(self.generated_ids)
class ContinuousBatcher:
"""
连续批处理调度器
核心思想:
1. 保持batch动态更新
2. 已完成请求立即替换为新请求
3. 最大化GPU利用率
"""
def __init__(
self,
model, # 推理模型
max_batch_size: int = 32,
max_sequence_length: int = 4096,
prefill_chunk_size: int = 512
):
self.model = model
self.max_batch_size = max_batch_size
self.max_sequence_length = max_sequence_length
self.prefill_chunk_size = prefill_chunk_size
# 请求管理
self.pending_requests: List[Request] = [] # 等待调度的请求
self.running_requests: List[Request] = [] # 正在运行的请求
self.completed_requests: List[Request] = [] # 已完成的请求
# 批处理状态
self.prompt_dicts = [] # 当前batch的prompt
self.max_running_length = 0 # 当前batch最大长度
def add_request(self, request: Request):
"""添加新请求"""
self.pending_requests.append(request)
def _prepare_batch(self) -> bool:
"""
准备下一个batch
Returns:
是否准备好进行推理
"""
# 计算当前batch的最大长度
if not self.running_requests:
self.max_running_length = 0
else:
self.max_running_length = max(
req.total_length for req in self.running_requests
)
# 计算可以加入的新请求数量
available_slots = self.max_batch_size - len(self.running_requests)
# 填充新请求(考虑长度约束)
while (self.pending_requests and
available_slots > 0):
next_req = self.pending_requests[0]
# 检查长度约束
new_total_length = max(
self.max_running_length,
len(next_req.prompt_ids)
) + 1 # 至少要生成1个token
if new_total_length <= self.max_sequence_length:
# 可以加入
self.running_requests.append(
self.pending_requests.pop(0)
)
available_slots -= 1
else:
# 长度超出,标记失败
next_req.is_finished = True
next_req.error = "Sequence too long"
self.completed_requests.append(
self.pending_requests.pop(0)
)
return len(self.running_requests) > 0
def _execute_batch(self):
"""执行一个推理步骤"""
if not self.running_requests:
return
# 1. 准备输入
# 检测哪些请求需要prefill
prefill_indices = []
decode_indices = []
for i, req in enumerate(self.running_requests):
if len(req.generated_ids) == 0:
prefill_indices.append(i)
else:
decode_indices.append(i)
# 2. 合并prefill请求(chunk处理)
if prefill_indices:
self._execute_prefill(prefill_indices)
# 3. 执行decode
if decode_indices:
self._execute_decode(decode_indices)
def _execute_prefill(self, indices: List[int]):
"""执行prefill阶段"""
# 将多个prefill请求合并为一个batch
# 实际实现中可能需要chunk处理
prompts = [
self.running_requests[i].prompt_ids
for i in indices
]
# 执行prefill
outputs = self.model.prefill(prompts)
# 更新请求状态
for idx, output in zip(indices, outputs):
req = self.running_requests[idx]
req.generated_ids.append(output.token_id)
# 检查是否结束
if output.token_id == self.model.eos_token_id:
req.is_finished = True
def _execute_decode(self, indices: List[int]):
"""执行decode阶段"""
# 获取当前running的最大长度
max_len = max(
self.running_requests[i].total_length
for i in indices
)
# 获取当前token
current_tokens = [
self.running_requests[i].generated_ids[-1]
for i in indices
]
# 执行decode
outputs = self.model.decode(current_tokens)
# 更新请求状态
finished = []
for idx, output in zip(indices, outputs):
req = self.running_requests[idx]
req.generated_ids.append(output.token_id)
# 检查结束条件
if (output.token_id == self.model.eos_token_id or
len(req.generated_ids) >= req.max_new_tokens):
req.is_finished = True
finished.append(idx)
# 移除已完成的请求
for idx in sorted(finished, reverse=True):
req = self.running_requests.pop(idx)
self.completed_requests.append(req)
def step(self) -> List[Request]:
"""
执行一个完整的推理步骤
Returns:
本步完成的所有请求
"""
# 1. 准备batch
self._prepare_batch()
# 2. 执行推理
self._execute_batch()
# 3. 返回完成的请求
completed = self.completed_requests
self.completed_requests = []
return completed
def run(self):
"""持续运行直到所有请求完成"""
while self.pending_requests or self.running_requests:
completed = self.step()
yield completed3.2 Chunked Prefill
长prompt的prefill非常耗时,会阻塞其他请求。Chunked Prefill解决方案:
def chunked_prefill(
model,
prompt_ids: List[int],
kv_cache: Any,
chunk_size: int = 512
):
"""
分块处理长prompt的prefill
避免长prompt独占整个batch
"""
num_chunks = (len(prompt_ids) + chunk_size - 1) // chunk_size
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(prompt_ids))
chunk_ids = prompt_ids[start:end]
# 执行chunk的prefill
outputs, kv_cache = model.prefill_step(
chunk_ids,
kv_cache,
is_first_chunk=(chunk_idx == 0),
is_last_chunk=(chunk_idx == num_chunks - 1)
)
yield chunk_idx, num_chunks, outputs
class ChunkedContinuousBatcher(ContinuousBatcher):
"""
支持Chunked Prefill的连续批处理
"""
def __init__(self, *args, prefill_chunk_size: int = 512, **kwargs):
super().__init__(*args, **kwargs)
self.prefill_chunk_size = prefill_chunk_size
self.chunking_requests = {} # 正在chunk prefill的请求
def _execute_prefill(self, indices: List[int]):
"""执行分块prefill"""
new_prefill = []
chunking = []
for idx in indices:
req = self.running_requests[idx]
if req.request_id not in self.chunking_requests:
# 新请求,开始chunk prefill
self.chunking_requests[req.request_id] = {
'request': req,
'chunk_idx': 0,
'num_chunks': (len(req.prompt_ids) +
self.prefill_chunk_size - 1) //
self.prefill_chunk_size,
'position': idx
}
chunking.append(req.request_id)
else:
# 正在chunk的请求
chunking.append(req.request_id)
# 优先处理正在chunk的请求
# 它们有更高的优先级(已部分完成)
ordered_indices = []
for rid in chunking:
info = self.chunking_requests[rid]
if info['chunk_idx'] > 0: # 优先已完成部分chunk的
ordered_indices.append(info['position'])
# 执行chunk prefill(简化版)
for rid in chunking:
info = self.chunking_requests[rid]
req = info['request']
chunk_idx = info['chunk_idx']
start = chunk_idx * self.prefill_chunk_size
end = min(start + self.prefill_chunk_size, len(req.prompt_ids))
# 执行chunk
output = model.prefill_chunk(
req.prompt_ids[start:end],
info.get('kv_cache', None)
)
info['kv_cache'] = output.kv_cache
info['chunk_idx'] += 1
# 检查是否完成
if info['chunk_idx'] >= info['num_chunks']:
# Prefill完成,生成第一个token
req.generated_ids.append(output.first_token)
del self.chunking_requests[rid]
# 检查是否结束
if output.first_token == model.eos_token_id:
req.is_finished = True4. Titan推理系统
4.1 Titan架构概览
Titan是SOTA的LLM推理系统,融合多种优化技术:
┌─────────────────────────────────────────────────────────────┐
│ Request Scheduler │
│ (连续批处理 + 智能调度 + 优先级队列) │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Prefill Engine │
│ (Chunked Prefill + 算子融合) │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Decode Engine │
│ (投机解码 + KV Cache管理 + 连续批处理) │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Memory Manager │
│ (PagedAttention + 分页KV Cache + 量化) │
└─────────────────────────────────────────────────────────────┘
4.2 PagedAttention
Titan使用PagedAttention管理KV Cache:
class PagedKVCache:
"""
PagedAttention的KV Cache管理
将KV Cache组织为固定大小的block
类似操作系统的分页内存管理
"""
def __init__(
self,
block_size: int = 16, # 每个block的token数
num_blocks: int = 1024, # 总block数
num_heads: int = 32,
head_dim: int = 128
):
self.block_size = block_size
self.num_blocks = num_blocks
self.num_heads = num_heads
self.head_dim = head_dim
# 预分配的KV block池
self.kv_blocks = torch.zeros(
num_blocks, num_heads, block_size, head_dim
)
self.block_allocated = [False] * num_blocks
# 请求的block映射
self.request_blocks = {} # request_id -> [block_ids]
def allocate(self, request_id: str, num_tokens: int) -> List[int]:
"""为请求分配block"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
block_ids = []
for i in range(num_blocks_needed):
# 寻找空闲block
for j, allocated in enumerate(self.block_allocated):
if not allocated:
self.block_allocated[j] = True
block_ids.append(j)
break
self.request_blocks[request_id] = block_ids
return block_ids
def write(
self,
request_id: str,
kv: torch.Tensor, # [num_heads, num_tokens, head_dim]
position: int
):
"""写入KV数据"""
block_ids = self.request_blocks[request_id]
for i, block_id in enumerate(block_ids):
start = i * self.block_size
end = min(start + self.block_size, kv.shape[1])
self.kv_blocks[block_id, :, :end-start, :] = kv[:, start:end, :]
def read(
self,
request_id: str,
positions: List[int]
) -> torch.Tensor:
"""读取KV数据"""
block_ids = self.request_blocks[request_id]
# 收集所有需要的block
kv_data = []
for pos in positions:
block_idx = pos // self.block_size
offset = pos % self.block_size
if block_idx < len(block_ids):
kv_data.append(
self.kv_blocks[block_ids[block_idx], :, offset, :]
)
return torch.stack(kv_data, dim=1)
def free(self, request_id: str):
"""释放请求的block"""
if request_id in self.request_blocks:
for block_id in self.request_blocks[request_id]:
self.block_allocated[block_id] = False
del self.request_blocks[request_id]4.3 投机解码集成
Titan集成了投机解码加速:
class TitanInferenceEngine:
"""
Titan推理引擎
整合连续批处理、PagedAttention和投机解码
"""
def __init__(
self,
model,
draft_model, # 投机解码的小模型
max_batch_size: int = 32,
block_size: int = 16
):
self.model = model
self.draft_model = draft_model
# 批处理调度器
self.batcher = ContinuousBatcher(
model, max_batch_size
)
# Paged KV Cache
self.kv_cache = PagedKVCache(block_size=block_size)
# 投机解码调度
self.speculator = SpeculativeDecoder(
main_model=model,
draft_model=draft_model,
max_draft_tokens=6
)
def generate(
self,
prompts: List[str],
max_tokens: int = 100
) -> List[str]:
"""生成文本"""
# 添加请求
requests = [
Request(
request_id=str(i),
prompt_ids=self.tokenizer.encode(p),
max_new_tokens=max_tokens
)
for i, p in enumerate(prompts)
]
for req in requests:
self.batcher.add_request(req)
# 推理循环
while self.batcher.pending_requests or self.batcher.running_requests:
# 1. 连续批处理调度
self.batcher._prepare_batch()
# 2. 执行prefill或decode
for req in self.batcher.running_requests:
if len(req.generated_ids) == 0:
# Prefill
self._prefill_request(req)
else:
# 投机解码
self._speculative_decode(req)
# 3. 完成请求
completed = [
req for req in self.batcher.running_requests
if req.is_finished
]
for req in completed:
self.batcher.running_requests.remove(req)
self.batcher.completed_requests.append(req)
# 解码结果
return [
self.tokenizer.decode(req.generated_ids)
for req in self.batcher.completed_requests
]
def _prefill_request(self, request: Request):
"""Prefill请求"""
# 使用chunked prefill
output = self.model.prefill(
[request.prompt_ids],
kv_cache=self.kv_cache
)
request.generated_ids.append(output.token_id)
def _speculative_decode(self, request: Request):
"""投机解码"""
# 使用draft model生成多个候选
draft_tokens = self.speculator.speculate(
request.generated_ids[-1],
num_tokens=6
)
# 用main model验证
accepted = self.speculator.verify(
request.generated_ids + draft_tokens
)
# 更新generated_ids
request.generated_ids = accepted5. 调度策略
5.1 优先级调度
class PriorityScheduler:
"""
优先级调度器
根据请求类型和截止时间分配优先级
"""
def __init__(self):
self.queues = {
'critical': [], # 关键任务(实时性要求高)
'normal': [], # 普通任务
'batch': [] # 批处理任务(可延迟)
}
def add_request(self, request: Request, priority: str = 'normal'):
"""添加请求"""
heapq.heappush(
self.queues[priority],
(-request.remaining_tokens, request) # 短请求优先
)
def get_next_batch(self, batch_size: int) -> List[Request]:
"""获取下一批请求"""
batch = []
# 按优先级获取
for priority in ['critical', 'normal', 'batch']:
while self.queues[priority] and len(batch) < batch_size:
_, request = heapq.heappop(self.queues[priority])
batch.append(request)
return batch5.2 公平调度
class FairScheduler:
"""
公平调度器
确保各用户的请求都能得到处理
"""
def __init__(self, max_batch_size: int = 32):
self.max_batch_size = max_batch_size
self.user_credits = defaultdict(lambda: 1.0) # 每用户的信用
self.user_queues = defaultdict(list)
def add_request(self, request: Request, user_id: str):
"""添加请求"""
self.user_queues[user_id].append(request)
def get_next_batch(self) -> List[Request]:
"""获取公平分配的batch"""
batch = []
# 计算各用户可用slot
user_slots = {
uid: min(
len(queue),
int(self.max_batch_size * self.user_credits[uid])
)
for uid, queue in self.user_queues.items()
}
# 分配slot
for uid, slots in user_slots.items():
queue = self.user_queues[uid]
for _ in range(slots):
if queue:
batch.append(queue.pop(0))
self.user_credits[uid] *= 0.9 # 消耗信用
# 补充信用
for uid in self.user_credits:
self.user_credits[uid] = min(1.0, self.user_credits[uid] + 0.1)
return batch6. 性能对比
6.1 吞吐量对比
| 方法 | 吞吐量 (tokens/s) | 相对提升 |
|---|---|---|
| 无批处理 | 45 | 1.0x |
| 静态批处理 | 120 | 2.7x |
| 连续批处理 | 280 | 6.2x |
| Titan | 420 | 9.3x |
6.2 延迟对比
| 百分位 | 无批处理 (ms) | 连续批处理 (ms) | Titan (ms) |
|---|---|---|---|
| P50 | 120 | 150 | 80 |
| P90 | 350 | 400 | 180 |
| P99 | 800 | 900 | 350 |
6.3 GPU利用率
| 方法 | 平均GPU利用率 | 峰值GPU利用率 |
|---|---|---|
| 无批处理 | 25% | 40% |
| 静态批处理 | 45% | 70% |
| 连续批处理 | 75% | 95% |
| Titan | 88% | 98% |
7. 实践建议
7.1 配置推荐
# Titan配置推荐
titan_config = {
'max_batch_size': 32, # 最大batch size
'prefill_chunk_size': 512, # prefill分块大小
'block_size': 16, # PagedAttention块大小
'use_speculative': True, # 启用投机解码
'draft_model': 'llama-135m', # draft模型
'max_draft_tokens': 6, # 最大draft token数
}7.2 瓶颈诊断
def diagnose_inference_bottleneck(model, profiler):
"""
诊断推理瓶颈
"""
results = profiler.profile()
bottlenecks = {}
# 计算各阶段耗时
prefill_time = results['prefill'] / results['total']
decode_time = results['decode'] / results['total']
# 诊断
if prefill_time > 0.5:
bottlenecks['prefill'] = {
'issue': 'Prefill is dominant',
'suggestion': 'Increase prefill chunk size or parallelize'
}
if decode_time > 0.5:
bottlenecks['decode'] = {
'issue': 'Decode is dominant (memory bound)',
'suggestion': 'Use smaller precision or KV cache compression'
}
# 检查内存
if results['memory_utilization'] < 0.7:
bottlenecks['memory'] = {
'issue': 'Low memory utilization',
'suggestion': 'Increase batch size'
}
return bottlenecks8. 总结
连续批处理与Titan系统的核心贡献:
- 连续批处理:动态batch调度,最大化GPU利用率
- Chunked Prefill:避免长prompt阻塞
- PagedAttention:高效的KV Cache管理
- 投机解码:加速生成
- 智能调度:优先级和公平性保证