DuoAttention双头注意力
1. 概述
DuoAttention是2024年提出的新型注意力架构优化方法,其核心思想是将注意力头分为两类:专门处理检索任务的检索头(Retrieval Heads)和专门处理流式生成的流式头(Streaming Heads)。通过这种分离,可以针对性地优化不同类型的注意力计算。
核心洞察
语言模型的不同注意力头有不同的功能角色:
- 检索头:需要关注整个历史,用于信息检索
- 流式头:主要关注最近上下文,用于语言建模
区分这两类头可以实现更高效的推理
2. 问题背景
2.1 长上下文的挑战
现有方法在处理长上下文时面临权衡:
| 方法 | 长上下文 | 内存效率 | 实现复杂度 |
|---|
| Full Attention | ✓ 完美 | ✗ 差 | 低 |
| StreamingLLM | ✓ 无限 | ✓ 好 | 低 |
| H2O | ✓ 好 | ✓ 好 | 中 |
| DuoAttention | ✓ 好 | ✓ 更好 | 中 |
2.2 现有方法的局限
- 均匀压缩:对所有注意力头使用相同策略
- 忽略头功能差异:不同头需要不同处理
- 缺乏针对性优化:无法针对特定任务优化
3. DuoAttention理论框架
3.1 注意力头的功能分类
def classify_attention_heads(
model,
tokenizer,
retrieval_prompts: List[str],
language_modeling_texts: List[str]
) -> Dict[str, List[int]]:
"""
分类注意力头的功能类型
方法:
1. 检索任务中激活强烈的 -> 检索头
2. 语言建模中激活强烈的 -> 流式头
"""
retrieval_activations = {}
lm_activations = {}
# 1. 测试检索任务激活
for prompt in retrieval_prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# 记录每个头的激活强度
attn = torch.stack(outputs.attentions).mean(dim=[0, 1])
for head_idx in range(model.config.num_attention_heads):
if head_idx not in retrieval_activations:
retrieval_activations[head_idx] = []
retrieval_activations[head_idx].append(attn[:, head_idx, :, :].mean().item())
# 2. 测试语言建模激活
for text in language_modeling_texts:
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
attn = torch.stack(outputs.attentions).mean(dim=[0, 1])
for head_idx in range(model.config.num_attention_heads):
if head_idx not in lm_activations:
lm_activations[head_idx] = []
lm_activations[head_idx].append(attn[:, head_idx, :, :].mean().item())
# 3. 分类
retrieval_heads = []
streaming_heads = []
for head_idx in range(model.config.num_attention_heads):
retrieval_score = np.mean(retrieval_activations[head_idx])
lm_score = np.mean(lm_activations[head_idx])
if retrieval_score > lm_score * 1.5:
retrieval_heads.append(head_idx)
elif lm_score > retrieval_score * 1.5:
streaming_heads.append(head_idx)
else:
# 中间类型,默认归为流式头
streaming_heads.append(head_idx)
return {
'retrieval_heads': retrieval_heads,
'streaming_heads': streaming_heads
}
3.2 实验观察
典型LLaMA-2 7B的注意力头分类结果:
| 类型 | 头数量 | 占比 | 特征 |
|---|
| 检索头 | 4-8 | 5-10% | 关注特定关键词、实体 |
| 流式头 | 24-28 | 30-35% | 关注最近上下文 |
| 混合头 | 16-24 | 20-30% | 两者皆有 |
4. DuoAttention架构
4.1 双路径设计
标准注意力:
┌─────────────────────────────────────────────────┐
│ Input → [Q, K, V] → Attention → Output │
│ ↑ │
│ 全部Token参与 │
└─────────────────────────────────────────────────┘
DuoAttention:
┌─────────────────────────────────────────────────┐
│ │
│ 检索头路径: │
│ Input → Q_ret → [K_all, V_all] → Attention → Output
│ ↑ │
│ 完整历史检索 │
│ │
│ 流式头路径: │
│ Input → Q_str → [K_window, V_window] → Attn → Output
│ ↑ │
│ 滑动窗口 │
│ │
└─────────────────────────────────────────────────┘
4.2 实现
class DuoAttention(nn.Module):
"""
DuoAttention: 双路径注意力机制
检索头:使用完整KV Cache
流式头:使用滑动窗口KV Cache
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
retrieval_head_indices: List[int],
window_size: int = 512,
dropout: float = 0.0
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.window_size = window_size
# 记录头类型
self.retrieval_head_indices = set(retrieval_head_indices)
self.streaming_head_indices = set(
i for i in range(num_heads)
if i not in self.retrieval_head_indices
)
# 共享QKV投影
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.k_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.v_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size)
# KV缓存
self.full_kv_cache = FullKVCache(num_heads, head_dim)
self.streaming_kv_cache = StreamingKVCache(
num_heads, head_dim, window_size
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = True,
is_prefill: bool = True
):
B, T, _ = hidden_states.shape
# QKV投影
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
k = self.k_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
v = self.v_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
# 调整维度顺序
q = q.transpose(1, 2) # [B, H, T, D]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 分离检索头和流式头的Q
q_ret = q[:, list(self.retrieval_head_indices), :, :]
q_str = q[:, list(self.streaming_head_indices), :, :]
if use_cache:
# 更新缓存
if is_prefill:
# Prefill阶段:更新完整缓存
self.full_kv_cache.update(k, v)
# 流式头使用窗口缓存
self.streaming_kv_cache.update(k, v)
# 完整注意力计算
output_ret = self._full_attention(
q_ret,
self.full_kv_cache.k,
self.full_kv_cache.v
)
output_str = self._window_attention(
q_str,
self.streaming_kv_cache.k,
self.streaming_kv_cache.v
)
else:
# Decode阶段:增量更新
k_new = k[:, :, -1:, :]
v_new = v[:, :, -1:, :]
self.full_kv_cache.update(k_new, v_new)
self.streaming_kv_cache.update(k_new, v_new)
# 检索头使用完整缓存
output_ret = self._full_attention(
q_ret[:, :, -1:, :],
self.full_kv_cache.k,
self.full_kv_cache.v
)
# 流式头使用窗口缓存
output_str = self._window_attention(
q_str[:, :, -1:, :],
self.streaming_kv_cache.k,
self.streaming_kv_cache.v
)
else:
# 无缓存的完整注意力
output_ret = self._full_attention(q_ret, k, v)
output_str = self._full_attention(q_str, k, v)
# 合并结果
output = torch.zeros_like(q.transpose(1, 2))
output[:, list(self.retrieval_head_indices), :, :] = output_ret
output[:, list(self.streaming_head_indices), :, :] = output_str
# 输出投影
output = output.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(output)
def _full_attention(self, q, k, v):
"""完整注意力计算"""
scale = self.head_dim ** -0.5
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
attn_weights = F.softmax(attn_scores, dim=-1)
return torch.matmul(attn_weights, v)
def _window_attention(self, q, k, v):
"""滑动窗口注意力计算"""
scale = self.head_dim ** -0.5
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# 因果掩码
seq_len = k.shape[2]
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=k.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
return torch.matmul(attn_weights, v)
5. KV Cache管理
5.1 分离缓存策略
class SeparatedKVCache:
"""
DuoAttention的分离KV缓存
检索头:完整历史
流式头:滑动窗口
"""
def __init__(
self,
num_heads: int,
head_dim: int,
retrieval_head_indices: List[int],
window_size: int = 512,
max_seq_len: int = 65536
):
self.num_heads = num_heads
self.head_dim = head_dim
self.window_size = window_size
# 检索头:预分配完整缓存
self.retrieval_k = torch.zeros(
num_heads, max_seq_len, head_dim
)
self.retrieval_v = torch.zeros(
num_heads, max_seq_len, head_dim
)
self.retrieval_len = 0
# 流式头:滑动窗口缓存
self.streaming_k = torch.zeros(
num_heads, window_size, head_dim
)
self.streaming_v = torch.zeros(
num_heads, window_size, head_dim
)
self.streaming_ptr = 0
def update(
self,
k_new: torch.Tensor, # [B, H, T, D]
v_new: torch.Tensor,
head_indices: Dict[str, List[int]]
):
"""更新缓存"""
B, H, T, D = k_new.shape
# 更新检索头缓存
for head_idx in head_indices['retrieval']:
self.retrieval_k[
head_idx,
self.retrieval_len:self.retrieval_len + T
] = k_new[0, head_idx]
self.retrieval_v[
head_idx,
self.retrieval_len:self.retrieval_len + T
] = v_new[0, head_idx]
# 更新流式头缓存(循环缓冲)
for head_idx in head_indices['streaming']:
# 使用模运算实现循环缓冲
for t in range(T):
ptr = (self.streaming_ptr + t) % self.window_size
self.streaming_k[head_idx, ptr] = k_new[0, head_idx, t]
self.streaming_v[head_idx, ptr] = v_new[0, head_idx, t]
self.retrieval_len += T
self.streaming_ptr = (self.streaming_ptr + T) % self.window_size
def get_cache(
self,
head_type: str
) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取指定类型的缓存"""
if head_type == 'retrieval':
k = self.retrieval_k[:, :self.retrieval_len, :].unsqueeze(0)
v = self.retrieval_v[:, :self.retrieval_len, :].unsqueeze(0)
else: # streaming
k = self.streaming_k.unsqueeze(0)
v = self.streaming_v.unsqueeze(0)
return k, v
5.2 内存效率分析
| 配置 | 检索头缓存 | 流式头缓存 | 总内存 | 节省比例 |
|---|
| 全部完整 | 100% | 100% | 100% | - |
| 全部窗口 | 0% | 100% | 50% | 50% |
| DuoAttention | 100% | 50% | 75% | 25% |
6. 任务适配性
6.1 检索任务
class RetrievalTaskOptimizer:
"""
检索任务的DuoAttention优化
"""
def __init__(self, duo_attention):
self.duo_attention = duo_attention
# 检索任务需要更多检索头
self.adaptive_retrieval_ratio = 0.2 # 20%检索头
def adapt_for_retrieval(
self,
query: str,
document: str,
top_k: int = 5
):
"""
针对检索任务调整注意力
"""
# 识别关键实体/关键词
key_entities = self._extract_entities(query)
# 动态增加相关位置的检索头激活
for entity_pos in self._find_entity_positions(document, key_entities):
self.duo_attention.set_retrieval_importance(entity_pos, 1.0)
# 执行检索
return self._retrieve(document, top_k)
6.2 流式生成
class StreamingTaskOptimizer:
"""
流式生成任务的DuoAttention优化
"""
def __init__(self, duo_attention):
self.duo_attention = duo_attention
def adapt_for_streaming(
self,
window_size: int = 1024
):
"""
针对流式生成优化
"""
# 流式场景:减少检索头比例
retrieval_ratio = 0.05 # 仅5%检索头
# 扩大流式头窗口
self.duo_attention.streaming_kv_cache.window_size = window_size
7. 实验结果
7.1 检索性能
| 方法 | Needle-in-Haystack | PassKey | 平均 |
|---|
| Full Attention | 98.2% | 97.5% | 97.9% |
| StreamingLLM | 72.3% | 68.1% | 70.2% |
| H2O | 91.5% | 89.2% | 90.4% |
| DuoAttention | 96.8% | 95.3% | 96.1% |
7.2 语言建模
| 方法 | WikiText | Pile | 困惑度 |
|---|
| Full Attention | 12.45 | 8.92 | 基准 |
| StreamingLLM | 12.68 | 9.15 | +3.5% |
| H2O | 12.52 | 9.02 | +1.2% |
| DuoAttention | 12.48 | 8.98 | +0.6% |
7.3 内存效率
| 模型规模 | Full Attention | DuoAttention | 内存节省 |
|---|
| 7B | 80GB | 60GB | 25% |
| 13B | 130GB | 95GB | 27% |
| 70B | 280GB | 210GB | 25% |
8. 与其他方法的对比
| 维度 | DuoAttention | StreamingLLM | H2O | PyramidKV |
|---|
| 检索能力 | ★★★★★ | ★★☆☆☆ | ★★★★☆ | ★★★★☆ |
| 流式生成 | ★★★★★ | ★★★★★ | ★★★★☆ | ★★★★☆ |
| 内存效率 | ★★★★☆ | ★★★★★ | ★★★★☆ | ★★★★☆ |
| 实现复杂度 | 中 | 低 | 中 | 中 |
| 无需训练 | ✓ | ✓ | ✓ | ✓ |
9. 实践指南
9.1 头分类方法
def auto_classify_heads(model, calibration_data, num_retrieval_heads=None):
"""
自动分类注意力头
使用激活分析自动识别检索头和流式头
"""
# 方法1:基于熵的分类
def compute_attention_entropy(attn_weights):
"""注意力熵:低熵=检索头,高熵=流式头"""
entropy = -(attn_weights * torch.log(attn_weights + 1e-10)).sum(dim=-1)
return entropy.mean()
# 收集激活统计
head_stats = {i: {'entropy': [], 'max_attn': []}
for i in range(model.config.num_attention_heads)}
for data in calibration_data:
outputs = model(**data, output_attentions=True)
for layer_idx, attn in enumerate(outputs.attentions):
avg_attn = attn.mean(dim=[0, 1]) # [H, T, T]
for h in range(model.config.num_attention_heads):
head_stats[h]['entropy'].append(
compute_attention_entropy(avg_attn[h])
)
head_stats[h]['max_attn'].append(
avg_attn[h].max().item()
)
# 分类
retrieval_heads = []
for h, stats in head_stats.items():
avg_entropy = np.mean(stats['entropy'])
avg_max = np.mean(stats['max_attn'])
# 低熵+高最大值 = 检索头
if avg_entropy < 2.0 and avg_max > 0.3:
retrieval_heads.append(h)
# 如果未指定数量,选择固定比例
if num_retrieval_heads and len(retrieval_heads) != num_retrieval_heads:
# 按max_attn排序,选择top-k
sorted_heads = sorted(
head_stats.items(),
key=lambda x: np.mean(x[1]['max_attn']),
reverse=True
)
retrieval_heads = [h for h, _ in sorted_heads[:num_retrieval_heads]]
return retrieval_heads
9.2 配置推荐
# DuoAttention配置
# 通用配置
general_config = {
'retrieval_head_ratio': 0.1, # 10%检索头
'window_size': 512
}
# 检索密集型
retrieval_config = {
'retrieval_head_ratio': 0.2, # 20%检索头
'window_size': 256
}
# 流式密集型
streaming_config = {
'retrieval_head_ratio': 0.05, # 5%检索头
'window_size': 1024
}
10. 总结
DuoAttention的核心贡献:
- 功能分离:识别并分离检索头和流式头的功能
- 针对性优化:不同类型的头使用不同的缓存策略
- 任务适配:可以根据任务类型调整头类型比例
- 无需训练:基于现有模型的即插即用方法
参考文献