概述

温度参数是控制注意力行为的核心超参数,它通过调节Softmax函数的平滑程度来影响注意力的锐度(sharpness)。温度与注意力的紧密相关,共同决定了模型的表达能力和泛化特性。12

本文系统性地分析:

  1. 温度的数学定义与几何意义
  2. 注意力熵的理论分析
  3. 温度-熵-模型行为的关系
  4. 自适应温度调度的实践方法

温度参数基础

标准注意力温度

Scaled Dot-Product Attention的温度形式:

其中 温度参数

极限行为分析

定理1(温度极限行为)

  1. (零温度)

    注意力退化为硬注意力(Hard Attention)——完全关注最相似的Key。

  2. (无穷温度)

    注意力退化为均匀注意力——等权重关注所有Key。

  3. (标准温度)
    使得初始化时注意力权重的方差为 ,避免饱和。

几何解释

温度 对应高斯核的标准差:

其中

  • :窄核,注意力聚焦
  • :宽核,注意力分散

注意力熵

定义

注意力熵衡量注意力分布的分散程度:

注意:使用自然对数,单位为nats;除以 可得bits。

熵的界

定理2(熵的极值)

对于长度为 的注意力分布:

  • :退化注意力,One-hot分布(
  • :均匀注意力(

熵的导数

温度对熵的影响:

结论:温度增加,熵单调增加(注意力更均匀)。


温度与模型行为

训练动力学

温度在训练不同阶段的作用:

阶段推荐温度注意力特点作用
早期均匀、探索发现多种模式、避免早熟收敛
中期中等适度聚焦学习主要依赖关系
后期锐利、利用精炼关键依赖、避免过拟合

泛化影响

定理3(温度与泛化)

在PAC-Bayes框架下,注意力的熵影响后验复杂度:

其中 随温度 变化。

直觉

  • 高温 → 高熵 → 更大的探索空间 → 更高的KL项
  • 低温 → 低熵 → 更集中的表示 → 更低的KL项

表示学习

温度影响学习的表示特性:

高温(高熵)

  • 学习更均匀的表示
  • 避免表示坍塌
  • 促进特征多样性

低温(低熵)

  • 学习更稀疏的表示
  • 提高判别能力
  • 可能导致过度自信

条件温度理论

任务自适应温度

不同注意力头可以学习不同的”最优温度”:

class ConditionalTemperatureAttention(torch.nn.Module):
    """
    条件温度注意力
    
    每个注意力头自适应学习其温度参数
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # 温度预测网络
        self.temp_net = torch.nn.Sequential(
            torch.nn.Linear(d_model, d_model // 2),
            torch.nn.GELU(),
            torch.nn.Linear(d_model // 2, num_heads),
            torch.nn.Softplus()  # 确保温度为正
        )
        
        # 标准注意力参数
        self.q_proj = torch.nn.Linear(d_model, d_model)
        self.k_proj = torch.nn.Linear(d_model, d_model)
        self.v_proj = torch.nn.Linear(d_model, d_model)
        self.out_proj = torch.nn.Linear(d_model, d_model)
    
    def forward(self, x, context=None):
        batch_size, seq_len, _ = x.shape
        context = context if context is not None else x
        
        # 计算Query, Key, Value
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(context).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(context).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 预测温度(每个头独立)
        # 使用全局平均池化的特征预测温度
        pooled = x.mean(dim=1)  # (batch, d_model)
        raw_temp = self.temp_net(pooled)  # (batch, num_heads)
        
        # 将温度reshape为(batch, num_heads, 1, 1)用于广播
        temperature = raw_temp.view(batch_size, self.num_heads, 1, 1)
        
        # 计算注意力分数(使用条件温度)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5 * temperature)
        
        # Softmax归一化
        attn_weights = F.softmax(scores, dim=-1)
        
        # 输出
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        return self.out_proj(output), attn_weights

层级温度设计

不同层可以设计不同的温度分布:

class LayerwiseTemperatureSchedule:
    """
    层级温度调度
    
    设计原则:
    - 低层(捕获局部特征):较低温度
    - 高层(捕获全局依赖):较高温度
    """
    def __init__(self, num_layers, T_low=0.5, T_high=1.0, pattern='linear'):
        self.num_layers = num_layers
        self.T_low = T_low
        self.T_high = T_high
        self.pattern = pattern
    
    def get_temperature(self, layer_idx):
        ratio = layer_idx / max(1, self.num_layers - 1)
        
        if self.pattern == 'linear':
            # 线性增长
            return self.T_low + ratio * (self.T_high - self.T_low)
        elif self.pattern == 'u':
            # U型:中间低,两端高
            return self.T_high - 0.5 * (self.T_high - self.T_low) * abs(2 * ratio - 1)
        elif self.pattern == 'exp':
            # 指数增长
            return self.T_low * ((self.T_high / self.T_low) ** ratio)
        else:
            return 1.0

熵正则化

熵正则化目标

在损失函数中加入注意力熵项:

其中 控制熵正则化强度。

实现

def entropy_regularized_attention_loss(attn_weights, lambda_ent=0.01):
    """
    熵正则化损失
    
    Args:
        attn_weights: 注意力权重 (batch, num_heads, seq_len, seq_len)
        lambda_ent: 熵正则化系数
    
    Returns:
        熵正则化项
    """
    # 避免log(0)
    eps = 1e-10
    ent = -torch.sum(attn_weights * torch.log(attn_weights + eps), dim=-1)
    return -lambda_ent * ent.mean()
 
 
class EntropyRegularizedTransformer(torch.nn.Module):
    """
    带熵正则化的Transformer
    """
    def __init__(self, model, lambda_ent=0.01):
        super().__init__()
        self.model = model
        self.lambda_ent = lambda_ent
    
    def forward(self, input_ids, labels=None):
        outputs = self.model(input_ids, output_attentions=True)
        
        logits = outputs.logits
        
        if labels is not None:
            # 计算任务损失
            task_loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1)
            )
            
            # 计算熵正则化项
            attn_weights = outputs.attentions  # tuple of (batch, num_heads, seq, seq)
            ent_loss = sum(entropy_regularized_attention_loss(w, self.lambda_ent) 
                          for w in attn_weights)
            
            total_loss = task_loss + ent_loss
            return {'loss': total_loss, 'task_loss': task_loss, 'ent_loss': ent_loss}
        
        return outputs

熵调度策略

class EntropySchedule:
    """
    熵/温度调度器
    
    策略1:预热-退火
    策略2:课程熵(从高到低)
    策略3:对比熵(控制注意力多样性)
    """
    def __init__(self, schedule_type='warmup_anneal', T_init=2.0, T_final=0.1):
        self.schedule_type = schedule_type
        self.T_init = T_init
        self.T_final = T_final
    
    def get_temperature(self, step, total_steps):
        ratio = step / total_steps
        
        if self.schedule_type == 'warmup_anneal':
            # 训练初期高温(探索),后期低温(利用)
            if ratio < 0.1:
                # 预热阶段:线性增加到初始温度
                warmup_ratio = ratio / 0.1
                return self.T_init * warmup_ratio
            else:
                # 退火阶段:从T_init降到T_final
                anneal_ratio = (ratio - 0.1) / 0.9
                return self.T_init - (self.T_init - self.T_final) * self._cosine_schedule(anneal_ratio)
        
        elif self.schedule_type == 'curriculum':
            # 课程学习:从高熵到低熵
            return self.T_init - (self.T_init - self.T_final) * ratio
        
        elif self.schedule_type == 'constant':
            return 1.0
        
        return 1.0
    
    def _cosine_schedule(self, t):
        return 0.5 * (1 + math.cos(math.pi * t))

熵作为可解释性指标

注意力熵监控

class AttentionEntropyMonitor:
    """
    注意力熵监控器
    
    用于分析训练过程中的注意力行为变化
    """
    def __init__(self, num_layers, num_heads):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.entropy_history = defaultdict(list)
    
    def compute_entropy(self, attn_weights):
        """
        计算注意力熵
        
        Args:
            attn_weights: (batch, num_heads, seq_len, seq_len)
        """
        eps = 1e-10
        ent = -torch.sum(attn_weights * torch.log(attn_weights + eps), dim=-1)
        # 平均到每个头
        ent_per_head = ent.mean(dim=(0, 2))  # (num_heads,)
        return ent_per_head
    
    def analyze(self, layer_idx, attn_weights):
        ent = self.compute_entropy(attn_weights)
        self.entropy_history[f'layer_{layer_idx}'].append(ent.detach().cpu().numpy())
    
    def get_statistics(self):
        stats = {}
        for key, values in self.entropy_history.items():
            values = np.array(values)
            stats[key] = {
                'mean': values.mean(axis=0),
                'std': values.std(axis=0),
                'trend': np.polyfit(range(len(values)), values.mean(axis=1), 1)[0]
            }
        return stats

熵异常检测

def detect_attention_anomaly(entropy_stats, threshold=2.0):
    """
    检测注意力熵异常
    
    高熵可能表示:信息未有效聚合、过度分散
    低熵可能表示:信息过度集中、表示坍塌风险
    """
    anomalies = []
    
    for layer_name, stats in entropy_stats.items():
        mean_ent = stats['mean']
        
        # 检测低熵异常(可能表示坍塌)
        if mean_ent < threshold:
            anomalies.append({
                'layer': layer_name,
                'type': 'low_entropy',
                'value': mean_ent,
                'message': f'Potential collapse: entropy={mean_ent:.3f} < {threshold}'
            })
        
        # 检测高熵异常(可能表示学习失败)
        if mean_ent > math.log(512) * 0.9:  # 假设max_seq_len=512
            anomalies.append({
                'layer': layer_name,
                'type': 'high_entropy', 
                'value': mean_ent,
                'message': f'Potential learning failure: entropy={mean_ent:.3f} too high'
            })
    
    return anomalies

理论深度:熵与信息论

互信息视角

注意力熵与互信息的关系:

注意力的熵 近似于给定Query时Key分布的条件熵。

速率-失真视角

从信息瓶颈理论:

  • 低温度(低熵):小 → 高压缩 → 潜在信息丢失
  • 高温度(高熵):大 → 低压缩 → 可能过拟合

最优温度理论

定理4(信息论最优温度)

设任务目标为最大化 ,则在温和条件下存在最优温度 使得:

该温度可通过验证集上的网格搜索或梯度下降估计。


实践指南

温度选择建议

场景建议温度原因
预训练1.0平衡探索与利用
微调(小数据)0.5-1.0减少过拟合风险
微调(大数据)0.1-0.5提高特定任务性能
长上下文< 1.0避免注意力稀释
检索增强< 1.0精确匹配任务

调试技巧

def debug_attention_entropy(model, batch, threshold_low=0.5, threshold_high=2.0):
    """
    调试注意力熵
    
    检查各层、各头的熵是否在合理范围内
    """
    model.eval()
    with torch.no_grad():
        outputs = model(**batch, output_attentions=True)
    
    entropies = []
    for layer_idx, attn in enumerate(outputs.attentions):
        ent = attention_entropy(attn)
        entropies.append((layer_idx, ent))
        
        # 检测异常
        for head_idx, h in enumerate(ent):
            if h < threshold_low:
                print(f"[WARN] Layer {layer_idx}, Head {head_idx}: Low entropy {h:.3f}")
            elif h > threshold_high:
                print(f"[WARN] Layer {layer_idx}, Head {head_idx}: High entropy {h:.3f}")
    
    return entropies

参考资料


相关文档

Footnotes

  1. [arXiv:2506.01562] “Temperature Scaling in Attention” - 系统分析温度对注意力行为的影响

  2. [arXiv:2412.16545] “Attention Entropy Predicts Model Behavior” - 注意力熵与模型行为的关联