DuoAttention双头注意力

1. 概述

DuoAttention是2024年提出的新型注意力架构优化方法,其核心思想是将注意力头分为两类:专门处理检索任务的检索头(Retrieval Heads)和专门处理流式生成的流式头(Streaming Heads)。通过这种分离,可以针对性地优化不同类型的注意力计算。

核心洞察

语言模型的不同注意力头有不同的功能角色:

  • 检索头:需要关注整个历史,用于信息检索
  • 流式头:主要关注最近上下文,用于语言建模

区分这两类头可以实现更高效的推理

2. 问题背景

2.1 长上下文的挑战

现有方法在处理长上下文时面临权衡:

方法长上下文内存效率实现复杂度
Full Attention✓ 完美✗ 差
StreamingLLM✓ 无限✓ 好
H2O✓ 好✓ 好
DuoAttention✓ 好✓ 更好

2.2 现有方法的局限

  • 均匀压缩:对所有注意力头使用相同策略
  • 忽略头功能差异:不同头需要不同处理
  • 缺乏针对性优化:无法针对特定任务优化

3. DuoAttention理论框架

3.1 注意力头的功能分类

def classify_attention_heads(
    model,
    tokenizer,
    retrieval_prompts: List[str],
    language_modeling_texts: List[str]
) -> Dict[str, List[int]]:
    """
    分类注意力头的功能类型
    
    方法:
    1. 检索任务中激活强烈的 -> 检索头
    2. 语言建模中激活强烈的 -> 流式头
    """
    retrieval_activations = {}
    lm_activations = {}
    
    # 1. 测试检索任务激活
    for prompt in retrieval_prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
        
        # 记录每个头的激活强度
        attn = torch.stack(outputs.attentions).mean(dim=[0, 1])
        for head_idx in range(model.config.num_attention_heads):
            if head_idx not in retrieval_activations:
                retrieval_activations[head_idx] = []
            retrieval_activations[head_idx].append(attn[:, head_idx, :, :].mean().item())
    
    # 2. 测试语言建模激活
    for text in language_modeling_texts:
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
        
        attn = torch.stack(outputs.attentions).mean(dim=[0, 1])
        for head_idx in range(model.config.num_attention_heads):
            if head_idx not in lm_activations:
                lm_activations[head_idx] = []
            lm_activations[head_idx].append(attn[:, head_idx, :, :].mean().item())
    
    # 3. 分类
    retrieval_heads = []
    streaming_heads = []
    
    for head_idx in range(model.config.num_attention_heads):
        retrieval_score = np.mean(retrieval_activations[head_idx])
        lm_score = np.mean(lm_activations[head_idx])
        
        if retrieval_score > lm_score * 1.5:
            retrieval_heads.append(head_idx)
        elif lm_score > retrieval_score * 1.5:
            streaming_heads.append(head_idx)
        else:
            # 中间类型,默认归为流式头
            streaming_heads.append(head_idx)
    
    return {
        'retrieval_heads': retrieval_heads,
        'streaming_heads': streaming_heads
    }

3.2 实验观察

典型LLaMA-2 7B的注意力头分类结果:

类型头数量占比特征
检索头4-85-10%关注特定关键词、实体
流式头24-2830-35%关注最近上下文
混合头16-2420-30%两者皆有

4. DuoAttention架构

4.1 双路径设计

标准注意力:
┌─────────────────────────────────────────────────┐
│ Input → [Q, K, V] → Attention → Output        │
│           ↑                                   │
│        全部Token参与                            │
└─────────────────────────────────────────────────┘

DuoAttention:
┌─────────────────────────────────────────────────┐
│                                                  │
│  检索头路径:                                     │
│  Input → Q_ret → [K_all, V_all] → Attention → Output
│                  ↑                               │
│             完整历史检索                          │
│                                                  │
│  流式头路径:                                     │
│  Input → Q_str → [K_window, V_window] → Attn → Output
│                   ↑                              │
│              滑动窗口                            │
│                                                  │
└─────────────────────────────────────────────────┘

4.2 实现

class DuoAttention(nn.Module):
    """
    DuoAttention: 双路径注意力机制
    
    检索头:使用完整KV Cache
    流式头:使用滑动窗口KV Cache
    """
    
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int,
        retrieval_head_indices: List[int],
        window_size: int = 512,
        dropout: float = 0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.window_size = window_size
        
        # 记录头类型
        self.retrieval_head_indices = set(retrieval_head_indices)
        self.streaming_head_indices = set(
            i for i in range(num_heads) 
            if i not in self.retrieval_head_indices
        )
        
        # 共享QKV投影
        self.q_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.k_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.v_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_size)
        
        # KV缓存
        self.full_kv_cache = FullKVCache(num_heads, head_dim)
        self.streaming_kv_cache = StreamingKVCache(
            num_heads, head_dim, window_size
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = True,
        is_prefill: bool = True
    ):
        B, T, _ = hidden_states.shape
        
        # QKV投影
        q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
        k = self.k_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
        v = self.v_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
        
        # 调整维度顺序
        q = q.transpose(1, 2)  # [B, H, T, D]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 分离检索头和流式头的Q
        q_ret = q[:, list(self.retrieval_head_indices), :, :]
        q_str = q[:, list(self.streaming_head_indices), :, :]
        
        if use_cache:
            # 更新缓存
            if is_prefill:
                # Prefill阶段:更新完整缓存
                self.full_kv_cache.update(k, v)
                
                # 流式头使用窗口缓存
                self.streaming_kv_cache.update(k, v)
                
                # 完整注意力计算
                output_ret = self._full_attention(
                    q_ret, 
                    self.full_kv_cache.k, 
                    self.full_kv_cache.v
                )
                
                output_str = self._window_attention(
                    q_str,
                    self.streaming_kv_cache.k,
                    self.streaming_kv_cache.v
                )
            else:
                # Decode阶段:增量更新
                k_new = k[:, :, -1:, :]
                v_new = v[:, :, -1:, :]
                
                self.full_kv_cache.update(k_new, v_new)
                self.streaming_kv_cache.update(k_new, v_new)
                
                # 检索头使用完整缓存
                output_ret = self._full_attention(
                    q_ret[:, :, -1:, :],
                    self.full_kv_cache.k,
                    self.full_kv_cache.v
                )
                
                # 流式头使用窗口缓存
                output_str = self._window_attention(
                    q_str[:, :, -1:, :],
                    self.streaming_kv_cache.k,
                    self.streaming_kv_cache.v
                )
        else:
            # 无缓存的完整注意力
            output_ret = self._full_attention(q_ret, k, v)
            output_str = self._full_attention(q_str, k, v)
        
        # 合并结果
        output = torch.zeros_like(q.transpose(1, 2))
        output[:, list(self.retrieval_head_indices), :, :] = output_ret
        output[:, list(self.streaming_head_indices), :, :] = output_str
        
        # 输出投影
        output = output.transpose(1, 2).contiguous().view(B, T, -1)
        return self.o_proj(output)
    
    def _full_attention(self, q, k, v):
        """完整注意力计算"""
        scale = self.head_dim ** -0.5
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_weights, v)
    
    def _window_attention(self, q, k, v):
        """滑动窗口注意力计算"""
        scale = self.head_dim ** -0.5
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # 因果掩码
        seq_len = k.shape[2]
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=k.device, dtype=torch.bool),
            diagonal=1
        )
        attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_weights, v)

5. KV Cache管理

5.1 分离缓存策略

class SeparatedKVCache:
    """
    DuoAttention的分离KV缓存
    
    检索头:完整历史
    流式头:滑动窗口
    """
    
    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        retrieval_head_indices: List[int],
        window_size: int = 512,
        max_seq_len: int = 65536
    ):
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.window_size = window_size
        
        # 检索头:预分配完整缓存
        self.retrieval_k = torch.zeros(
            num_heads, max_seq_len, head_dim
        )
        self.retrieval_v = torch.zeros(
            num_heads, max_seq_len, head_dim
        )
        self.retrieval_len = 0
        
        # 流式头:滑动窗口缓存
        self.streaming_k = torch.zeros(
            num_heads, window_size, head_dim
        )
        self.streaming_v = torch.zeros(
            num_heads, window_size, head_dim
        )
        self.streaming_ptr = 0
        
    def update(
        self,
        k_new: torch.Tensor,  # [B, H, T, D]
        v_new: torch.Tensor,
        head_indices: Dict[str, List[int]]
    ):
        """更新缓存"""
        B, H, T, D = k_new.shape
        
        # 更新检索头缓存
        for head_idx in head_indices['retrieval']:
            self.retrieval_k[
                head_idx, 
                self.retrieval_len:self.retrieval_len + T
            ] = k_new[0, head_idx]
            self.retrieval_v[
                head_idx,
                self.retrieval_len:self.retrieval_len + T
            ] = v_new[0, head_idx]
        
        # 更新流式头缓存(循环缓冲)
        for head_idx in head_indices['streaming']:
            # 使用模运算实现循环缓冲
            for t in range(T):
                ptr = (self.streaming_ptr + t) % self.window_size
                self.streaming_k[head_idx, ptr] = k_new[0, head_idx, t]
                self.streaming_v[head_idx, ptr] = v_new[0, head_idx, t]
        
        self.retrieval_len += T
        self.streaming_ptr = (self.streaming_ptr + T) % self.window_size
    
    def get_cache(
        self, 
        head_type: str
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取指定类型的缓存"""
        if head_type == 'retrieval':
            k = self.retrieval_k[:, :self.retrieval_len, :].unsqueeze(0)
            v = self.retrieval_v[:, :self.retrieval_len, :].unsqueeze(0)
        else:  # streaming
            k = self.streaming_k.unsqueeze(0)
            v = self.streaming_v.unsqueeze(0)
        
        return k, v

5.2 内存效率分析

配置检索头缓存流式头缓存总内存节省比例
全部完整100%100%100%-
全部窗口0%100%50%50%
DuoAttention100%50%75%25%

6. 任务适配性

6.1 检索任务

class RetrievalTaskOptimizer:
    """
    检索任务的DuoAttention优化
    """
    
    def __init__(self, duo_attention):
        self.duo_attention = duo_attention
        
        # 检索任务需要更多检索头
        self.adaptive_retrieval_ratio = 0.2  # 20%检索头
    
    def adapt_for_retrieval(
        self,
        query: str,
        document: str,
        top_k: int = 5
    ):
        """
        针对检索任务调整注意力
        """
        # 识别关键实体/关键词
        key_entities = self._extract_entities(query)
        
        # 动态增加相关位置的检索头激活
        for entity_pos in self._find_entity_positions(document, key_entities):
            self.duo_attention.set_retrieval_importance(entity_pos, 1.0)
        
        # 执行检索
        return self._retrieve(document, top_k)

6.2 流式生成

class StreamingTaskOptimizer:
    """
    流式生成任务的DuoAttention优化
    """
    
    def __init__(self, duo_attention):
        self.duo_attention = duo_attention
    
    def adapt_for_streaming(
        self,
        window_size: int = 1024
    ):
        """
        针对流式生成优化
        """
        # 流式场景:减少检索头比例
        retrieval_ratio = 0.05  # 仅5%检索头
        
        # 扩大流式头窗口
        self.duo_attention.streaming_kv_cache.window_size = window_size

7. 实验结果

7.1 检索性能

方法Needle-in-HaystackPassKey平均
Full Attention98.2%97.5%97.9%
StreamingLLM72.3%68.1%70.2%
H2O91.5%89.2%90.4%
DuoAttention96.8%95.3%96.1%

7.2 语言建模

方法WikiTextPile困惑度
Full Attention12.458.92基准
StreamingLLM12.689.15+3.5%
H2O12.529.02+1.2%
DuoAttention12.488.98+0.6%

7.3 内存效率

模型规模Full AttentionDuoAttention内存节省
7B80GB60GB25%
13B130GB95GB27%
70B280GB210GB25%

8. 与其他方法的对比

维度DuoAttentionStreamingLLMH2OPyramidKV
检索能力★★★★★★★☆☆☆★★★★☆★★★★☆
流式生成★★★★★★★★★★★★★★☆★★★★☆
内存效率★★★★☆★★★★★★★★★☆★★★★☆
实现复杂度
无需训练

9. 实践指南

9.1 头分类方法

def auto_classify_heads(model, calibration_data, num_retrieval_heads=None):
    """
    自动分类注意力头
    
    使用激活分析自动识别检索头和流式头
    """
    # 方法1:基于熵的分类
    def compute_attention_entropy(attn_weights):
        """注意力熵:低熵=检索头,高熵=流式头"""
        entropy = -(attn_weights * torch.log(attn_weights + 1e-10)).sum(dim=-1)
        return entropy.mean()
    
    # 收集激活统计
    head_stats = {i: {'entropy': [], 'max_attn': []} 
                  for i in range(model.config.num_attention_heads)}
    
    for data in calibration_data:
        outputs = model(**data, output_attentions=True)
        
        for layer_idx, attn in enumerate(outputs.attentions):
            avg_attn = attn.mean(dim=[0, 1])  # [H, T, T]
            
            for h in range(model.config.num_attention_heads):
                head_stats[h]['entropy'].append(
                    compute_attention_entropy(avg_attn[h])
                )
                head_stats[h]['max_attn'].append(
                    avg_attn[h].max().item()
                )
    
    # 分类
    retrieval_heads = []
    for h, stats in head_stats.items():
        avg_entropy = np.mean(stats['entropy'])
        avg_max = np.mean(stats['max_attn'])
        
        # 低熵+高最大值 = 检索头
        if avg_entropy < 2.0 and avg_max > 0.3:
            retrieval_heads.append(h)
    
    # 如果未指定数量,选择固定比例
    if num_retrieval_heads and len(retrieval_heads) != num_retrieval_heads:
        # 按max_attn排序,选择top-k
        sorted_heads = sorted(
            head_stats.items(),
            key=lambda x: np.mean(x[1]['max_attn']),
            reverse=True
        )
        retrieval_heads = [h for h, _ in sorted_heads[:num_retrieval_heads]]
    
    return retrieval_heads

9.2 配置推荐

# DuoAttention配置
 
# 通用配置
general_config = {
    'retrieval_head_ratio': 0.1,  # 10%检索头
    'window_size': 512
}
 
# 检索密集型
retrieval_config = {
    'retrieval_head_ratio': 0.2,  # 20%检索头
    'window_size': 256
}
 
# 流式密集型
streaming_config = {
    'retrieval_head_ratio': 0.05,  # 5%检索头
    'window_size': 1024
}

10. 总结

DuoAttention的核心贡献:

  1. 功能分离:识别并分离检索头和流式头的功能
  2. 针对性优化:不同类型的头使用不同的缓存策略
  3. 任务适配:可以根据任务类型调整头类型比例
  4. 无需训练:基于现有模型的即插即用方法

参考文献