Fourier分析与神经网络

傅里叶分析为理解神经网络的学习行为、频率偏差和泛化特性提供了强有力的工具。本篇介绍频谱分析的基础及其在神经网络研究中的应用。1


傅里叶分析基础

傅里叶变换回顾

对于函数 ,其傅里叶变换为:

逆变换:

离散傅里叶变换(DFT)

import numpy as np
import matplotlib.pyplot as plt
 
def discrete_fourier_transform(signal):
    """
    离散傅里叶变换
    
    X[k] = Σ_{n=0}^{N-1} x[n] * exp(-2πi*k*n/N)
    """
    N = len(signal)
    n = np.arange(N)
    k = n.reshape((N, 1))
    
    # DFT 矩阵
    M = np.exp(-2j * np.pi * k * n / N)
    
    return np.dot(M, signal)
 
def inverse_dft(X):
    """逆离散傅里叶变换"""
    N = len(X)
    n = np.arange(N)
    k = n.reshape((N, 1))
    
    # 逆变换矩阵
    M_inv = np.exp(2j * np.pi * k * n / N) / N
    
    return np.dot(M_inv, X)
 
def fft_demo():
    """FFT 演示"""
    # 生成合成信号
    np.random.seed(42)
    t = np.linspace(0, 1, 1000)
    
    # 信号 = 低频 + 高频噪声
    signal = (
        np.sin(2 * np.pi * 5 * t) +      # 5 Hz
        0.5 * np.sin(2 * np.pi * 15 * t) + # 15 Hz
        0.2 * np.random.randn(1000)         # 噪声
    )
    
    # FFT
    X = discrete_fourier_transform(signal)
    frequencies = np.fft.fftfreq(len(signal), t[1] - t[0])
    
    # 幅值谱
    amplitude = np.abs(X) / len(X)
    
    # 只取正频率
    positive_freq_idx = frequencies > 0
    frequencies_pos = frequencies[positive_freq_idx]
    amplitude_pos = 2 * amplitude[positive_freq_idx]  # 单边谱
    
    print(f"信号频率成分检测:")
    print(f"  5 Hz 成分幅值: {amplitude_pos[np.argmin(np.abs(frequencies_pos - 5))]:.4f}")
    print(f"  15 Hz 成分幅值: {amplitude_pos[np.argmin(np.abs(frequencies_pos - 15))]:.4f}")
    
    return frequencies_pos, amplitude_pos

神经网络的频谱表示

权重矩阵的频域视角

神经网络中的线性变换 可以在频域理解:

def frequency_domain_linear(W, x):
    """
    频域中的线性变换
    
    对于循环卷积,频域变换是对角化的
    """
    # 时域
    y_time = W @ x
    
    # 频域
    X_hat = np.fft.fft(x)
    W_hat = np.fft.fft(W)
    Y_hat = W_hat * X_hat  # 逐元素乘法
    y_freq = np.fft.ifft(Y_hat)
    
    return y_time, y_freq
 
 
class CirculantMatrix:
    """
    循环矩阵:频域对角化
    
    每一行是前一行的循环移位
    """
    
    def __init__(self, c):
        """
        c: 第一行向量
        """
        self.c = c
        self.n = len(c)
    
    @staticmethod
    def from_vector(c):
        """从向量构造循环矩阵"""
        n = len(c)
        C = np.zeros((n, n))
        for i in range(n):
            C[i] = np.roll(c, i)
        return C
    
    def to_frequency_domain(self):
        """转到频域"""
        return np.fft.fft(self.c)
    
    def multiply_frequency(self, x_hat):
        """在频域中乘法"""
        c_hat = self.to_frequency_domain()
        return c_hat * x_hat

卷积神经网络的频率偏差

卷积定理

时域中的卷积对应频域中的逐元素乘法:

def convolution_frequency():
    """
    演示卷积的频域实现
    """
    # 信号
    x = np.random.randn(1000)
    
    # 滤波器
    h = np.array([0.1, 0.2, 0.4, 0.2, 0.1])  # 平滑滤波器
    
    # 时域卷积
    y_time = np.convolve(x, h, mode='same')
    
    # 频域实现
    X = np.fft.fft(x)
    H = np.fft.fft(h, len(x))
    Y_freq = X * H
    y_freq = np.fft.ifft(Y_freq).real
    
    # 验证
    error = np.max(np.abs(y_time - y_freq))
    print(f"时域/频域卷积最大误差: {error:.2e}")
    
    return y_time, y_freq
 
 
class ConvolutionalLayerFrequency:
    """
    卷积层的频域分析
    """
    
    @staticmethod
    def kernel_frequency_response(kernel):
        """
        计算卷积核的频率响应
        
        这决定了层对不同频率的敏感度
        """
        # FFT of kernel (pad to match input size)
        H = np.fft.fft2(kernel, s=(32, 32))  # 假设 32x32 输入
        
        # 幅值响应
        magnitude = np.abs(H)
        
        # 相位响应
        phase = np.angle(H)
        
        return magnitude, phase
    
    @staticmethod
    def effective_bandwidth(kernel, threshold=0.1):
        """
        有效带宽:频率响应高于阈值的频率范围
        """
        magnitude, _ = ConvolutionalLayerFrequency.kernel_frequency_response(kernel)
        
        # 归一化
        magnitude_norm = magnitude / magnitude.max()
        
        # 找到截止频率
        center = magnitude_norm.shape[0] // 2
        row = magnitude_norm[center, :]
        
        # 低频到高频
        above_threshold = np.where(row > threshold)[0]
        if len(above_threshold) > 0:
            low_freq_idx = above_threshold[0]
            high_freq_idx = above_threshold[-1]
            
            bandwidth = (high_freq_idx - low_freq_idx) / len(row)
        else:
            bandwidth = 0
        
        return bandwidth

ReLU 的频率效应

ReLU 的频谱

ReLU 激活函数 在频域中会产生所有频率的谐波。

class ReLUFrequencyEffect:
    """
    分析 ReLU 对频率成分的影响
    """
    
    @staticmethod
    def frequency_response(x, threshold=0):
        """
        ReLU 的频域效应
        
        ReLU 不是线性操作,会产生谐波
        """
        # 应用 ReLU
        x_relu = np.maximum(x, threshold)
        
        # FFT
        X_orig = np.fft.fft(x)
        X_relu = np.fft.fft(x_relu)
        
        # 幅值比
        ratio = np.abs(X_relu) / (np.abs(X_orig) + 1e-10)
        
        return ratio
    
    @staticmethod
    def harmonic_generation(frequency, n_harmonics=10):
        """
        模拟谐波生成
        
        对于单一频率输入,分析谐波分布
        """
        t = np.linspace(0, 1, 1000)
        x = np.sin(2 * np.pi * frequency * t)
        
        # ReLU
        x_relu = np.maximum(x, 0)
        
        # FFT
        X = np.fft.fft(x_relu)
        frequencies = np.fft.fftfreq(len(t), t[1] - t[0])
        
        # 找到峰值频率
        positive_freq_idx = frequencies > 0
        amplitude = np.abs(X)[positive_freq_idx]
        freqs_pos = frequencies[positive_freq_idx]
        
        return freqs_pos, amplitude
    
    @staticmethod
    def spectral_decomposition_analysis():
        """
        频谱分解分析
        """
        print("ReLU 频率效应分析:")
        
        # 不同频率的输入
        test_frequencies = [1, 5, 10, 20]
        
        for freq in test_frequencies:
            freqs, amplitude = ReLUFrequencyEffect.harmonic_generation(freq)
            
            # 找到主峰
            peak_idx = np.argmax(amplitude)
            peak_freq = freqs[peak_idx]
            peak_amplitude = amplitude[peak_idx]
            
            # 计算谐波能量比
            fundamental_energy = peak_amplitude**2
            total_energy = np.sum(amplitude**2)
            harmonic_ratio = (total_energy - fundamental_energy) / total_energy
            
            print(f"  输入频率 {freq} Hz:")
            print(f"    检测到峰值: {peak_freq:.1f} Hz, 幅值: {peak_amplitude:.4f}")
            print(f"    谐波能量比: {harmonic_ratio:.2%}")

神经网络的频率偏差(Frequency Bias)

经验观察

神经网络表现出对低频成分的偏好,这被称为频率偏差

class FrequencyBiasAnalysis:
    """
    分析神经网络的频率偏差
    """
    
    @staticmethod
    def train_on_frequency_mixture(n_samples=1000, hidden_dim=100):
        """
        在混合频率信号上训练,测量频率偏差
        """
        np.random.seed(42)
        
        # 生成不同频率的信号
        frequencies = [1, 5, 10, 20, 50]
        t = np.linspace(0, 1, 100)
        
        # 训练数据:只有低频
        X_train = []
        y_train = []
        for _ in range(n_samples):
            freq = np.random.choice([1, 5])
            signal = np.sin(2 * np.pi * freq * t)
            X_train.append(signal)
            y_train.append(freq)
        
        X_train = np.array(X_train)
        y_train = np.array(y_train)
        
        # FFT 分析
        X_fft = np.fft.fft(X_train)
        X_freq = np.abs(X_fft)[:, :50]  # 只看前50个频率
        
        return X_freq, y_train
    
    @staticmethod
    def measure_frequency_learning(model, test_signals):
        """
        测量模型对不同频率的学习程度
        """
        model.eval()
        
        frequency_responses = {}
        
        for freq in [1, 5, 10, 20, 50]:
            signals = []
            for _ in range(10):
                t = np.linspace(0, 1, 100)
                signal = np.sin(2 * np.pi * freq * t)
                signals.append(signal)
            
            signals = np.array(signals)
            
            with torch.no_grad():
                output = model(torch.tensor(signals, dtype=torch.float32))
                # 分析输出
                response = output.mean().item()
                frequency_responses[freq] = response
        
        return frequency_responses
    
    @staticmethod
    def spectral_decomposition_training():
        """
        频谱分解训练分析
        """
        print("神经网络频率偏差分析:")
        print("=" * 50)
        
        # 观察:随机初始化的网络
        print("\n1. 随机初始化网络:")
        np.random.seed(42)
        
        # 简单 MLP
        model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        # 分析初始权重对频率的影响
        first_layer_weights = model[0].weight.data.numpy()
        
        # 简化:计算奇异值分解的谱
        U, s, Vt = np.linalg.svd(first_layer_weights, full_matrices=False)
        
        print(f"  权重矩阵奇异值分布:")
        print(f"    最大: {s[0]:.4f}")
        print(f"    最小: {s[-1]:.4f}")
        print(f"    均值: {s.mean():.4f}")
        print(f"    条件数: {s[0]/s[-1]:.2f}")

谱域学习(Spectral Learning)

谱归一化与频率

class SpectralNormalization:
    """
    谱归一化:控制网络对高频的放大
    """
    
    def __init__(self, model):
        self.model = model
        self.power_iteration_steps = 1
    
    def compute_spectral_norm(self, weight):
        """
        计算权重矩阵的谱范数
        """
        weight_mat = weight.view(weight.size(0), -1)
        
        if weight_mat.shape[0] > weight_mat.shape[1]:
            # u = random, v = W^T u / ||W^T u||
            u = torch.randn(weight_mat.shape[0], 1, device=weight.device)
            u = u / u.norm()
            v = weight_mat.T @ u
            v = v / v.norm()
            return (u.T @ weight_mat @ v).item()
        else:
            v = torch.randn(weight_mat.shape[1], 1, device=weight.device)
            v = v / v.norm()
            u = weight_mat @ v
            u = u / u.norm()
            return (u.T @ weight_mat @ v).item()
    
    def normalize_layer(self, name='weight'):
        """
        对层进行谱归一化
        """
        for module in self.model.modules():
            if hasattr(module, name):
                weight = getattr(module, name)
                spectral_norm = self.compute_spectral_norm(weight)
                if spectral_norm > 0:
                    setattr(module, name, weight / spectral_norm)
    
    def normalize_all(self):
        """
        归一化所有可学习参数
        """
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                spectral_norm = self.compute_spectral_norm(param)
                if spectral_norm > 0:
                    with torch.no_grad():
                        param.div_(spectral_norm)

低通与高通网络的构造

低通网络(平滑网络)

class LowPassNetwork:
    """
    构造具有低通特性的网络
    """
    
    def __init__(self, cutoff_frequency=0.1):
        self.cutoff = cutoff_frequency
    
    def apply_frequency_mask(self, weight, mask_type='ideal_low'):
        """
        应用理想低通/高通掩码
        """
        # 计算权重的频谱
        W_freq = np.fft.fft2(weight)
        
        if mask_type == 'ideal_low':
            # 理想低通掩码
            mask = self._ideal_lowpass_mask(W_freq.shape, self.cutoff)
        elif mask_type == 'gaussian_low':
            mask = self._gaussian_lowpass_mask(W_freq.shape, self.cutoff)
        else:
            mask = np.ones_like(W_freq)
        
        # 应用掩码
        W_filtered = W_freq * mask
        
        # 逆变换
        weight_filtered = np.fft.ifft2(W_filtered).real
        
        return weight_filtered
    
    def _ideal_lowpass_mask(self, shape, cutoff):
        """理想低通掩码"""
        rows, cols = shape
        center_row, center_col = rows // 2, cols // 2
        
        r, c = np.ogrid[:rows, :cols]
        distance = np.sqrt((r - center_row)**2 + (c - center_col)**2)
        
        return (distance <= cutoff * max(rows, cols)).astype(float)
    
    def _gaussian_lowpass_mask(self, shape, cutoff):
        """高斯低通掩码"""
        rows, cols = shape
        center_row, center_col = rows // 2, cols // 2
        
        r, c = np.ogrid[:rows, :cols]
        distance = np.sqrt((r - center_row)**2 + (c - center_col)**2)
        sigma = cutoff * max(rows, cols)
        
        return np.exp(-distance**2 / (2 * sigma**2))

频域中的泛化分析

频率成分与泛化

class FrequencyGeneralization:
    """
    频率视角的泛化分析
    """
    
    @staticmethod
    def decompose_train_test_error(model, X_train, X_test):
        """
        分解训练和测试误差的频率成分
        """
        def compute_frequency_content(X):
            """计算频率内容"""
            X_fft = np.fft.fft(X, axis=1)
            magnitude = np.abs(X_fft)
            
            # 低频 (0-20%), 中频 (20-50%), 高频 (50-100%)
            n_freqs = magnitude.shape[1] // 2
            low_freq = magnitude[:, :int(0.2 * n_freqs)].mean()
            mid_freq = magnitude[:, int(0.2 * n_freqs):int(0.5 * n_freqs)].mean()
            high_freq = magnitude[:, int(0.5 * n_freqs):].mean()
            
            return {'low': low_freq, 'mid': mid_freq, 'high': high_freq}
        
        train_freq = compute_frequency_content(X_train)
        test_freq = compute_frequency_content(X_test)
        
        return {
            'train': train_freq,
            'test': test_freq,
            'high_freq_ratio': test_freq['high'] / (train_freq['high'] + 1e-10)
        }
    
    @staticmethod
    def spectral_distance(train_weights, current_weights):
        """
        训练过程中权重的频谱距离
        """
        # 对每层计算谱距离
        distances = []
        
        for w_train, w_curr in zip(train_weights, current_weights):
            # FFT
            W_train_fft = np.fft.fft2(w_train)
            W_curr_fft = np.fft.fft2(w_curr)
            
            # 频域距离
            dist = np.abs(W_train_fft - W_curr_fft).mean()
            distances.append(dist)
        
        return np.mean(distances)

应用:频域正则化

频域损失

class FrequencyDomainLoss:
    """
    频域损失函数
    """
    
    @staticmethod
    def frequency_smooth_loss(prediction, target, alpha=0.1):
        """
        频率平滑损失
        
        惩罚高频成分,促进平滑预测
        """
        # 预测的 FFT
        pred_fft = np.fft.fft(prediction)
        
        # 高频能量
        n = len(pred_fft) // 2
        high_freq_energy = np.sum(np.abs(pred_fft[n:])**2)
        
        # 低频能量
        low_freq_energy = np.sum(np.abs(pred_fft[1:n])**2)
        
        # 高频比率
        high_freq_ratio = high_freq_energy / (low_freq_energy + high_freq_energy + 1e-10)
        
        # MSE 损失
        mse_loss = np.mean((prediction - target)**2)
        
        # 总损失
        total_loss = mse_loss + alpha * high_freq_ratio
        
        return total_loss, {'mse': mse_loss, 'high_freq_ratio': high_freq_ratio}
    
    @staticmethod
    def spectral_contrastive_loss(embeddings1, embeddings2, temperature=0.1):
        """
        频域对比损失
        
        在频域中匹配两个嵌入
        """
        # FFT
        emb1_fft = np.fft.fft(embeddings1, axis=-1)
        emb2_fft = np.fft.fft(embeddings2, axis=-1)
        
        # 归一化
        emb1_fft = emb1_fft / (np.abs(emb1_fft) + 1e-10)
        emb2_fft = emb2_fft / (np.abs(emb2_fft) + 1e-10)
        
        # 对比损失
        similarity = np.real(emb1_fft * np.conj(emb2_fft))
        
        return -np.mean(similarity)
 
 
class FrequencyRegularizedTraining:
    """
    频域正则化训练
    """
    
    def __init__(self, model, alpha=0.01):
        self.model = model
        self.alpha = alpha
    
    def training_step(self, x, y, optimizer):
        """
        带频域正则化的训练步骤
        """
        optimizer.zero_grad()
        
        output = self.model(x)
        
        # 标准损失
        loss_task = F.cross_entropy(output, y)
        
        # 频域正则化
        loss_freq = self._frequency_penalty()
        
        loss = loss_task + self.alpha * loss_freq
        
        loss.backward()
        optimizer.step()
        
        return {'total': loss.item(), 'task': loss_task.item(), 'freq': loss_freq.item()}
    
    def _frequency_penalty(self):
        """
        频域惩罚项
        
        惩罚权重的高频能量
        """
        penalty = 0.0
        
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                W = module.weight
                
                # FFT
                W_fft = torch.fft.fft2(W)
                
                # 高频能量
                magnitude = torch.abs(W_fft)
                n = magnitude.numel() // 2
                
                # 简化为高频比率
                high_freq = magnitude[..., n:].mean()
                low_freq = magnitude[..., :n].mean()
                
                penalty += high_freq / (low_freq + high_freq + 1e-10)
        
        return penalty

参考

Footnotes

  1. Mallat, S. (1999). A wavelet tour of signal processing.