连续批处理与Titan推理系统

1. 概述

连续批处理(Continuous Batching)是LLM推理系统中提升GPU利用率的核心技术。Titan推理系统则代表了2025年SOTA的推理优化方案,通过融合连续批处理、投机解码和智能调度来实现极致性能。

核心问题

传统批处理的局限性:

  • 静态批处理:所有请求必须同时完成
  • GPU空闲:短请求等待长请求
  • 资源浪费:无法动态调整

连续批处理的解决方案

动态批处理:新请求可替换已完成请求

结果:GPU利用率从30%提升至90%+

2. 批处理范式对比

2.1 静态批处理

时间 →
┌─────────────────────────────────────────────────────────┐
│ Request A (10 tokens)  │████████│                         │
│ Request B (50 tokens)   │████████░░░░░░░░░░░░░░░░░░░░░░│
│ Request C (30 tokens)   │████████░░░░░░░░░░░░░░░░░░░░░░│
└─────────────────────────────────────────────────────────┘
                              ↑
                         GPU空闲等待

利用率: ~30%

2.2 连续批处理

时间 →
┌─────────────────────────────────────────────────────────┐
│ Batch 1: A, B, C                                       │
│         │▓▓│▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓│
│                                                          │
│ Batch 2: D, E, B, C  ← A完成,D/E加入                   │
│            │▓▓│▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓│            │
│                                                          │
│ Batch 3: F, G, B, C  ← D/E完成,F/G加入                 │
│               │▓▓│▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓│            │
└─────────────────────────────────────────────────────────┘
                              ↑
                         GPU持续工作

利用率: ~90%

3. 连续批处理实现

3.1 核心算法

from dataclasses import dataclass, field
from typing import List, Optional
import heapq
 
@dataclass
class Request:
    """推理请求"""
    request_id: str
    prompt_ids: List[int]
    max_new_tokens: int
    generated_ids: List[int] = field(default_factory=list)
    is_finished: bool = False
    
    @property
    def total_length(self) -> int:
        return len(self.prompt_ids) + len(self.generated_ids)
    
    @property
    def remaining_tokens(self) -> int:
        return self.max_new_tokens - len(self.generated_ids)
 
class ContinuousBatcher:
    """
    连续批处理调度器
    
    核心思想:
    1. 保持batch动态更新
    2. 已完成请求立即替换为新请求
    3. 最大化GPU利用率
    """
    
    def __init__(
        self,
        model,  # 推理模型
        max_batch_size: int = 32,
        max_sequence_length: int = 4096,
        prefill_chunk_size: int = 512
    ):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_sequence_length = max_sequence_length
        self.prefill_chunk_size = prefill_chunk_size
        
        # 请求管理
        self.pending_requests: List[Request] = []  # 等待调度的请求
        self.running_requests: List[Request] = []  # 正在运行的请求
        self.completed_requests: List[Request] = []  # 已完成的请求
        
        # 批处理状态
        self.prompt_dicts = []  # 当前batch的prompt
        self.max_running_length = 0  # 当前batch最大长度
        
    def add_request(self, request: Request):
        """添加新请求"""
        self.pending_requests.append(request)
    
    def _prepare_batch(self) -> bool:
        """
        准备下一个batch
        
        Returns:
            是否准备好进行推理
        """
        # 计算当前batch的最大长度
        if not self.running_requests:
            self.max_running_length = 0
        else:
            self.max_running_length = max(
                req.total_length for req in self.running_requests
            )
        
        # 计算可以加入的新请求数量
        available_slots = self.max_batch_size - len(self.running_requests)
        
        # 填充新请求(考虑长度约束)
        while (self.pending_requests and 
               available_slots > 0):
            next_req = self.pending_requests[0]
            
            # 检查长度约束
            new_total_length = max(
                self.max_running_length,
                len(next_req.prompt_ids)
            ) + 1  # 至少要生成1个token
            
            if new_total_length <= self.max_sequence_length:
                # 可以加入
                self.running_requests.append(
                    self.pending_requests.pop(0)
                )
                available_slots -= 1
            else:
                # 长度超出,标记失败
                next_req.is_finished = True
                next_req.error = "Sequence too long"
                self.completed_requests.append(
                    self.pending_requests.pop(0)
                )
        
        return len(self.running_requests) > 0
    
    def _execute_batch(self):
        """执行一个推理步骤"""
        if not self.running_requests:
            return
        
        # 1. 准备输入
        # 检测哪些请求需要prefill
        prefill_indices = []
        decode_indices = []
        
        for i, req in enumerate(self.running_requests):
            if len(req.generated_ids) == 0:
                prefill_indices.append(i)
            else:
                decode_indices.append(i)
        
        # 2. 合并prefill请求(chunk处理)
        if prefill_indices:
            self._execute_prefill(prefill_indices)
        
        # 3. 执行decode
        if decode_indices:
            self._execute_decode(decode_indices)
    
    def _execute_prefill(self, indices: List[int]):
        """执行prefill阶段"""
        # 将多个prefill请求合并为一个batch
        # 实际实现中可能需要chunk处理
        
        prompts = [
            self.running_requests[i].prompt_ids 
            for i in indices
        ]
        
        # 执行prefill
        outputs = self.model.prefill(prompts)
        
        # 更新请求状态
        for idx, output in zip(indices, outputs):
            req = self.running_requests[idx]
            req.generated_ids.append(output.token_id)
            
            # 检查是否结束
            if output.token_id == self.model.eos_token_id:
                req.is_finished = True
    
    def _execute_decode(self, indices: List[int]):
        """执行decode阶段"""
        # 获取当前running的最大长度
        max_len = max(
            self.running_requests[i].total_length
            for i in indices
        )
        
        # 获取当前token
        current_tokens = [
            self.running_requests[i].generated_ids[-1]
            for i in indices
        ]
        
        # 执行decode
        outputs = self.model.decode(current_tokens)
        
        # 更新请求状态
        finished = []
        for idx, output in zip(indices, outputs):
            req = self.running_requests[idx]
            req.generated_ids.append(output.token_id)
            
            # 检查结束条件
            if (output.token_id == self.model.eos_token_id or
                len(req.generated_ids) >= req.max_new_tokens):
                req.is_finished = True
                finished.append(idx)
        
        # 移除已完成的请求
        for idx in sorted(finished, reverse=True):
            req = self.running_requests.pop(idx)
            self.completed_requests.append(req)
    
    def step(self) -> List[Request]:
        """
        执行一个完整的推理步骤
        
        Returns:
            本步完成的所有请求
        """
        # 1. 准备batch
        self._prepare_batch()
        
        # 2. 执行推理
        self._execute_batch()
        
        # 3. 返回完成的请求
        completed = self.completed_requests
        self.completed_requests = []
        return completed
    
    def run(self):
        """持续运行直到所有请求完成"""
        while self.pending_requests or self.running_requests:
            completed = self.step()
            yield completed

3.2 Chunked Prefill

长prompt的prefill非常耗时,会阻塞其他请求。Chunked Prefill解决方案:

def chunked_prefill(
    model,
    prompt_ids: List[int],
    kv_cache: Any,
    chunk_size: int = 512
):
    """
    分块处理长prompt的prefill
    
    避免长prompt独占整个batch
    """
    num_chunks = (len(prompt_ids) + chunk_size - 1) // chunk_size
    
    for chunk_idx in range(num_chunks):
        start = chunk_idx * chunk_size
        end = min(start + chunk_size, len(prompt_ids))
        
        chunk_ids = prompt_ids[start:end]
        
        # 执行chunk的prefill
        outputs, kv_cache = model.prefill_step(
            chunk_ids, 
            kv_cache,
            is_first_chunk=(chunk_idx == 0),
            is_last_chunk=(chunk_idx == num_chunks - 1)
        )
        
        yield chunk_idx, num_chunks, outputs
 
 
class ChunkedContinuousBatcher(ContinuousBatcher):
    """
    支持Chunked Prefill的连续批处理
    """
    
    def __init__(self, *args, prefill_chunk_size: int = 512, **kwargs):
        super().__init__(*args, **kwargs)
        self.prefill_chunk_size = prefill_chunk_size
        self.chunking_requests = {}  # 正在chunk prefill的请求
    
    def _execute_prefill(self, indices: List[int]):
        """执行分块prefill"""
        new_prefill = []
        chunking = []
        
        for idx in indices:
            req = self.running_requests[idx]
            
            if req.request_id not in self.chunking_requests:
                # 新请求,开始chunk prefill
                self.chunking_requests[req.request_id] = {
                    'request': req,
                    'chunk_idx': 0,
                    'num_chunks': (len(req.prompt_ids) + 
                                   self.prefill_chunk_size - 1) // 
                                  self.prefill_chunk_size,
                    'position': idx
                }
                chunking.append(req.request_id)
            else:
                # 正在chunk的请求
                chunking.append(req.request_id)
        
        # 优先处理正在chunk的请求
        # 它们有更高的优先级(已部分完成)
        ordered_indices = []
        for rid in chunking:
            info = self.chunking_requests[rid]
            if info['chunk_idx'] > 0:  # 优先已完成部分chunk的
                ordered_indices.append(info['position'])
        
        # 执行chunk prefill(简化版)
        for rid in chunking:
            info = self.chunking_requests[rid]
            req = info['request']
            chunk_idx = info['chunk_idx']
            
            start = chunk_idx * self.prefill_chunk_size
            end = min(start + self.prefill_chunk_size, len(req.prompt_ids))
            
            # 执行chunk
            output = model.prefill_chunk(
                req.prompt_ids[start:end],
                info.get('kv_cache', None)
            )
            
            info['kv_cache'] = output.kv_cache
            info['chunk_idx'] += 1
            
            # 检查是否完成
            if info['chunk_idx'] >= info['num_chunks']:
                # Prefill完成,生成第一个token
                req.generated_ids.append(output.first_token)
                del self.chunking_requests[rid]
                
                # 检查是否结束
                if output.first_token == model.eos_token_id:
                    req.is_finished = True

4. Titan推理系统

4.1 Titan架构概览

Titan是SOTA的LLM推理系统,融合多种优化技术:

┌─────────────────────────────────────────────────────────────┐
│                      Request Scheduler                        │
│         (连续批处理 + 智能调度 + 优先级队列)                  │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│                      Prefill Engine                          │
│              (Chunked Prefill + 算子融合)                    │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│                      Decode Engine                           │
│        (投机解码 + KV Cache管理 + 连续批处理)                 │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│                    Memory Manager                            │
│         (PagedAttention + 分页KV Cache + 量化)                │
└─────────────────────────────────────────────────────────────┘

4.2 PagedAttention

Titan使用PagedAttention管理KV Cache:

class PagedKVCache:
    """
    PagedAttention的KV Cache管理
    
    将KV Cache组织为固定大小的block
    类似操作系统的分页内存管理
    """
    
    def __init__(
        self,
        block_size: int = 16,  # 每个block的token数
        num_blocks: int = 1024,  # 总block数
        num_heads: int = 32,
        head_dim: int = 128
    ):
        self.block_size = block_size
        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        # 预分配的KV block池
        self.kv_blocks = torch.zeros(
            num_blocks, num_heads, block_size, head_dim
        )
        self.block_allocated = [False] * num_blocks
        
        # 请求的block映射
        self.request_blocks = {}  # request_id -> [block_ids]
        
    def allocate(self, request_id: str, num_tokens: int) -> List[int]:
        """为请求分配block"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
        block_ids = []
        
        for i in range(num_blocks_needed):
            # 寻找空闲block
            for j, allocated in enumerate(self.block_allocated):
                if not allocated:
                    self.block_allocated[j] = True
                    block_ids.append(j)
                    break
        
        self.request_blocks[request_id] = block_ids
        return block_ids
    
    def write(
        self,
        request_id: str,
        kv: torch.Tensor,  # [num_heads, num_tokens, head_dim]
        position: int
    ):
        """写入KV数据"""
        block_ids = self.request_blocks[request_id]
        
        for i, block_id in enumerate(block_ids):
            start = i * self.block_size
            end = min(start + self.block_size, kv.shape[1])
            
            self.kv_blocks[block_id, :, :end-start, :] = kv[:, start:end, :]
    
    def read(
        self,
        request_id: str,
        positions: List[int]
    ) -> torch.Tensor:
        """读取KV数据"""
        block_ids = self.request_blocks[request_id]
        
        # 收集所有需要的block
        kv_data = []
        for pos in positions:
            block_idx = pos // self.block_size
            offset = pos % self.block_size
            
            if block_idx < len(block_ids):
                kv_data.append(
                    self.kv_blocks[block_ids[block_idx], :, offset, :]
                )
        
        return torch.stack(kv_data, dim=1)
    
    def free(self, request_id: str):
        """释放请求的block"""
        if request_id in self.request_blocks:
            for block_id in self.request_blocks[request_id]:
                self.block_allocated[block_id] = False
            del self.request_blocks[request_id]

4.3 投机解码集成

Titan集成了投机解码加速:

class TitanInferenceEngine:
    """
    Titan推理引擎
    整合连续批处理、PagedAttention和投机解码
    """
    
    def __init__(
        self,
        model,
        draft_model,  # 投机解码的小模型
        max_batch_size: int = 32,
        block_size: int = 16
    ):
        self.model = model
        self.draft_model = draft_model
        
        # 批处理调度器
        self.batcher = ContinuousBatcher(
            model, max_batch_size
        )
        
        # Paged KV Cache
        self.kv_cache = PagedKVCache(block_size=block_size)
        
        # 投机解码调度
        self.speculator = SpeculativeDecoder(
            main_model=model,
            draft_model=draft_model,
            max_draft_tokens=6
        )
    
    def generate(
        self,
        prompts: List[str],
        max_tokens: int = 100
    ) -> List[str]:
        """生成文本"""
        # 添加请求
        requests = [
            Request(
                request_id=str(i),
                prompt_ids=self.tokenizer.encode(p),
                max_new_tokens=max_tokens
            )
            for i, p in enumerate(prompts)
        ]
        
        for req in requests:
            self.batcher.add_request(req)
        
        # 推理循环
        while self.batcher.pending_requests or self.batcher.running_requests:
            # 1. 连续批处理调度
            self.batcher._prepare_batch()
            
            # 2. 执行prefill或decode
            for req in self.batcher.running_requests:
                if len(req.generated_ids) == 0:
                    # Prefill
                    self._prefill_request(req)
                else:
                    # 投机解码
                    self._speculative_decode(req)
            
            # 3. 完成请求
            completed = [
                req for req in self.batcher.running_requests
                if req.is_finished
            ]
            for req in completed:
                self.batcher.running_requests.remove(req)
                self.batcher.completed_requests.append(req)
        
        # 解码结果
        return [
            self.tokenizer.decode(req.generated_ids)
            for req in self.batcher.completed_requests
        ]
    
    def _prefill_request(self, request: Request):
        """Prefill请求"""
        # 使用chunked prefill
        output = self.model.prefill(
            [request.prompt_ids],
            kv_cache=self.kv_cache
        )
        
        request.generated_ids.append(output.token_id)
    
    def _speculative_decode(self, request: Request):
        """投机解码"""
        # 使用draft model生成多个候选
        draft_tokens = self.speculator.speculate(
            request.generated_ids[-1],
            num_tokens=6
        )
        
        # 用main model验证
        accepted = self.speculator.verify(
            request.generated_ids + draft_tokens
        )
        
        # 更新generated_ids
        request.generated_ids = accepted

5. 调度策略

5.1 优先级调度

class PriorityScheduler:
    """
    优先级调度器
    
    根据请求类型和截止时间分配优先级
    """
    
    def __init__(self):
        self.queues = {
            'critical': [],   # 关键任务(实时性要求高)
            'normal': [],     # 普通任务
            'batch': []       # 批处理任务(可延迟)
        }
    
    def add_request(self, request: Request, priority: str = 'normal'):
        """添加请求"""
        heapq.heappush(
            self.queues[priority],
            (-request.remaining_tokens, request)  # 短请求优先
        )
    
    def get_next_batch(self, batch_size: int) -> List[Request]:
        """获取下一批请求"""
        batch = []
        
        # 按优先级获取
        for priority in ['critical', 'normal', 'batch']:
            while self.queues[priority] and len(batch) < batch_size:
                _, request = heapq.heappop(self.queues[priority])
                batch.append(request)
        
        return batch

5.2 公平调度

class FairScheduler:
    """
    公平调度器
    
    确保各用户的请求都能得到处理
    """
    
    def __init__(self, max_batch_size: int = 32):
        self.max_batch_size = max_batch_size
        self.user_credits = defaultdict(lambda: 1.0)  # 每用户的信用
        self.user_queues = defaultdict(list)
    
    def add_request(self, request: Request, user_id: str):
        """添加请求"""
        self.user_queues[user_id].append(request)
    
    def get_next_batch(self) -> List[Request]:
        """获取公平分配的batch"""
        batch = []
        
        # 计算各用户可用slot
        user_slots = {
            uid: min(
                len(queue),
                int(self.max_batch_size * self.user_credits[uid])
            )
            for uid, queue in self.user_queues.items()
        }
        
        # 分配slot
        for uid, slots in user_slots.items():
            queue = self.user_queues[uid]
            for _ in range(slots):
                if queue:
                    batch.append(queue.pop(0))
                    self.user_credits[uid] *= 0.9  # 消耗信用
        
        # 补充信用
        for uid in self.user_credits:
            self.user_credits[uid] = min(1.0, self.user_credits[uid] + 0.1)
        
        return batch

6. 性能对比

6.1 吞吐量对比

方法吞吐量 (tokens/s)相对提升
无批处理451.0x
静态批处理1202.7x
连续批处理2806.2x
Titan4209.3x

6.2 延迟对比

百分位无批处理 (ms)连续批处理 (ms)Titan (ms)
P5012015080
P90350400180
P99800900350

6.3 GPU利用率

方法平均GPU利用率峰值GPU利用率
无批处理25%40%
静态批处理45%70%
连续批处理75%95%
Titan88%98%

7. 实践建议

7.1 配置推荐

# Titan配置推荐
 
titan_config = {
    'max_batch_size': 32,          # 最大batch size
    'prefill_chunk_size': 512,    # prefill分块大小
    'block_size': 16,              # PagedAttention块大小
    'use_speculative': True,       # 启用投机解码
    'draft_model': 'llama-135m',   # draft模型
    'max_draft_tokens': 6,        # 最大draft token数
}

7.2 瓶颈诊断

def diagnose_inference_bottleneck(model, profiler):
    """
    诊断推理瓶颈
    """
    results = profiler.profile()
    
    bottlenecks = {}
    
    # 计算各阶段耗时
    prefill_time = results['prefill'] / results['total']
    decode_time = results['decode'] / results['total']
    
    # 诊断
    if prefill_time > 0.5:
        bottlenecks['prefill'] = {
            'issue': 'Prefill is dominant',
            'suggestion': 'Increase prefill chunk size or parallelize'
        }
    
    if decode_time > 0.5:
        bottlenecks['decode'] = {
            'issue': 'Decode is dominant (memory bound)',
            'suggestion': 'Use smaller precision or KV cache compression'
        }
    
    # 检查内存
    if results['memory_utilization'] < 0.7:
        bottlenecks['memory'] = {
            'issue': 'Low memory utilization',
            'suggestion': 'Increase batch size'
        }
    
    return bottlenecks

8. 总结

连续批处理与Titan系统的核心贡献:

  1. 连续批处理:动态batch调度,最大化GPU利用率
  2. Chunked Prefill:避免长prompt阻塞
  3. PagedAttention:高效的KV Cache管理
  4. 投机解码:加速生成
  5. 智能调度:优先级和公平性保证

参考文献