1. 引言

权重表示学习(Weight Representation Learning)是将神经网络的权重向量视为高维空间中的点,通过学习其在某种度量下的表示来实现模型分析、压缩、迁移等任务的技术范式。与传统关注输入空间或特征空间的表示学习不同,权重空间表示学习将焦点转向模型参数本身的几何性质。

这一研究方向的兴起源于两个核心观察:

  1. 参数化模型的等价性:神经网络存在大量参数化对称性(如权重缩放、排列等),导致参数空间中存在大量等价的权重配置
  2. 损失景观的几何结构:损失函数在权重空间中呈现出复杂的几何结构,理解这一结构对于理解泛化、优化和迁移至关重要

2. 权重空间的度量与几何结构

2.1 距离度量

权重空间中的距离度量是理解和分析模型行为的基础工具。不同的度量捕捉了权重的不同方面。

欧氏距离

最简单的度量是两个权重向量之间的欧氏距离:

欧氏距离的优势在于计算简单且直观,但它忽略了权重的统计特性,且对尺度敏感。

余弦相似度

余弦相似度衡量两个权重向量的方向一致性:

余弦相似度对权重的尺度不敏感,专注于方向信息,这在分析权重演化轨迹时特别有用。

Fisher-Rao 度量

Fisher-Rao 度量是信息几何中的自然度量,它利用Fisher信息矩阵定义黎曼度量:

两个分布 之间的Fisher-Rao距离定义为:

对于高斯分布 ,Fisher-Rao距离有闭式解:

2.2 权重空间的流形结构

神经网络的权重空间并非平坦的欧几里得空间,而是嵌入在高维空间中的黎曼流形。这种流形结构源于:

  1. 激活函数的非线性:ReLU、GeLU等激活函数将参数空间划分为多个线性区域
  2. 对称性群作用:权重空间中存在由置换群、缩放群等组成的等价类
  3. 损失函数的曲率:损失景观中的曲率变化反映了权重的内在几何性质

子流形结构

为所有可实现的输入-输出映射的集合,权重空间 通过映射 关联到函数空间。权重空间中的一条曲线 对应函数空间中的曲线

切空间 由所有可行的权重扰动方向构成,其维度等于自由参数数量。

3. SANE方法详解

SANE(Scalable and Versatile Weight Space Learning)是一种将权重空间表示学习扩展到大规模模型的方法框架。1

3.1 方法概述

SANE的核心思想是将预训练模型集合映射到一个紧凑的表示空间,使得在该空间中可以高效执行模型发现、比较和组合等操作。

给定 个预训练模型 ,每个模型 由权重 参数化。SANE学习一个表示函数 ,将权重映射到低维表示空间

3.2 表示学习目标

SANE的表示学习目标结合了三个组件:

其中:

  1. 功能一致性损失 :确保相似的功能在表示空间中接近

其中 衡量两个模型输出的相似度。

  1. 几何一致性损失 :保留权重空间的几何结构
  1. 正则化损失 :防止表示空间坍缩

3.3 规模化技术

处理大规模模型时,SANE采用以下技术:

层次化表示:不直接处理完整权重向量,而是提取层次化表示:

其中 是第 层的权重, 是层特定的表示函数。

随机投影:对于极大规模的模型,使用随机投影加速:

4. 权重空间的流形结构分析

4.1 线性模式连通性

线性模式连通性(Linear Mode Connectivity)指出,在特定条件下,两个训练好的模型可以通过线性路径连接而不经历显著的损失增加。2

为两个局部最优点,定义插值路径:

如果对于所有 都有 ,则称 -线性连通的。

4.2 子流形维度估计

权重空间中有效自由度(子流形维度)的估计是一个重要问题。使用PCA或固有维度估计方法:

其中 是权重矩阵协方差矩阵的特征值。

4.3 曲率分析

损失景观的曲率由Hessian矩阵描述:

Hessian的特征值分解揭示了权重空间的局部几何:

  • 负曲率方向(负特征值):指向更差的局部极大点
  • 正曲率方向(大正特征值):指向尖锐的局部极小
  • 零曲率方向(零特征值):对应冗余参数或对称性

5. 权重表示的维度约简技术

5.1 主成分分析(PCA)

PCA通过线性投影将高维权重映射到低维空间:

其中 包含前 个主成分, 是权重均值。

5.2 自编码器方法

变分自编码器(VAE)在权重空间学习紧凑表示:

编码器

解码器

损失函数

5.3 对比学习

对比权重表示学习使用InfoNCE损失:

正样本对 可以来自:

  • 相同训练轨迹的不同检查点
  • 具有相似功能的模型
  • 同一任务的微调变体

6. 与模型可解释性的联系

6.1 功能聚类

权重表示学习使得功能相似但参数不同的模型能够被聚类:

这种聚类在以下场景中特别有用:

  • 模型选择:从大量微调模型中选择最适合特定任务的
  • 集成学习:选择多样且有效的模型组合
  • 知识迁移:识别哪些模型包含可迁移的知识

6.2 特征归因

在权重表示空间中分析单个神经元或权重通道的贡献:

6.3 因果干预分析

通过在表示空间中的插值和外推来理解功能变化:

分析 的模式可以揭示权重变化如何影响模型行为。

7. 代码示例

以下代码展示如何计算权重间的相似度并执行维度约简:

import torch
import torch.nn as nn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import numpy as np
 
class WeightSpaceMetrics:
    """权重空间度量工具类"""
    
    @staticmethod
    def euclidean_distance(theta1: torch.Tensor, theta2: torch.Tensor) -> float:
        """计算两个权重向量之间的欧氏距离"""
        return torch.norm(theta1 - theta2).item()
    
    @staticmethod
    def cosine_similarity(theta1: torch.Tensor, theta2: torch.Tensor) -> float:
        """计算两个权重向量之间的余弦相似度"""
        dot_prod = torch.dot(theta1, theta2)
        norm_prod = torch.norm(theta1) * torch.norm(theta2)
        return (dot_prod / (norm_prod + 1e-8)).item()
    
    @staticmethod
    def frobenius_distance(theta1: torch.Tensor, theta2: torch.Tensor) -> float:
        """计算两个权重矩阵之间的Frobenius距离"""
        return torch.norm(theta1 - theta2, p='fro').item()
    
    @staticmethod
    def extract_weight_vector(model: nn.Module) -> torch.Tensor:
        """从模型中提取展平的权重向量"""
        weights = []
        for param in model.parameters():
            weights.append(param.data.flatten())
        return torch.cat(weights)
    
    @staticmethod
    def layer_wise_similarity(theta1: torch.Tensor, theta2: torch.Tensor, 
                             layer_shapes: list) -> list:
        """计算每层的权重相似度"""
        similarities = []
        idx = 0
        for shape in layer_shapes:
            size = np.prod(shape)
            layer1 = theta1[idx:idx+size]
            layer2 = theta2[idx:idx+size]
            sim = WeightSpaceMetrics.cosine_similarity(layer1, layer2)
            similarities.append(sim)
            idx += size
        return similarities
 
 
class WeightSpaceAnalyzer:
    """权重空间分析器"""
    
    def __init__(self, models: list):
        """
        初始化分析器
        
        Args:
            models: 模型列表,每个模型用于提取权重
        """
        self.models = models
        self.weight_vectors = [
            WeightSpaceMetrics.extract_weight_vector(m) for m in models
        ]
        self.weight_matrix = torch.stack(self.weight_vectors)
    
    def compute_distance_matrix(self, metric: str = 'euclidean') -> np.ndarray:
        """计算权重之间的距离矩阵"""
        n = len(self.weight_vectors)
        dist_matrix = np.zeros((n, n))
        
        for i in range(n):
            for j in range(i + 1, n):
                if metric == 'euclidean':
                    d = WeightSpaceMetrics.euclidean_distance(
                        self.weight_vectors[i], self.weight_vectors[j]
                    )
                elif metric == 'cosine':
                    d = 1 - WeightSpaceMetrics.cosine_similarity(
                        self.weight_vectors[i], self.weight_vectors[j]
                    )
                dist_matrix[i, j] = d
                dist_matrix[j, i] = d
        
        return dist_matrix
    
    def reduce_dimensions_pca(self, n_components: int = 10) -> np.ndarray:
        """使用PCA进行维度约简"""
        pca = PCA(n_components=n_components)
        representations = pca.fit_transform(self.weight_matrix.numpy())
        return representations, pca.explained_variance_ratio_.sum()
    
    def reduce_dimensions_tsne(self, perplexity: float = 5.0) -> np.ndarray:
        """使用t-SNE进行非线性维度约简"""
        tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
        representations = tsne.fit_transform(self.weight_matrix.numpy())
        return representations
 
 
class WeightSpaceInterpolation:
    """权重空间插值工具"""
    
    @staticmethod
    def linear_interpolate(theta1: torch.Tensor, theta2: torch.Tensor, 
                          t: float) -> torch.Tensor:
        """线性插值"""
        return (1 - t) * theta1 + t * theta2
    
    @staticmethod
    def slerp(theta1: torch.Tensor, theta2: torch.Tensor, 
              t: float, eps: float = 1e-8) -> torch.Tensor:
        """球面线性插值(适用于归一化权重)"""
        theta1_norm = theta1 / (torch.norm(theta1) + eps)
        theta2_norm = theta2 / (torch.norm(theta2) + eps)
        
        omega = torch.acos(torch.clamp(
            torch.dot(theta1_norm, theta2_norm), -1.0 + eps, 1.0 - eps
        ))
        
        if omega.item() < eps:
            return theta1
        
        sin_omega = torch.sin(omega)
        w1 = torch.sin((1 - t) * omega) / sin_omega
        w2 = torch.sin(t * omega) / sin_omega
        
        return w1 * theta1_norm + w2 * theta2_norm
    
    @staticmethod
    def geodesic_path(theta1: torch.Tensor, theta2: torch.Tensor, 
                      n_steps: int = 10) -> list:
        """计算测地线路径"""
        path = []
        for i in range(n_steps + 1):
            t = i / n_steps
            interp = WeightSpaceInterpolation.slerp(theta1, theta2, t)
            path.append(interp)
        return path
 
 
def example_usage():
    """示例用法"""
    # 创建简单的测试模型
    torch.manual_seed(42)
    models = []
    for i in range(5):
        model = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 5)
        )
        # 添加随机扰动模拟不同训练阶段
        with torch.no_grad():
            for param in model.parameters():
                param += 0.1 * torch.randn_like(param) * i
        models.append(model)
    
    # 分析权重空间
    analyzer = WeightSpaceAnalyzer(models)
    
    # 计算距离矩阵
    dist_matrix = analyzer.compute_distance_matrix('euclidean')
    print("Euclidean距离矩阵:")
    print(dist_matrix)
    
    # 计算余弦相似度矩阵
    cos_dist_matrix = analyzer.compute_distance_matrix('cosine')
    print("\n余弦距离矩阵:")
    print(cos_dist_matrix)
    
    # PCA降维
    representations, variance_explained = analyzer.reduce_dimensions_pca(n_components=3)
    print(f"\nPCA保留方差比例: {variance_explained:.4f}")
    
    # t-SNE可视化
    tsne_repr = analyzer.reduce_dimensions_tsne(perplexity=2.0)
    print(f"\nt-SNE表示形状: {tsne_repr.shape}")
    
    # 插值示例
    theta1 = analyzer.weight_vectors[0]
    theta2 = analyzer.weight_vectors[1]
    
    print("\n--- 插值示例 ---")
    print(f"起点权重范数: {torch.norm(theta1).item():.4f}")
    print(f"终点权重范数: {torch.norm(theta2).item():.4f}")
    
    for t in [0.25, 0.5, 0.75]:
        interp = WeightSpaceInterpolation.linear_interpolate(theta1, theta2, t)
        print(f"t={t}: 插值权重范数 = {torch.norm(interp).item():.4f}")
 
 
if __name__ == "__main__":
    example_usage()

关键输出说明

运行上述代码会产生以下分析结果:

  1. 距离矩阵:揭示模型间的全局相似性结构
  2. PCA方差保留比例:反映权重空间的有效维度
  3. 插值路径:验证线性模式连通性假设

8. 总结与展望

权重表示学习为理解和操作深度学习模型提供了新的视角。通过将权重空间视为几何结构,我们可以:

  1. 量化模型差异:使用合适的度量比较不同训练阶段或来源的模型
  2. 发现功能模式:识别功能相似的模型簇,实现高效模型选择
  3. 指导模型操作:在表示空间中的操作可以直接转化为权重空间的编辑

未来研究方向包括:

  • 更精细的度量设计:结合任务性能的几何感知度量
  • 动态表示学习:随训练过程自适应更新的表示
  • 因果表示:揭示权重变化与功能变化之间的因果关系

参考资料

Footnotes

  1. Weight Space Learning的相关工作首次在[NeurIPS 2024]发表,提出将权重空间表示学习扩展到大规模预训练模型集合

  2. Garipov et al., “Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs”, NeurIPS 2018