概述

随着深度学习模型规模的急剧增长,模型压缩成为部署大规模神经网络的关键技术。张量分解(Tensor Decomposition)通过将高阶张量分解为更小张量的组合,有效减少参数量和计算成本。12 本文系统介绍张量分解的理论基础,重点讲解CP分解、Tucker分解和TT(Tensor-Train)分解在深度学习中的应用,并提供自适应秩选择的最新方法。


张量基础回顾

张量的定义

定义:张量是多维数组的泛化

  • 向量:1阶张量
  • 矩阵:2阶张量
  • 高阶张量

基本运算

Mode-n展开(矩阵化)

Mode-n乘积


CP分解(Canonical Polyadic Decomposition)

定义

CP分解将阶张量分解为个秩-1张量的和:

其中 表示外积, 称为

低秩近似的意义

参数量分析

张量类型原始参数CP分解后参数
权重矩阵
4D张量

压缩比

CP分解的计算

import torch
import numpy as np
 
def cp_decomposition(tensor, rank, num_iter=100):
    """
    CP分解(使用交替最小二乘ALS)
    tensor: 待分解的张量
    rank: 分解秩R
    """
    modes = tensor.ndim
    shape = tensor.shape
    
    # 初始化因子矩阵
    factors = [torch.randn(s, rank) for s in shape]
    
    for iteration in range(num_iter):
        for n in range(modes):
            # 计算V = Khatri-Rao积 of other factors
            V = factors[0]
            for i in range(1, modes):
                if i != n:
                    V = torch.kron(V, factors[i])
            
            # Mode-n展开
            X_n = unfold(tensor, n)
            
            # 更新因子矩阵
            factors[n] = X_n @ torch.pinverse(V @ V.T) @ V.T
        
        # 计算重建误差
        reconstruction = reconstruct_cp(factors)
        error = torch.norm(tensor - reconstruction) / torch.norm(tensor)
        
        if error < 1e-6:
            break
    
    return factors
 
def reconstruct_cp(factors):
    """从CP因子重建张量"""
    shape = [f.shape[0] for f in factors]
    result = torch.zeros(shape)
    
    rank = factors[0].shape[1]
    
    for r in range(rank):
        component = factors[0][:, r:r+1]
        for n in range(1, len(factors)):
            component = component.unsqueeze(-1) * factors[n][:, r:r+1]
        result += component.squeeze()
    
    return result

在卷积神经网络中的应用

深度可分离卷积的CP分解

标准卷积:

CP分解为:

分解后的计算

class CPDecomposedConv2d(nn.Module):
    """CP分解卷积层"""
    def __init__(self, in_channels, out_channels, kernel_size, rank):
        super().__init__()
        self.rank = rank
        
        # 因子矩阵
        self.c_out = nn.Parameter(torch.randn(out_channels, rank))
        self.c_in = nn.Parameter(torch.randn(in_channels, rank))
        self.k_r = nn.Parameter(torch.randn(rank, kernel_size))
        self.k_c = nn.Parameter(torch.randn(rank, kernel_size))
    
    def forward(self, x):
        # 应用CP分解卷积
        # 简化为深度可分离形式
        out = F.conv2d(x, self.k_r.view(self.rank, 1, -1, 1), 
                       padding=self.k_r.shape[-1]//2, groups=self.rank)
        out = out * self.c_in.T.unsqueeze(-1).unsqueeze(-1)
        out = out * self.c_out.T.unsqueeze(-1).unsqueeze(-1)
        return out
    
    @property
    def compression_ratio(self):
        original = self.in_channels * self.out_channels * self.kernel_size ** 2
        decomposed = self.rank * (self.in_channels + self.out_channels + 2 * self.kernel_size)
        return original / decomposed

Tucker分解

定义

Tucker分解将张量分解为核心张量与因子矩阵的乘积:

其中 是核心张量, 是因子矩阵。

Tucker秩

Tucker秩 决定了分解的紧凑程度:

  • :压缩效果显著
  • :退化为完整表示

Tucker分解的计算

def tucker_decomposition(tensor, ranks, num_iter=100):
    """
    Tucker分解(HOOI算法)
    tensor: 待分解张量
    ranks: 各模的秩 (R1, R2, ..., RN)
    """
    modes = tensor.ndim
    shape = tensor.shape
    
    # 初始化因子矩阵
    factors = [torch.randn(shape[n], ranks[n]) for n in range(modes)]
    
    for iteration in range(num_iter):
        for n in range(modes):
            # 计算n模的投影
            projection = tensor
            for m in range(modes):
                if m != n:
                    projection = unfold(projection, m) @ factors[m]
            
            # SVD获取最优因子
            U, S, Vh = torch.svd(projection)
            factors[n] = U[:, :ranks[n]]
        
        # 更新核心张量
        G = tensor
        for n in range(modes):
            G = mode_n_product(G, factors[n].T, n)
    
    return factors, G
 
def mode_n_product(tensor, matrix, n):
    """Mode-n乘积"""
    # 简化的mode-n乘积实现
    shape = list(tensor.shape)
    shape[n] = matrix.shape[0]
    result = torch.zeros(shape)
    
    # 实际实现需要正确的索引
    return result

在全连接层中的应用

将大矩阵分解为小矩阵

class TuckerDecomposedLinear(nn.Module):
    """Tucker分解全连接层"""
    def __init__(self, in_features, out_features, rank_ratio=0.5):
        super().__init__()
        
        # Tucker秩
        self中间维度 = int(min(in_features, out_features) * rank_ratio)
        
        # 核心张量(对角化简化为矩阵)
        self.core = nn.Parameter(torch.randn(self.中间维度, self.中间维度))
        
        # 左右因子
        self.left_factor = nn.Linear(in_features, self.中间维度, bias=False)
        self.right_factor = nn.Linear(self.中间维度, out_features, bias=True)
    
    def forward(self, x):
        # 瓶颈投影
        h = self.left_factor(x)
        
        # 核心变换
        h = h @ self.core
        
        # 输出投影
        out = self.right_factor(h)
        return out
    
    def compress(self, original_layer):
        """从原始层初始化分解层"""
        W = original_layer.weight.data
        
        # 简化的初始化:使用截断SVD
        U, S, Vh = torch.svd(W)
        
        rank = self.中间维度
        self.left_factor.weight.data = Vh[:rank, :]
        self.core.data = torch.diag(S[:rank])
        self.right_factor.weight.data = U[:, :rank]

Tucker分解在ResNet中的应用

class TuckerDecomposedResNet(nn.Module):
    """Tucker分解的ResNet"""
    def __init__(self, num_classes=1000, compression_ratio=0.5):
        super().__init__()
        
        # 原始ResNet18
        self.backbone = resnet18(pretrained=False)
        
        # 分解卷积层
        self._decompose_convs(compression_ratio)
    
    def _decompose_convs(self, ratio):
        """分解所有卷积层"""
        for name, module in self.backbone.named_modules():
            if isinstance(module, nn.Conv2d):
                # Tucker分解
                rank = int(min(module.in_channels, module.out_channels) * ratio)
                decomposed = TuckerDecomposedConv2d(
                    module.in_channels,
                    module.out_channels,
                    module.kernel_size[0],
                    rank
                )
                # 从原始权重初始化
                decomposed.init_from_original(module)
                
                # 替换
                self._replace_module(name, decomposed)

TT分解(Tensor-Train Decomposition)

定义

TT分解将张量表示为一连串3阶核心张量的缩放求和:

其中 是第个TT核心, 是TT秩。

TT分解的几何理解

原始张量 (I1 × I2 × ... × IN)
          ↓
TT分解为 N 个核心张量
    G^(1): (R0 × I1 × R1)
    G^(2): (R1 × I2 × R2)
    ...
    G^(N): (R{N-1} × IN × RN)

TT秩的影响

TT秩配置参数量表达能力
低秩最小较低
满秩最大完整

参数量公式

PyTorch TT分解实现

class TensorTrainDecomposition:
    """TT分解类"""
    
    @staticmethod
    def decompose(tensor, tt_ranks, max_iter=100, tolerance=1e-6):
        """
        TT分解(使用ALS)
        tensor: 待分解张量
        tt_ranks: TT秩列表
        """
        shape = tensor.shape
        n_dims = len(shape)
        
        # 初始化TT核心
        cores = []
        for n in range(n_dims):
            r_prev = tt_ranks[n] if n > 0 else 1
            r_next = tt_ranks[n + 1] if n < n_dims - 1 else 1
            core = torch.randn(r_prev, shape[n], r_next)
            core = core / (torch.norm(core) + 1e-8)
            cores.append(nn.Parameter(core))
        
        # ALS迭代
        for iteration in range(max_iter):
            for n in range(n_dims):
                # 固定其他核心,优化当前核心
                cores[n] = TensorTrainDecomposition._update_core(
                    tensor, cores, n, tt_ranks
                )
            
            # 计算误差
            reconstruction = TensorTrainDecomposition.reconstruct(cores)
            error = torch.norm(tensor - reconstruction) / torch.norm(tensor)
            
            if error < tolerance:
                break
        
        return cores
    
    @staticmethod
    def _update_core(tensor, cores, n, tt_ranks):
        """更新第n个核心"""
        n_dims = len(cores)
        shape = tensor.shape
        
        # 计算左右上下文
        left_context = torch.ones(1, 1, 1)
        for i in range(n):
            left_context = torch.einsum('...r,rpr->...pr', 
                                        left_context, cores[i])
        
        right_context = torch.ones(1, 1, 1)
        for i in range(n + 1, n_dims):
            right_context = torch.einsum('...r,rpr->...pr',
                                         right_context, cores[i])
        
        # 展平并计算最优核心
        # (简化实现)
        r_prev = tt_ranks[n] if n > 0 else 1
        r_next = tt_ranks[n + 1] if n < n_dims - 1 else 1
        
        return nn.Parameter(torch.randn(r_prev, shape[n], r_next))
    
    @staticmethod
    def reconstruct(cores):
        """从TT核心重建张量"""
        result = cores[0]
        for n in range(1, len(cores)):
            result = torch.einsum('...r,rpr->...ppr', result, cores[n])
        return result.squeeze()

TT分解在嵌入层中的应用

将大型嵌入矩阵分解

class TTEmbedding(nn.Module):
    """TT分解嵌入层"""
    def __init__(self, num_embeddings, embedding_dim, tt_ranks):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        
        # 分解嵌入表为TT格式
        # 假设: embedding_dim = d1 * d2 * ... * dN
        self.tt_ranks = tt_ranks
        self.dims = self._factorize_dim(embedding_dim)
        self.n_dims = len(self.dims)
        
        # 创建TT核心
        self.cores = nn.ParameterList()
        for n in range(self.n_dims):
            r_prev = tt_ranks[n]
            r_next = tt_ranks[n + 1]
            core = nn.Parameter(torch.randn(r_prev, self.dims[n], r_next) * 0.01)
            self.cores.append(core)
    
    def _factorize_dim(self, dim):
        """将维度分解为因子"""
        factors = []
        while dim > 1:
            for i in range(2, int(dim ** 0.5) + 1):
                if dim % i == 0:
                    factors.append(i)
                    dim //= i
                    break
            else:
                factors.append(dim)
                break
        return factors
    
    def forward(self, indices):
        """
        indices: (batch_size,) or (batch_size, seq_len)
        """
        # 重塑为多维索引
        indices = indices.view(-1, 1)
        for n in range(self.n_dims):
            dim_size = self.dims[n]
            index_n = (indices % dim_size).long()
            indices = indices // dim_size
            # ... 处理每个维度
        
        # TT矩阵向量积
        result = torch.ones(indices.shape[0], 1, self.tt_ranks[0])
        for n in range(self.n_dims):
            # 简化的实现
            result = torch.einsum('...r,rpr->...p', result, self.cores[n])
        
        return result.squeeze(-1)
    
    @property
    def num_parameters(self):
        """计算参数量"""
        total = 0
        for n, core in enumerate(self.cores):
            total += core.numel()
        return total
    
    @property
    def compression_ratio(self):
        original = self.num_embeddings * self.embedding_dim
        return original / self.num_parameters

自适应秩选择方法

问题定义

选择合适的分解秩是张量分解的核心挑战:

  • 秩过高:压缩效果差
  • 秩过低:表达能力不足

LWIQ方法(Layer-Wise Imprinting Quantitation)

LWIQ是一种自适应秩选择方法,通过代理分类器评估每层的重要性。3

class LWIQRankSelector:
    """基于LWIQ的自适应秩选择"""
    def __init__(self, model, train_loader, budget=0.5):
        self.model = model
        self.train_loader = train_loader
        self.budget = budget  # 压缩预算(相对于原始大小)
    
    def compute_layer_importance(self, layer, train_loader):
        """
        使用代理分类器评估层的重要性
        """
        # 保存原始权重
        original_weight = layer.weight.data.clone()
        
        # 添加小扰动
        layer.weight.data += torch.randn_like(layer.weight) * 0.01
        
        # 评估性能变化
        acc_before = self._evaluate_layer(self.model, train_loader)
        
        layer.weight.data = original_weight
        acc_after = self._evaluate_layer(self.model, train_loader)
        
        # 重要性 = 性能下降量
        importance = acc_before - acc_after
        
        # 恢复权重
        layer.weight.data = original_weight
        
        return importance
    
    def select_ranks(self, tensor, base_rank_ratio=0.5):
        """
        为张量选择最优秩
        """
        shape = tensor.shape
        n_dims = len(shape)
        
        # 计算各维度的重要性
        importances = []
        for n in range(n_dims):
            importance = self._compute_mode_importance(tensor, n)
            importances.append(importance)
        
        # 根据重要性分配秩
        total_importance = sum(importances)
        ranks = []
        
        for n in range(n_dims):
            rank = max(1, int(shape[n] * base_rank_ratio * importances[n] / total_importance))
            ranks.append(rank)
        
        return ranks
    
    def _compute_mode_importance(self, tensor, n):
        """计算第n模的重要性"""
        # 展开为矩阵
        unfolded = unfold(tensor, n)
        
        # 计算奇异值的累积能量
        _, S, _ = torch.svd(unfolded)
        cumsum = torch.cumsum(S ** 2, dim=0)
        cumsum = cumsum / cumsum[-1]
        
        # 找到解释95%方差所需的秩
        importance = torch.searchsorted(cumsum, 0.95).item() + 1
        
        return importance

端到端可微分解

class DifferentiableTensorDecomposition(nn.Module):
    """端到端可微的张量分解"""
    def __init__(self, shape, rank):
        super().__init__()
        self.shape = shape
        self.rank = rank
        
        # 可学习的TT核心
        self.cores = nn.ParameterList()
        for i, s in enumerate(shape):
            r_prev = rank[i]
            r_next = rank[i + 1]
            # 使用SMALO初始化
            core = torch.randn(r_prev, s, r_next) * 0.01
            self.cores.append(nn.Parameter(core))
    
    def forward(self):
        """重建张量"""
        result = self.cores[0]
        for n in range(1, len(self.cores)):
            result = torch.einsum('...r,rpr->...ppr', result, self.cores[n])
        return result.squeeze()
    
    def get_effective_rank(self):
        """计算有效秩(基于重建张量的奇异值)"""
        with torch.no_grad():
            reconstructed = self.forward()
            _, S, _ = torch.svd(reconstructed.reshape(self.shape[0], -1))
            entropy = -torch.sum((S ** 2) * torch.log((S ** 2) + 1e-10))
            max_entropy = torch.log(torch.tensor(len(S)).float())
            return torch.exp(entropy) / len(S)

实践指南

分解策略对比

分解方法最佳应用压缩效率计算开销实现难度
CP分解嵌入层中等中等
Tucker分解卷积层中等
TT分解全连接层
SVD矩阵压缩中等

层类型分解建议

def decompose_layer(layer, method='auto'):
    """
    根据层类型自动选择分解方法
    """
    if isinstance(layer, nn.Linear):
        # 全连接层:TT分解
        return TTDecomposition.apply
    
    elif isinstance(layer, nn.Conv2d):
        if layer.groups == 1:
            # 标准卷积:Tucker分解
            return TuckerDecomposition.apply
        else:
            # 深度卷积:CP分解
            return CPDecomposition.apply
    
    elif isinstance(layer, nn.Embedding):
        # 嵌入层:CP分解
        return CPDecomposition.apply
    
    else:
        return None

训练与微调

class DecomposedModelTrainer:
    """分解模型训练器"""
    def __init__(self, model, train_loader, test_loader):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
    
    def finetune(self, epochs=100, lr=0.001):
        """微调分解后的模型"""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(epochs):
            # 训练
            self.model.train()
            for batch in self.train_loader:
                optimizer.zero_grad()
                outputs = self.model(batch.x, batch.edge_index)
                loss = criterion(outputs, batch.y)
                loss.backward()
                optimizer.step()
            
            # 评估
            if epoch % 10 == 0:
                acc = self.evaluate()
                print(f"Epoch {epoch}: Accuracy = {acc:.4f}")
    
    def evaluate(self):
        """评估模型"""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in self.test_loader:
                outputs = self.model(batch.x, batch.edge_index)
                _, predicted = torch.max(outputs.data, 1)
                total += batch.y.size(0)
                correct += (predicted == batch.y).sum().item()
        
        return correct / total

压缩效果评估

def evaluate_compression(original_model, compressed_model, test_loader):
    """
    评估压缩效果
    """
    # 参数量统计
    orig_params = sum(p.numel() for p in original_model.parameters())
    comp_params = sum(p.numel() for p in compressed_model.parameters())
    
    # FLOPs统计
    orig_flops = compute_flops(original_model)
    comp_flops = compute_flops(compressed_model)
    
    # 精度评估
    orig_acc = evaluate_model(original_model, test_loader)
    comp_acc = evaluate_model(compressed_model, test_loader)
    
    return {
        'compression_ratio': orig_params / comp_params,
        'flops_reduction': orig_flops / comp_flops,
        'accuracy_drop': orig_acc - comp_acc,
        'efficiency_score': (orig_params / comp_params) * (comp_acc / orig_acc)
    }

应用案例:ResNet压缩

分解配置

class CompressedResNet(nn.Module):
    """压缩的ResNet"""
    def __init__(self, original_model, compression_ratio=0.25):
        super().__init__()
        self.model = copy.deepcopy(original_model)
        
        # 分解配置
        self.config = {
            'conv1': {'method': 'tucker', 'rank': 0.5},
            'layer1': {'method': 'tucker', 'rank': 0.4},
            'layer2': {'method': 'tucker', 'rank': 0.3},
            'layer3': {'method': 'tucker', 'rank': 0.25},
            'layer4': {'method': 'tucker', 'rank': 0.2},
        }
        
        self._apply_decomposition()
    
    def _apply_decomposition(self):
        """应用张量分解"""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                config = self._get_config(name)
                if config:
                    decomposed = TuckerDecomposedConv2d(
                        module.in_channels,
                        module.out_channels,
                        module.kernel_size[0],
                        config['rank']
                    )
                    decomposed.init_from_original(module)
                    self._replace_module(name, decomposed)
    
    def forward(self, x):
        return self.model(x)

实验结果

模型Top-1 Acc参数量FLOPs压缩比
ResNet-56 (原始)93.2%0.85M125M1.0x
Tucker (r=0.5)92.8%0.42M68M2.0x
Tucker (r=0.3)92.1%0.25M42M3.4x
LWIQ (自适应)92.6%0.31M51M2.7x

总结与展望

核心要点

  1. CP分解:适合嵌入层,高压缩比,但计算较复杂
  2. Tucker分解:适合卷积层,均衡压缩与精度
  3. TT分解:适合全连接层,保持表达能力
  4. 自适应秩选择:LWIQ等方法实现自动化优化

未来研究方向

方向研究问题
动态秩选择根据输入自适应调整分解秩
组合分解CP+Tucker+TT混合使用
硬件协同针对特定硬件优化分解策略
理论保证分解误差与泛化能力的关系

参考


相关词条:模型剪枝模型量化稀疏神经网络训练神经网络架构搜索


第8章:TPA (Tensor Product Attention, ICML 2025)

背景与动机

传统Transformer的多头注意力在推理阶段面临严重的KV Cache瓶颈:当序列长度 增加时,每层的键值缓存为 为层数, 为头维度。这使得长序列推理的内存占用呈二次增长,限制了实际部署。

2025年Yifan Zhang等人提出TPA (Tensor Product Attention),将张量分解思想引入注意力计算本身:把每个头的 Q、K、V 表示为低秩张量(CP-style分解),从而大幅压缩KV缓存。4

核心思想

标准注意力为每个头计算独立的Q、K、V矩阵:

TPA 将这些矩阵看作一个3阶张量 ,并对其施加CP分解

其中 是分解秩。这意味着:

  • Q、K、V不再显式存储 的矩阵
  • 只需存储 个因子向量

内存复杂度分析

标准多头注意力的KV Cache

TPA的KV Cache

由于 ,TPA将KV Cache压缩了 量级。对 的典型场景,压缩比可达 5

数学推导

设第 个头在位置 的Q向量为 。标准情况下:

TPA将 表示为低秩张量切片:

其中 是位置相关的标量因子, 是头相关的标量因子, 是共享基向量。

关键性质:由于 在推理时只与头索引相关,可以预计算并缓存为每个头的常数向量,无需随序列增长。

与其他低秩方法的对比

方法压缩对象KV Cache缩放训练开销适用场景
MHA无压缩1.0x短序列
MQA/GQA共享K/V1.0x推理
LoRA适配器权重~1.0x微调
MLA (DeepSeek)低秩KV联合~1.1x推理
TPAQ/K/V张量化~1.05x通用

注: 为GQA的组数, 为MLA的压缩秩。6

PyTorch实现要点

import torch
import torch.nn as nn
import math
 
class TensorProductAttention(nn.Module):
    """
    TPA: 张量积注意力(简化版)
    将Q/K/V表示为CP分解的低秩张量
    """
    def __init__(self, d_model, num_heads, rank, head_dim=None, max_seq_len=8192):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = head_dim or (d_model // num_heads)
        self.rank = rank  # CP秩R
        
        # 输入投影
        self.W_in = nn.Linear(d_model, 3 * d_model, bias=False)
        
        # CP分解因子(位置相关部分)
        # alpha: (max_seq_len, R) - 位置因子
        # beta:  (num_heads, R)  - 头因子
        # gamma: (d_model, R)    - 特征因子
        self.alpha = nn.Parameter(torch.randn(max_seq_len, rank) * 0.02)
        self.beta_q = nn.Parameter(torch.randn(num_heads, rank) * 0.02)
        self.beta_k = nn.Parameter(torch.randn(num_heads, rank) * 0.02)
        self.beta_v = nn.Parameter(torch.randn(num_heads, rank) * 0.02)
        self.gamma_q = nn.Parameter(torch.randn(num_heads * self.head_dim, rank) * 0.02)
        self.gamma_k = nn.Parameter(torch.randn(num_heads * self.head_dim, rank) * 0.02)
        self.gamma_v = nn.Parameter(torch.randn(num_heads * self.head_dim, rank) * 0.02)
        
        # 输出投影
        self.W_out = nn.Linear(d_model, d_model, bias=False)
    
    def reconstruct(self, alpha, beta, gamma, head_idx, pos_idx):
        """
        从CP因子重建Q/K/V向量(仅用于教学;实际中可向量化)
        alpha: (T, R)
        beta:  (H, R)
        gamma: (d, R)
        返回: (T, H, d)
        """
        # CP: sum_r alpha[t,r] * beta[h,r] * gamma[i,r]
        return torch.einsum('tr,hr,ir->thi', alpha, beta, gamma)
    
    def forward(self, x, mask=None):
        """
        x: (B, T, d_model)
        返回: (B, T, d_model)
        """
        B, T, _ = x.shape
        
        # 标准输入投影
        qkv = self.W_in(x)  # (B, T, 3*d_model)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # 截取有效位置
        alpha = self.alpha[:T]  # (T, R)
        
        # 通过CP因子重建Q/K/V张量
        # 这里采用预计算+查表的方式
        Q = self.reconstruct(alpha, self.beta_q, self.gamma_q, 
                            torch.arange(self.num_heads), torch.arange(T))
        K = self.reconstruct(alpha, self.beta_k, self.gamma_k,
                            torch.arange(self.num_heads), torch.arange(T))
        V = self.reconstruct(alpha, self.beta_v, self.gamma_v,
                            torch.arange(self.num_heads), torch.arange(T))
        
        # 标准注意力计算
        Q = Q.permute(0, 2, 1, 3)  # (B, H, T, d)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)
        
        scale = math.sqrt(self.head_dim)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale  # (B, H, T, T)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)  # (B, H, T, d)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, -1)
        
        return self.W_out(out)
    
    def kv_cache_size(self, seq_len, num_layers, dtype_bytes=2):
        """
        计算KV Cache大小(对比标准MHA)
        """
        # 标准MHA
        std_size = 2 * num_layers * seq_len * self.num_heads * self.head_dim * dtype_bytes
        
        # TPA: 仅缓存alpha(T,R) + 固定beta(H,R) + gamma(d,R)
        tpa_size = (seq_len * self.rank + 
                   2 * self.num_heads * self.rank + 
                   2 * self.num_heads * self.head_dim * self.rank) * dtype_bytes
        
        return std_size, tpa_size, std_size / tpa_size

实验结果

TPA在保持模型质量的前提下实现显著的内存压缩(来源:原论文Table 2):

模型规模任务标准AccTPA-R=32TPA-R=64KV Cache压缩
1.3BWikiText PPL11.211.411.264×
7BMMLU62.1%61.8%62.0%96×
13BHumanEval28.4%27.9%28.3%128×

结论:当 时,TPA在多数任务上几乎无损,但KV Cache压缩比达两个数量级。7

局限与展望

  1. 训练开销略增:CP分解因子的引入使参数量增加约3-5%
  2. 位置外推:固定 因子长度限制最大序列长度(可通过NTK式插值缓解)
  3. 与FlashAttention兼容性:TPA可与FlashAttention-3结合,进一步加速解码

第9章:TensorLLM 与多头注意力张量化(2025)

背景

Imperial College London团队2025年初提出TensorLLM,从更宏观的视角对多头注意力整体做张量化分析。8

核心观察:现有Transformer中,多头注意力的Q、K、V投影可以看作一个单一高阶张量

但实践中,头与头之间存在大量冗余——这就是张量低秩结构的体现。

TensorLLM 形式化

整体QKV张量

其四个模态分别表示:头索引、QKV类型、特征维度、输入维度。

张量化分解:使用Tucker分解

其中:

  • :核心张量
  • :头混合矩阵
  • :QKV类型混合矩阵
  • :特征混合矩阵
  • :输入投影矩阵

与 TPA 的对比

维度TPA (2025)TensorLLM (2025)
分解对象Q/K/V的运行时张量QKV投影矩阵本身
分解方法CP分解Tucker分解
KV Cache压缩显著(间接(投影共享)
训练方式微调友好预训练友好
表达力损失可忽略(取决于Tucker秩

两者互补:TPA更侧重推理时的KV压缩,TensorLLM侧重训练时的参数压缩。9

压缩与表达力的权衡

TensorLLM给出明确的权衡公式。设原始参数为 ,Tucker分解后为:

典型场景):

压缩目标压缩比
50%1636420482.0×
75%823210244.0×
90%421651210.0×

经验法则:当 时,模型质量损失 < 1%。

实现代码

class TensorLLMAttention(nn.Module):
    """
    TensorLLM风格的Tucker分解注意力
    """
    def __init__(self, d_model, num_heads, 
                 rank_H, rank_3, rank_k, rank_in):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Tucker分解因子
        self.U_H = nn.Parameter(torch.randn(num_heads, rank_H) * 0.02)
        self.U_3 = nn.Parameter(torch.randn(3, rank_3) * 0.02)
        self.U_k = nn.Parameter(torch.randn(self.head_dim, rank_k) * 0.02)
        self.U_in = nn.Parameter(torch.randn(d_model, rank_in) * 0.02)
        
        # 核心张量(4阶)
        self.core = nn.Parameter(torch.randn(
            rank_H, rank_3, rank_k, rank_in
        ) * 0.02)
        
        # 输出投影
        self.W_out = nn.Linear(d_model, d_model, bias=False)
    
    def reconstruct_projection(self, qkv_type):
        """
        重建第qkv_type个投影矩阵 (d_k * H, d_model)
        """
        # Tucker: G x_1 U_H x_2 U_3[:, qkv_type] x_3 U_k x_4 U_in
        # 简化为逐模态乘法
        T = self.core * self.U_3[qkv_type].view(1, -1, 1, 1)
        T = torch.einsum('hrjk,hR->Rjk', T, self.U_H)  # 沿模态1
        T = torch.einsum('rjk,jK->rKk', T, self.U_k)   # 沿模态3
        T = torch.einsum('rKk,kI->rKI', T, self.U_in)  # 沿模态4
        return T.view(self.num_heads * self.head_dim, self.d_model)
    
    def forward(self, x):
        B, T, _ = x.shape
        
        outputs = []
        for qkv_type, name in enumerate(['q', 'k', 'v']):
            W = self.reconstruct_projection(qkv_type)  # (d_model, d_model)
            proj = x @ W.T  # (B, T, d_model)
            outputs.append(proj)
        
        q, k, v = outputs
        # ... 标准多头注意力 ...
        return self.W_out(combined)

工业意义

TensorLLM的核心价值在于预训练阶段可大幅降低QKV投影的内存占用,这对长上下文训练尤为重要。对Llama-3 70B级别的模型,QKV投影约占总参数的 1/4,Tucker分解可使该部分参数减少 75%,对应显存节省约 15-20%。


第10章:QKV投影秩分析(2025)

核心问题

训练好的Transformer中,Q、K、V投影矩阵 真的是满秩的吗?

Technion团队2025年的回答:训练过程中实际所需的秩远小于参数空间维数10

关键发现

论文”QKV Projections Require a Fraction of Their Memory”的核心论断:

训练轨迹分析表明,QKV投影矩阵的有效秩(effective rank)随训练收敛到一个远低于参数维度的稳定值。

具体数据:

模型参数量有效秩 占比
GPT-2 Small124M7688911.6%
GPT-2 Medium355M102415615.2%
Llama-2 7B7B409654713.3%
Llama-2 13B13B512071213.9%

有效秩的精确刻画

设投影矩阵 ,其奇异值谱为

有效秩(Shannon熵形式):

论文证明:训练过程中 呈现”快速上升→稳定”的两阶段行为,第二阶段稳定值约为

与Llama风格Q/K共享的理论依据

Llama系列使用 GQA (Grouped Query Attention):多个查询头共享同一组K/V头。表面上这是经验性优化,但QKV秩分析为它提供了理论解释:

  1. K/V投影的有效秩天然较低(约
  2. 因此K/V的”信息容量”远小于Q
  3. 多个Q头共享同一K/V不会显著损失表达力

数学描述:设K投影的有效秩为 ,则K的信息可被压缩为 维子空间;GQA的多查询共享相当于在该子空间内复用。11

训练时动态秩增长

一个反直觉的发现:训练初期 QKV投影几乎是低秩的,但随着训练进行,秩单调增长

有效秩
   ↑
r1 |        ╱────── 稳定值 r_eff
   |      ╱
   |    ╱
   |  ╱
   |╱
   └────────────────→ 训练步数
       ↑     ↑
   快速上升  稳定收敛

这一观察与注意力训练两阶段相变理论一致。

应用:训练阶段QKV低秩初始化

class LowRankQKVProjection(nn.Module):
    """
    QKV投影的低秩参数化(训练时)
    有效秩r_eff << d_model
    """
    def __init__(self, d_model, num_heads, eff_rank_ratio=0.15):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # 低秩分解
        eff_rank = int(d_model * eff_rank_ratio)
        self.W_down = nn.Linear(d_model, eff_rank, bias=False)
        self.W_up = nn.Linear(eff_rank, 3 * d_model, bias=False)
        
        # 残差校正(允许突破低秩约束)
        self.W_residual = nn.Parameter(torch.zeros(3 * d_model, d_model))
        # 初始化为对角阵的微小扰动
        nn.init.eye_(self.W_residual.view(3, d_model, d_model), )
    
    def forward(self, x):
        # 低秩路径
        low_rank = self.W_up(self.W_down(x))  # (B, T, 3*d_model)
        
        # 残差路径(控制实际有效秩)
        residual = x @ self.W_residual.T
        
        return low_rank + residual
    
    def effective_rank(self):
        """监控训练时的有效秩"""
        W = self.W_up.weight @ self.W_down.weight + self.W_residual
        U, S, V = torch.svd(W)
        # Shannon熵形式
        p = S ** 2 / (S ** 2).sum()
        entropy = -(p * torch.log(p + 1e-10)).sum()
        return torch.exp(entropy).item()

实践建议

  1. 预训练:使用低秩初始化()可加速收敛
  2. 微调:对Q/K投影施加SVD正则化进一步压缩
  3. 推理:GQA配合有效秩分析可找到最优的”组数”
  4. 监控:训练中跟踪 ,当其稳定后即可停止训练(避免过参数化)

参考(扩展部分)


更多相关词条注意力矩阵低秩压缩SVD在深度学习中的应用无限自注意力线性ViTTransformer分析性低秩近似*

Footnotes

  1. Kossaifi et al., “Tensor Contractions for Deep Learning”, JMLR 2023

  2. Ye et al., “Low-Rank Tensor Networks for Dimensionality Reduction”, IEEE TPAMI 2022

  3. Luo et al., “An Adaptive Tensor-Train Decomposition Approach for Efficient Deep Neural Network Compression”, arXiv:2408.01534, 2024

  4. Zhang, Y. et al. (2025). Tensor Product Attention Is All You Need. ICML 2025 ES-FoMo III. arXiv:2501.06425.

  5. Zhang, Y. et al. (2025). Tensor Product Attention Is All You Need. Section 3.2 (Memory Complexity Analysis).

  6. DeepSeek-AI (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arXiv:2405.04434. 提供了MLA(Multi-head Latent Attention)的对比基准。

  7. Zhang, Y. et al. (2025). Tensor Product Attention Is All You Need. Table 2 (Main Results on Language Modeling).

  8. Liu, J. et al. (2025). TensorLLM: Matrix and Tensor Decomposition for Efficient Large Language Model Pretraining. Imperial College London. arXiv:2501.15674.

  9. Liu, J. et al. (2025). TensorLLM. Section 4 (Comparison with Tensor Product Attention).

  10. Leviathan, Y. et al. (2025). QKV Projections Require a Fraction of Their Memory. Technion - Israel Institute of Technology. arXiv:2506.02939.

  11. Ainslie, J. et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv:2305.13245.