Transformer Full Block Hessian Analysis

Abstract

This article provides a comprehensive analysis of the Hessian matrix for complete Transformer blocks, extending prior work on self-attention mechanisms to include explicit second-order expressions for LayerNorm and feed-forward networks (FFNs). The Hessian analysis reveals critical insights into curvature propagation, convergence dynamics, and the mathematical foundations of empirical scaling laws. 1


1. Introduction: Why Complete Hessian Analysis Matters

1.1 The Curvature Gap in Transformer Theory

While recent studies have derived Hessian expressions for self-attention mechanisms, the full Transformer block—including LayerNorm and feed-forward networks (FFNs)—lacks a comprehensive theoretical characterization. This gap limits our understanding of:

AspectWithout Full HessianWith Full Hessian
Optimization DynamicsPartial understandingComplete picture
Convergence RatesEmpirical estimatesTheoretical bounds
Scaling LawsPhenomenologicalMechanistic explanation
Curvature PropagationLayer-wiseEnd-to-end

1.2 Why Hessian Analysis is Essential

The Hessian matrix encodes the curvature of the loss landscape:

Key properties revealed by Hessian analysis:

  1. Spectral structure: Distribution of eigenvalues determines optimization difficulty
  2. Conditioning: affects convergence speed
  3. Negative curvature: Presence of saddle points and local maxima
  4. Block structure: Reveals interactions between sub-layers

1.3 Connections to Prior Work

This work builds upon and generalizes prior analyses:

  • Self-attention Hessian: Extends the framework from prior work on attention curvature 1
  • LayerNorm derivatives: First complete second-order treatment
  • FFN Hessian: Explicit derivations for feed-forward blocks
  • Block-wise assembly: Complete Transformer layer characterization

2. LayerNorm Hessian Derivation

2.1 LayerNorm Definition

For an input matrix , LayerNorm computes:

where:

2.2 Jacobian of LayerNorm

Theorem 2 (Jacobian of LayerNorm): 1

Let . Define:

Then LayerNorm can be expressed as:

Jacobian with respect to input:

where:

2.3 Hessian of LayerNorm

Theorem 3 (Hessian of LayerNorm): 1

The Hessian of LayerNorm with respect to its input is:

Key insight: Since , the Hessian depends only on the second derivatives of .

2.4 PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LayerNormHessian(nn.Module):
    """
    LayerNorm with Hessian computation for analysis
    
    LayerNorm(x)_{i,j} = γ_j * (x_{i,j} - μ_i) / sqrt(σ_i² + ε) + β_j
    """
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.normalized_shape = normalized_shape
        self.eps = eps
        
        # Learnable parameters
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
    
    def forward(self, x):
        """
        x: (batch, seq_len, d_model) or (batch, d_model)
        """
        if x.dim() == 2:
            x = x.unsqueeze(1)
        
        batch, seq_len, d_model = x.shape
        
        # Compute mean and variance per feature dimension across last dim
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        
        # Normalize
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        # Scale and shift
        return self.weight * x_norm + self.bias
    
    def compute_jacobian(self, x):
        """
        Compute Jacobian ∂LayerNorm/∂x
        
        Returns: Jacobian matrix of shape (L*d_V, L*d_V)
        """
        x = x.clone().requires_grad_(True)
        output = self.forward(x)
        
        # For computational efficiency, return analytical form
        batch, seq_len, d_model = x.shape
        L, d = seq_len, d_model
        
        # Compute statistics
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        std = torch.sqrt(var + self.eps)
        
        # Create centering matrix M
        I_L = torch.eye(L, device=x.device, dtype=x.dtype)
        I_d = torch.eye(d, device=x.device, dtype=x.dtype)
        ones_d = torch.ones(d, d, device=x.device, dtype=x.dtype)
        
        # G = I_{Ld} - (1/d)(I_L ⊗ 1_{d×d})
        G = torch.eye(L * d, device=x.device, dtype=x.dtype) - \
            (1/d) * torch.kron(I_L, ones_d)
        
        # P = diag(1/σ_i)
        P = torch.diag(1.0 / std.squeeze(-1).reshape(-1))
        
        # Jacobian = (P ⊗ I_d) @ G
        jacobian = torch.kron(P, I_d) @ G
        
        return jacobian
    
    def compute_hessian_wrt_weight(self, x):
        """
        Compute Hessian ∂²Loss/∂γ² for analysis
        
        Simplified: diag(∂Loss/∂γ) structure
        """
        x = x.clone().requires_grad_(True)
        output = self.forward(x)
        
        # Simple approximation: Hessian wrt weight is diagonal
        # In practice, this couples through the loss
        batch, seq_len, d_model = x.shape
        
        # Return approximate diagonal Hessian
        return torch.eye(d_model, device=x.device, dtype=x.dtype)
 
 
def verify_layernorm_hessian():
    """Verify LayerNorm Hessian computations numerically"""
    torch.manual_seed(42)
    
    # Setup
    batch, seq_len, d_model = 2, 8, 16
    x = torch.randn(batch, seq_len, d_model, requires_grad=True)
    ln = LayerNormHessian(d_model)
    
    # Compute analytically
    jac_analytic = ln.compute_jacobian(x)
    
    # Verify numerically
    x_test = x.clone().detach().requires_grad_(True)
    output = ln(x_test)
    
    # Sample gradient checks
    eps = 1e-5
    grad_numerical = torch.zeros_like(jac_analytic)
    
    for i in range(min(5, seq_len * d_model)):
        for j in range(min(5, seq_len * d_model)):
            x_plus = x_test.clone()
            x_plus.flatten()[i] += eps
            x_minus = x_test.clone()
            x_minus.flatten()[i] -= eps
            
            out_plus = ln(x_plus).flatten()[j]
            out_minus = ln(x_minus).flatten()[j]
            
            grad_numerical[i, j] = (out_plus - out_minus) / (2 * eps)
    
    print(f"Jacobian shape: {jac_analytic.shape}")
    print(f"Numerical gradient approx:\n{grad_numerical[:5, :5]}")
    print(f"Max difference: {(jac_analytic[:5, :5] - grad_numerical).abs().max():.6f}")
 
 
if __name__ == "__main__":
    verify_layernorm_hessian()

3. Feed-Forward Network Hessian Derivation

3.1 FFN Definition

The feed-forward network in a Transformer is:

where:

  • (up-projection)
  • (down-projection)
  • is typically GELU or ReLU

3.2 ReLU Activation Derivatives

Lemma 1 (ReLU derivative and Hessian): 1

For , almost everywhere:

Key insight: ReLU’s Hessian is almost everywhere zero, making FFN Hessian computation more tractable.

3.3 FFN Block Hessian Structure

For a Transformer block with FFN:

Theorem 4 (Transformer block derivative): 1

The derivatives with respect to FFN parameters are:

3.4 FFN Hessian Code Implementation

class FFNHessian(nn.Module):
    """
    Feed-Forward Network with Hessian computation
    
    FFN(x) = W2 @ σ(W1 @ x) + x  (with residual)
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        
        # FFN weights
        self.w1 = nn.Linear(d_model, d_ff, bias=True)
        self.w2 = nn.Linear(d_ff, d_model, bias=True)
        
        # Activation
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        x: (batch, seq_len, d_model)
        """
        return self.dropout(self.w2(self.activation(self.w1(x)))) + x
    
    def compute_ffn_jacobians(self, x):
        """
        Compute Jacobians ∂FFN/∂W1 and ∂FFN/∂W2
        """
        batch, seq_len, d_model = x.shape
        
        # Pre-compute intermediate values
        h = self.w1(x)  # (batch, seq_len, d_ff)
        h_act = self.activation(h)  # (batch, seq_len, d_ff)
        
        # ∂σ/∂(W1 @ x) - GELU derivative
        # GELU'(x) = 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
        cdf = 0.5 * (1 + torch.tanh(torch.sqrt(torch.tensor(2/np.pi)) * 
                      (h + 0.044715 * h**3)))
        pdf = torch.exp(-0.5 * h**2) / torch.sqrt(2 * torch.tensor(np.pi))
        dgelu = cdf + h * pdf
        
        # Reshape for batch matrix multiplication
        # x: (batch, seq, d) -> (batch, seq*d)
        x_flat = x.reshape(batch, -1)
        h_act_flat = h_act.reshape(batch, -1)
        
        # ∂FFN/∂W1: chain rule through activation
        # ∂FFN/∂W1 = (∂σ/∂(W1x) ⊗ I) @ (x ⊗ I) @ W2
        d_ff = self.d_ff
        d_m = self.d_model
        
        # Simplified: use Kronecker structure
        # In practice, compute per-sample and aggregate
        
        return {
            'd_ffn_d_w1': None,  # Requires Kronecker product
            'd_ffn_d_w2': h_act,
            'activation_derivative': dgelu,
            'intermediate': h
        }
    
    def compute_hessian_wrt_w2(self, x):
        """
        Compute Hessian ∂²Loss/∂W2²
        
        Using Gauss-Newton approximation:
        H_GN ≈ J^T @ diag(r) @ J
        where r are the residuals
        """
        batch, seq_len, d_model = x.shape
        d_ff = self.d_ff
        
        # Forward pass
        h = self.w1(x)
        h_act = self.activation(h)
        
        # Simplified Hessian: block diagonal structure
        # Each output dimension has its own block
        
        # For each sample, the Hessian block for W2 is:
        # H_W2 = (h_act ⊗ I_d) @ diag(r) @ (h_act ⊗ I_d)^T
        
        hessian_blocks = []
        for b in range(batch):
            h_sample = h_act[b]  # (seq_len, d_ff)
            
            # Outer product structure
            # H = sum over sequence positions of:
            # outer(flattened(h_sample), flattened(h_sample))
            
            hessian = torch.zeros(d_ff * d_model, d_ff * d_model)
            
            for i in range(seq_len):
                v = h_sample[i]  # (d_ff,)
                # This creates rank-1 updates
                # In practice, use efficient computation
            
            hessian_blocks.append(hessian)
        
        return torch.stack(hessian_blocks).mean(0)  # Average over batch
 
 
class TransformerBlockHessian(nn.Module):
    """
    Full Transformer Block with Hessian Analysis
    
    Z = LayerNorm(Y + FFN(LayerNorm(X + Attention(X))))
    """
    def __init__(self, d_model, d_ff, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # Components
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = FFNHessian(d_model, d_ff, dropout)
    
    def forward(self, x):
        # Pre-norm architecture (more stable gradients)
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ffn(self.norm2(x))
        return x
    
    def compute_block_hessian(self, x):
        """
        Compute full block Hessian structure
        
        Returns block matrix:
        [ H_aa  H_an  H_af
          H_na  H_nn  H_nf
          H_fa  H_fn  H_ff ]
        
        where a=attention, n=norm, f=ffn parameters
        """
        # This is a simplified representation
        # Full computation requires careful handling of all cross-terms
        
        block_sizes = {
            'attention': 3 * self.d_model**2,  # Q, K, V projections
            'norm1': self.d_model * 2,          # gamma, beta
            'ffn': 2 * self.d_model * self.ffn.d_ff,  # W1, W2
            'norm2': self.d_model * 2
        }
        
        total_params = sum(block_sizes.values())
        
        # Initialize block Hessian
        H = torch.zeros(total_params, total_params)
        
        return {
            'H': H,
            'block_sizes': block_sizes,
            'structure': 'block_sparse'  # Many zero blocks
        }

4. Self-Attention Hessian (Building on Prior Work)

4.1 Self-Attention Mechanism

Consider a single-head self-attention layer:

where the attention matrix is:

4.2 Attention Hessian Decomposition

The Hessian decomposes using the Gauss-Newton approximation:

ComponentDefinitionOrigin
Outer-product Hessian
Functional Hessian

4.3 Spectral Norm Bound for Attention

Theorem 1 (Hessian spectral norm): 1

Let be the spectral norm. For a single self-attention layer:

where involves terms like:

The full expression includes terms from:

  • - interactions
  • - interactions
  • - interactions
  • Diagonal terms (-)

4.4 Attention Hessian Implementation

class SelfAttentionHessian(nn.Module):
    """
    Self-Attention with Hessian Analysis
    """
    def __init__(self, d_model, n_heads, d_k=None, d_v=None):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k or d_model // n_heads
        self.d_v = d_v or d_model // n_heads
        
        # Projections
        self.W_Q = nn.Linear(d_model, self.d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, self.d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads, bias=False)
        self.out_proj = nn.Linear(self.d_v * n_heads, d_model)
    
    def forward(self, x, return_attention=False):
        batch, seq_len, d_model = x.shape
        
        # Linear projections
        Q = self.W_Q(x).view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(batch, seq_len, self.n_heads, self.d_v).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        
        # Output
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        output = self.out_proj(context)
        
        if return_attention:
            return output, attn_weights
        return output
    
    def compute_attention_hessian(self, x, target):
        """
        Compute Hessian of loss w.r.t. attention parameters
        
        Uses Gauss-Newton approximation for efficiency
        """
        x = x.clone().requires_grad_(True)
        
        # Forward pass
        output, attn = self.forward(x, return_attention=True)
        
        # MSE loss
        loss = F.mse_loss(output, target)
        
        # Backward to get gradients
        loss.backward()
        
        # Compute Jacobian structure
        with torch.no_grad():
            batch, seq_len, d_model = x.shape
            L = seq_len
            d_V = self.d_v * self.n_heads
            
            # Extract gradients for each projection
            grad_WQ = self.W_Q.weight.grad.clone()
            grad_WK = self.W_K.weight.grad.clone()
            grad_WV = self.W_V.weight.grad.clone()
        
        # Estimate Hessian norms using eigenvalue bounds
        # This is a simplified approximation
        hessian_estimates = {
            'H_WQ': torch.norm(grad_WQ) ** 2 / (grad_WQ.numel() ** 0.5),
            'H_WK': torch.norm(grad_WK) ** 2 / (grad_WK.numel() ** 0.5),
            'H_WV': torch.norm(grad_WV) ** 2 / (grad_WV.numel() ** 0.5),
        }
        
        return {
            'loss': loss.item(),
            'hessian_estimates': hessian_estimates,
            'attention': attn
        }
 
 
def analyze_attention_curvature(model, dataloader, device='cuda'):
    """
    Analyze attention Hessian across training
    """
    model.eval()
    curvature_history = []
    
    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        
        # Compute attention Hessian
        result = model.compute_attention_hessian(x, y)
        
        curvature_history.append({
            'batch': batch_idx,
            'loss': result['loss'],
            'h_WQ': result['hessian_estimates']['H_WQ'],
            'h_WK': result['hessian_estimates']['H_WK'],
            'h_WV': result['hessian_estimates']['H_WV']
        })
        
        if batch_idx >= 100:  # Analyze first 100 batches
            break
    
    return curvature_history

5. Curvature Propagation Through Transformer Blocks

5.1 Block Structure Overview

A complete Transformer block (Pre-norm architecture) is:

Parameter set:

5.2 Block Hessian Assembly

The full block Hessian has the following structure:

Key properties:

  • Sparse structure: Many zero blocks due to residual connections
  • Cross-layer coupling: Through LayerNorm
  • Block diagonal dominance: FFN and attention have relatively independent curvature

5.3 Curvature Flow Analysis

class CurvaturePropagationAnalyzer:
    """
    Analyze how curvature propagates through Transformer blocks
    """
    
    def __init__(self, model):
        self.model = model
        self.curvature_history = []
    
    def compute_layer_curvature(self, x, layer_idx):
        """
        Compute effective curvature for a specific layer
        """
        layer = self.model.transformer.layers[layer_idx]
        
        with torch.no_grad():
            # Forward through layer
            x_norm = layer.norm1(x)
            attn_out, attn = layer.self_attn(x_norm, x_norm, x_norm, return_attention=True)
            x = x + attn_out
            
            x_norm2 = layer.norm2(x)
            ffn_out = layer.mlp(x_norm2)
            x = x + ffn_out
        
        # Estimate local curvature
        # Using gradient norm as proxy
        x_grad = x.clone().requires_grad_(True)
        output = self.model.head(x_grad.mean(dim=1))
        
        # Simplified: use Fisher information as curvature proxy
        # H ≈ F = E[∇ℓ∇ℓ^T]
        
        return {
            'layer': layer_idx,
            'attn_curvature': attn.std().item(),
            'gradient_scale': x.norm().item(),
            'attention_scores': attn
        }
    
    def analyze_curvature_flow(self, x):
        """
        Track curvature through all layers
        """
        curvatures = []
        
        for layer_idx in range(len(self.model.transformer.layers)):
            layer_curv = self.compute_layer_curvature(x, layer_idx)
            curvatures.append(layer_curv)
            
            # Update x for next layer (in practice, this is done in forward pass)
            x = self.model.transformer.layers[layer_idx](x)
        
        self.curvature_history.append(curvatures)
        return curvatures
    
    def visualize_curvature_flow(self):
        """
        Visualize how curvature changes across layers
        """
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        layers = list(range(len(self.curvature_history[0])))
        
        # Plot attention curvature
        ax1 = axes[0, 0]
        for step, curvatures in enumerate(self.curvature_history[:10]):
            attn_curvs = [c['attn_curvature'] for c in curvatures]
            ax1.plot(layers, attn_curvs, label=f'Step {step}', alpha=0.7)
        ax1.set_xlabel('Layer')
        ax1.set_ylabel('Attention Curvature')
        ax1.set_title('Attention Curvature Flow')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot gradient scale
        ax2 = axes[0, 1]
        for step, curvatures in enumerate(self.curvature_history[:10]):
            grad_scales = [c['gradient_scale'] for c in curvatures]
            ax2.plot(layers, grad_scales, label=f'Step {step}', alpha=0.7)
        ax2.set_xlabel('Layer')
        ax2.set_ylabel('Gradient Scale')
        ax2.set_title('Gradient Scale Flow')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Average curvature by layer
        ax3 = axes[1, 0]
        avg_attn_curv = torch.stack([
            torch.tensor([c['attn_curvature'] for c in step]) 
            for step in self.curvature_history
        ]).mean(0)
        ax3.bar(layers, avg_attn_curv.numpy())
        ax3.set_xlabel('Layer')
        ax3.set_ylabel('Average Curvature')
        ax3.set_title('Average Attention Curvature by Layer')
        
        # Heatmap of attention curvature
        ax4 = axes[1, 1]
        curv_matrix = torch.stack([
            torch.tensor([c['attn_curvature'] for c in step]) 
            for step in self.curvature_history
        ]).numpy()
        im = ax4.imshow(curv_matrix, aspect='auto', cmap='viridis')
        ax4.set_xlabel('Layer')
        ax4.set_ylabel('Training Step')
        ax4.set_title('Curvature Heatmap')
        plt.colorbar(im, ax=ax4)
        
        plt.tight_layout()
        plt.savefig('curvature_flow.png', dpi=150)
        plt.show()

5.4 Mathematical Framework for Curvature Propagation

Key theorems for curvature flow:

  1. LayerNorm stabilizes curvature: The normalization ensures that the input-dependent Hessian terms remain bounded

  2. FFN preserves block structure: Due to ReLU’s zero Hessian, FFN Hessian is relatively simple

  3. Attention introduces data-dependent coupling: The softmax Jacobian creates correlations across the sequence


6. Implications for Convergence Dynamics

6.1 Loss Landscape Convergence

As the dataset size increases, the loss landscape converges to a more stable structure:

Key insight: Larger datasets lead to:

  • Smoother optimization landscapes
  • Better-conditioned Hessians
  • More predictable convergence

6.2 Hessian-Based Convergence Analysis

Theorem (Convergence bound): 1

For empirical loss with Hessian :

where depends on the spectral norm bounds derived in previous sections.

6.3 Practical Convergence Diagnostics

class ConvergenceDiagnostics:
    """
    Use Hessian analysis for convergence diagnostics
    """
    
    def __init__(self, model):
        self.model = model
        self.loss_history = []
        self.grad_history = []
        self.curvature_history = []
    
    def compute_effective_conditioning(self, x, y):
        """
        Estimate condition number from gradient statistics
        """
        output = self.model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        
        # Collect gradients
        grads = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.flatten())
        
        all_grads = torch.cat(grads)
        
        # Estimate "local" condition number using gradient statistics
        grad_mean = all_grads.mean()
        grad_std = all_grads.std()
        grad_max = all_grads.abs().max()
        
        # Effective condition number proxy
        eff_cond = grad_max / (grad_std + 1e-8)
        
        return {
            'loss': loss.item(),
            'grad_norm': all_grads.norm().item(),
            'grad_mean': grad_mean.item(),
            'grad_std': grad_std.item(),
            'eff_condition_number': eff_cond.item()
        }
    
    def analyze_convergence_trajectory(self, dataloader, n_steps=100):
        """
        Analyze convergence using Hessian-based metrics
        """
        self.model.train()
        
        for step, (x, y) in enumerate(dataloader):
            if step >= n_steps:
                break
            
            # Compute diagnostics
            diag = self.compute_effective_conditioning(x, y)
            
            self.loss_history.append(diag['loss'])
            self.grad_history.append(diag['grad_norm'])
            self.curvature_history.append(diag['eff_condition_number'])
            
            # Training step
            self.model.zero_grad()
        
        return self._summarize_convergence()
    
    def _summarize_convergence(self):
        """
        Summarize convergence properties
        """
        import numpy as np
        
        loss_arr = np.array(self.loss_history)
        grad_arr = np.array(self.grad_history)
        cond_arr = np.array(self.curvature_history)
        
        return {
            'loss_decreased': loss_arr[-1] < loss_arr[0],
            'loss_reduction': (loss_arr[0] - loss_arr[-1]) / loss_arr[0],
            'gradient_decay_rate': np.polyfit(range(len(grad_arr)), np.log(grad_arr + 1e-10), 1)[0],
            'condition_number_trend': 'improving' if cond_arr[-1] < cond_arr[0] else 'degrading',
            'final_condition': cond_arr[-1]
        }

6.4 Optimization Recommendations

Based on Hessian analysis:

ObservationRecommendation
High condition number in early layersUse adaptive optimizers (Adam, AdaGrad)
Negative eigenvalues presentApply gradient clipping, use momentum
Curvature varies across layersLayer-wise learning rate scaling
Hessian norm grows with depthUse Pre-norm architecture

7. Connections to Empirical Scaling Laws

7.1 Neural Scaling Laws Background

Empirical scaling laws describe how test loss scales with model size and dataset size :

Key observations:

  • Larger models are more sample-efficient
  • Performance improves predictably with scale
  • Double descent phenomena at certain scales

7.2 Hessian Perspective on Scaling Laws

New perspective from full Hessian analysis:

  1. Dataset size and curvature:

    • Larger datasets better conditioned Hessians
    • Smoother loss landscape better generalization
  2. Model size and condition number:

    • Larger models have larger parameter space
    • Hessian spectrum widens: more large/small eigenvalues
    • This explains “critical batch size” phenomena
  3. Training dynamics:

    • Phase transitions in curvature during training
    • Initial phase: rapid curvature changes
    • Later phase: landscape stabilization

7.3 Theoretical Bounds for Scaling

Theorem (Loss landscape evolution): 1

The difference between loss functions trained on and samples satisfies:

This provides:

  • Theoretical justification for scaling law exponents
  • Quantitative predictions for dataset size requirements
  • Framework for compute-optimal training

7.4 Scaling Law Implementation

class ScalingLawAnalyzer:
    """
    Analyze scaling laws from Hessian perspective
    """
    
    def __init__(self, model_factory):
        self.model_factory = model_factory
    
    def compute_hessian_norm(self, model, dataloader, n_samples=100):
        """
        Compute Hessian spectral norm (approximation)
        """
        # Use power iteration for top eigenvalue
        model.eval()
        
        # Get a batch
        x, y = next(iter(dataloader))
        x, y = x[:n_samples], y[:n_samples]
        
        # Simplified: compute gradient covariance
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        
        # Collect gradients
        grads = []
        for p in model.parameters():
            if p.grad is not None:
                grads.append(p.grad.flatten())
        
        all_grads = torch.cat(grads)
        
        # Gradient covariance as Hessian proxy
        # H ≈ E[∇ℓ∇ℓ^T] for well-trained networks
        hessian_proxy = torch.outer(all_grads, all_grads)
        
        # Estimate spectral norm
        eigenvalues = torch.linalg.eigvalsh(hessian_proxy)
        
        return {
            'top_eigenvalue': eigenvalues[-1].item(),
            'bottom_eigenvalue': eigenvalues[0].item(),
            'condition_number': eigenvalues[-1].item() / (eigenvalues[0].item() + 1e-8),
            'effective_rank': len(eigenvalues) / (eigenvalues.sum() / (eigenvalues[0] + 1e-8))
        }
    
    def analyze_scaling_with_size(self, model_sizes, dataset_sizes):
        """
        Analyze how Hessian properties scale with model/data size
        """
        results = []
        
        for n_params in model_sizes:
            for n_samples in dataset_sizes:
                # Create model
                model = self.model_factory(n_params)
                
                # Compute Hessian properties
                hess_props = self.compute_hessian_norm(model, self.dataloader)
                
                results.append({
                    'n_params': n_params,
                    'n_samples': n_samples,
                    **hess_props
                })
        
        return pd.DataFrame(results)
    
    def predict_optimal_dataset_size(self, model_size, target_condition):
        """
        Predict optimal dataset size for target condition number
        
        Based on theory: condition_number ∝ 1/√(dataset_size)
        """
        # Empirical fit: use prior observations
        alpha = 0.5  # Theoretical exponent
        C = 1.0  # Scale constant
        
        n_samples = (C / target_condition) ** (1/alpha)
        
        return int(n_samples)

8. Taylor Expansion Framework for Loss Difference Analysis

8.1 Taylor Expansion Foundation

Given two loss functions and :

Second-order Taylor expansion around :

8.2 Convergence Trajectory Analysis

Framework for quantifying convergence:

class TaylorExpansionAnalyzer:
    """
    Taylor expansion-based framework for loss difference analysis
    """
    
    def __init__(self, model):
        self.model = model
        self.hessian_cache = {}
    
    def compute_taylor_coefficients(self, x, y, x_train):
        """
        Compute Taylor expansion coefficients for loss difference
        
        ΔL = L_{k+1} - L_k ≈ 1/2 * δw^T * H * δw
        """
        # Compute losses
        L_k = self.compute_loss(x, y, x_train[:-1])
        L_k1 = self.compute_loss(x, y, x_train)
        
        delta_L = L_k1 - L_k
        
        # Estimate local curvature
        x_grad = x.clone().requires_grad_(True)
        output = self.model(x_grad)
        loss = F.cross_entropy(output, y)
        
        # Get gradients
        grads = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.flatten())
        
        all_grads = torch.cat(grads)
        
        # Quadratic coefficient from gradient statistics
        grad_mom2 = (all_grads ** 2).mean()
        
        return {
            'delta_loss': delta_L,
            'quadratic_coefficient': grad_mom2.item(),
            'gradient_norm': all_grads.norm().item()
        }
    
    def predict_convergence_trajectory(self, n_steps, learning_rate):
        """
        Predict convergence using Taylor expansion
        
        Based on: δw_{t+1} ≈ (I - ηH) * δw_t
        """
        trajectory = []
        delta_w_norm = 1.0  # Initial distance to optimum
        
        for t in range(n_steps):
            # Taylor prediction
            delta_w_next = delta_w_norm * (1 - learning_rate * self.estimated_hessian_eigenvalue)
            
            trajectory.append({
                'step': t,
                'delta_w_norm': delta_w_norm,
                'delta_w_predicted': delta_w_next
            })
            
            delta_w_norm = delta_w_next
        
        return trajectory
    
    def estimate_hessian_from_trajectory(self, loss_trajectory, learning_rate):
        """
        Estimate average Hessian eigenvalue from loss trajectory
        
        Using: L_{t+1} - L_t ≈ -η * δw_t^T H ∇L
        """
        losses = np.array([t['loss'] for t in loss_trajectory])
        
        # Finite difference approximation
        dL = np.diff(losses)
        ddL = np.diff(dL)
        
        # Approximate average curvature
        avg_curvature = -2 * ddL.mean() / (learning_rate ** 2 + 1e-8)
        
        return {
            'avg_curvature': avg_curvature,
            'convergence_rate': np.exp(np.polyfit(range(len(dL)), np.log(np.abs(dL) + 1e-10), 1)[0])
        }

8.3 Mathematical Framework

Key results from Taylor expansion analysis:

  1. Second-order term dominance: For well-conditioned regions, quadratic term dominates

  2. Convergence rate prediction:

  3. Critical learning rate:


9. Code Examples in PyTorch

9.1 Complete Transformer Hessian Analysis Pipeline

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
 
 
class TransformerHessianAnalyzer:
    """
    Complete Hessian analysis for Transformer models
    """
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.history = {
            'hessian_norms': [],
            'condition_numbers': [],
            'eigenvalue_spectra': [],
            'layer_curvatures': []
        }
    
    def compute_hessian_eigenvalues(self, x: torch.Tensor, y: torch.Tensor, 
                                   k: int = 20) -> torch.Tensor:
        """
        Compute top-k Hessian eigenvalues using power iteration
        """
        # Enable gradient computation
        x = x.clone().detach().requires_grad_(True)
        
        # Forward and backward pass
        output = self.model(x)
        loss = F.cross_entropy(output.view(-1, output.size(-1)), y.view(-1))
        loss.backward()
        
        # Collect flat gradient
        grads = []
        params = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.flatten())
                params.append(p.data.flatten())
        
        all_grads = torch.cat(grads).unsqueeze(1)  # (n, 1)
        all_params = torch.cat(params)
        
        n = len(all_params)
        
        # Random initialization for power iteration
        v = torch.randn(n, 1)
        v = v / v.norm()
        
        eigenvalues = []
        
        for _ in range(k):
            # Hessian-vector product approximation
            # H @ v ≈ (1/η) * (g(w + ηv) - g(w)) / η
            eta = 1e-5
            
            # Save original parameters
            orig_params = all_params.clone()
            
            # Perturb
            all_params.data = orig_params + eta * v.squeeze()
            
            # Compute new gradient
            self._set_model_params(all_params)
            output = self.model(x)
            loss_new = F.cross_entropy(output.view(-1, output.size(-1)), y.view(-1))
            loss_new.backward()
            
            grads_new = []
            for p in self.model.parameters():
                if p.grad is not None:
                    grads_new.append(p.grad.flatten())
            all_grads_new = torch.cat(grads_new).unsqueeze(1)
            
            # Restore original
            all_params.data = orig_params
            self._set_model_params(all_params)
            
            # Compute approximate Hv
            Hv = (all_grads_new - all_grads) / eta
            
            # Power iteration
            v_new = Hv / (Hv.norm() + 1e-10)
            
            # Rayleigh quotient
            eigenvalue = (v.t() @ Hv).item() / (v.t() @ v).item()
            eigenvalues.append(eigenvalue)
            
            v = v_new
        
        return torch.tensor(eigenvalues)
    
    def _set_model_params(self, flat_params: torch.Tensor):
        """Set model parameters from flat tensor"""
        idx = 0
        for p in self.model.parameters():
            size = p.numel()
            p.data = flat_params[idx:idx+size].view(p.shape)
            idx += size
    
    def analyze_full_hessian(self, dataloader: torch.utils.data.DataLoader,
                            n_batches: int = 10) -> Dict:
        """
        Comprehensive Hessian analysis
        """
        self.model.eval()
        
        all_eigenvalues = []
        hessian_norms = []
        
        for i, (x, y) in enumerate(dataloader):
            if i >= n_batches:
                break
            
            x, y = x.cuda(), y.cuda()
            
            # Compute eigenvalues
            eigenvals = self.compute_hessian_eigenvalues(x, y, k=10)
            all_eigenvalues.append(eigenvals)
            
            # Hessian norm estimate
            hessian_norm = eigenvals.max().item() if len(eigenvals) > 0 else 0
            hessian_norms.append(hessian_norm)
            
            self.history['hessian_norms'].append(hessian_norm)
        
        all_eigenvals_cat = torch.cat(all_eigenvalues)
        
        return {
            'mean_hessian_norm': np.mean(hessian_norms),
            'eigenvalue_mean': all_eigenvals_cat.mean().item(),
            'eigenvalue_std': all_eigenvals_cat.std().item(),
            'top_eigenvalue': all_eigenvals_cat.max().item(),
            'spectral_gap': all_eigenvals_cat.max().item() - all_eigenvals_cat.min().item()
        }
    
    def analyze_layer_wise_curvature(self, x: torch.Tensor) -> List[Dict]:
        """
        Analyze curvature layer by layer
        """
        self.model.eval()
        x = x.cuda().requires_grad_(True)
        
        layer_curvatures = []
        
        def hook_fn(name):
            def hook(module, input, output):
                # Compute local gradient norm as curvature proxy
                if output.grad is not None:
                    curvature = output.grad.norm().item()
                else:
                    curvature = 0
                
                layer_curvatures.append({
                    'name': name,
                    'output_norm': output.norm().item(),
                    'curvature_proxy': curvature
                })
            return hook
        
        # Register hooks
        hooks = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.MultiheadAttention)):
                hooks.append(module.register_forward_hook(hook_fn(name)))
        
        # Forward pass
        output = self.model(x)
        loss = F.cross_entropy(output.view(-1, output.size(-1)), torch.randint(0, output.size(-1), x.shape[:2]).cuda())
        loss.backward()
        
        # Remove hooks
        for h in hooks:
            h.remove()
        
        return layer_curvatures
 
 
class CurvatureAwareOptimizer:
    """
    Optimizer that adapts based on Hessian information
    """
    
    def __init__(self, model: nn.Module, base_lr: float = 1e-3,
                 curvature_adaptation: bool = True):
        self.model = model
        self.base_lr = base_lr
        self.curvature_adaptation = curvature_adaptation
        self.hessian_estimates = {}
        
        # Adam optimizer
        self.optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=base_lr,
            betas=(0.9, 0.999),
            eps=1e-8
        )
    
    def estimate_local_curvature(self, x: torch.Tensor, y: torch.Tensor) -> Dict:
        """
        Estimate local curvature for adaptive learning rate
        """
        self.model.eval()
        
        x, y = x.cuda(), y.cuda()
        
        # Compute gradient
        output = self.model(x)
        loss = F.cross_entropy(output.view(-1, output.size(-1)), y.view(-1))
        loss.backward()
        
        # Collect gradient statistics
        grad_norms = []
        for p in self.model.parameters():
            if p.grad is not None:
                grad_norms.append(p.grad.norm().item())
        
        # Estimate condition number proxy
        max_grad = max(grad_norms) if grad_norms else 1
        min_grad = min(grad_norms) if grad_norms else 1
        
        cond_proxy = max_grad / (min_grad + 1e-8)
        
        self.hessian_estimates['condition_number'] = cond_proxy
        
        return {
            'condition_number': cond_proxy,
            'max_gradient': max_grad,
            'min_gradient': min_grad,
            'loss': loss.item()
        }
    
    def step_with_curvature_adaptation(self, x: torch.Tensor, y: torch.Tensor):
        """
        Perform optimization step with curvature-aware learning rate
        """
        if self.curvature_adaptation:
            curv = self.estimate_local_curvature(x, y)
            
            # Adapt learning rate based on condition number
            adaptation_factor = min(1.0, 1.0 / (curv['condition_number'] ** 0.25 + 1e-8))
            
            # Update optimizer's learning rate
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.base_lr * adaptation_factor
        
        # Standard optimization step
        self.optimizer.step()
        self.optimizer.zero_grad()
        
        return self.hessian_estimates.copy()
 
 
def demo_hessian_analysis():
    """
    Demonstration of Transformer Hessian analysis
    """
    # Create a small Transformer model
    class SimpleTransformer(nn.Module):
        def __init__(self, vocab_size=1000, d_model=64, nhead=4, num_layers=2):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, d_model)
            self.pos_encoder = nn.Parameter(torch.randn(1, 100, d_model) * 0.1)
            self.transformer = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=256, batch_first=True),
                num_layers=num_layers
            )
            self.fc = nn.Linear(d_model, vocab_size)
        
        def forward(self, x):
            x = self.embedding(x) + self.pos_encoder[:, :x.size(1), :]
            x = self.transformer(x)
            return self.fc(x)
    
    # Create model and analyzer
    model = SimpleTransformer().cuda()
    analyzer = TransformerHessianAnalyzer(model)
    optimizer = CurvatureAwareOptimizer(model)
    
    # Create dummy data
    batch_size = 8
    seq_len = 32
    vocab_size = 1000
    
    # Training loop (simplified)
    for step in range(50):
        x = torch.randint(0, vocab_size, (batch_size, seq_len)).cuda()
        y = torch.randint(0, vocab_size, (batch_size, seq_len)).cuda()
        
        # Analyze every 10 steps
        if step % 10 == 0:
            results = analyzer.analyze_full_hessian([(x, y)], n_batches=1)
            print(f"Step {step}: Hessian norm = {results['mean_hessian_norm']:.4f}")
        
        # Optimizer step
        optimizer.step_with_curvature_adaptation(x, y)
    
    print("Hessian analysis complete!")
    
    return analyzer, optimizer
 
 
if __name__ == "__main__":
    demo_hessian_analysis()

9.2 Visualization and Analysis Utilities

import matplotlib.pyplot as plt
import seaborn as sns
 
 
def visualize_hessian_spectrum(eigenvalues: torch.Tensor, 
                               title: str = "Hessian Eigenvalue Spectrum"):
    """
    Visualize Hessian eigenvalue distribution
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    eigenvalues = eigenvalues.cpu().numpy()
    
    # Histogram
    axes[0].hist(eigenvalues, bins=50, edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Eigenvalue')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Eigenvalue Distribution')
    
    # Cumulative
    sorted_eig = np.sort(eigenvalues)[::-1]
    axes[1].plot(range(len(sorted_eig)), sorted_eig, 'b-', linewidth=2)
    axes[1].set_xlabel('Eigenvalue Index')
    axes[1].set_ylabel('Eigenvalue')
    axes[1].set_title('Sorted Eigenvalues')
    axes[1].set_yscale('log')
    
    # Statistics
    stats_text = f"""Statistics:
    Max: {eigenvalues.max():.4f}
    Min: {eigenvalues.min():.4f}
    Mean: {eigenvalues.mean():.4f}
    Std: {eigenvalues.std():.4f}
    Condition: {eigenvalues.max()/eigenvalues.min():.2f}"""
    
    axes[2].text(0.1, 0.5, stats_text, transform=axes[2].transAxes,
                fontsize=12, verticalalignment='center',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    axes[2].axis('off')
    axes[2].set_title('Summary Statistics')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig('hessian_spectrum.png', dpi=150)
    plt.show()
 
 
def visualize_curvature_flow(curvature_history: List[List[Dict]]):
    """
    Visualize curvature propagation through layers
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Aggregate data
    n_layers = len(curvature_history[0])
    n_steps = len(curvature_history)
    
    layer_curvatures = np.zeros((n_steps, n_layers))
    layer_norms = np.zeros((n_steps, n_layers))
    
    for step, layers in enumerate(curvature_history):
        for layer_idx, data in enumerate(layers):
            layer_curvatures[step, layer_idx] = data.get('curvature_proxy', 0)
            layer_norms[step, layer_idx] = data.get('output_norm', 0)
    
    # Heatmap: curvature
    sns.heatmap(layer_curvatures, ax=axes[0, 0], cmap='viridis', aspect='auto')
    axes[0, 0].set_xlabel('Layer')
    axes[0, 0].set_ylabel('Training Step')
    axes[0, 0].set_title('Curvature Flow (Gradient Norm)')
    
    # Heatmap: activation norms
    sns.heatmap(layer_norms, ax=axes[0, 1], cmap='plasma', aspect='auto')
    axes[0, 1].set_xlabel('Layer')
    axes[0, 1].set_ylabel('Training Step')
    axes[0, 1].set_title('Activation Norm Flow')
    
    # Line plot: average curvature per layer
    for layer_idx in range(n_layers):
        axes[1, 0].plot(layer_curvatures[:, layer_idx], 
                       label=f'Layer {layer_idx}', alpha=0.7)
    axes[1, 0].set_xlabel('Training Step')
    axes[1, 0].set_ylabel('Curvature Proxy')
    axes[1, 0].set_title('Curvature by Layer Over Training')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Bar plot: final curvature
    final_curvatures = layer_curvatures[-1]
    axes[1, 1].bar(range(n_layers), final_curvatures)
    axes[1, 1].set_xlabel('Layer')
    axes[1, 1].set_ylabel('Final Curvature')
    axes[1, 1].set_title('Final Layer Curvatures')
    axes[1, 1].set_xticks(range(n_layers))
    
    plt.tight_layout()
    plt.savefig('curvature_flow.png', dpi=150)
    plt.show()
 
 
def plot_scaling_law_predictions(model_sizes: List[int],
                                  hessian_norms: List[float],
                                  dataset_sizes: List[int]):
    """
    Plot predictions from Hessian-based scaling law analysis
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Hessian norm vs model size
    axes[0].loglog(model_sizes, hessian_norms, 'bo-', linewidth=2, markersize=8)
    axes[0].set_xlabel('Model Parameters')
    axes[0].set_ylabel('Hessian Norm')
    axes[0].set_title('Hessian Norm vs Model Size')
    axes[0].grid(True, alpha=0.3)
    
    # Fit power law
    log_params = np.log(model_sizes)
    log_hessian = np.log(hessian_norms)
    slope, intercept = np.polyfit(log_params, log_hessian, 1)
    
    # Plot fit
    fit_line = np.exp(intercept) * np.array(model_sizes) ** slope
    axes[0].loglog(model_sizes, fit_line, 'r--', linewidth=2,
                   label=f'Power law: N^{slope:.2f}')
    axes[0].legend()
    
    # Dataset size effect
    dataset_sizes_arr = np.array(dataset_sizes)
    # Theoretical: Hessian norm ∝ 1/√(dataset_size)
    theoretical = 1.0 / np.sqrt(dataset_sizes_arr)
    theoretical = theoretical / theoretical[0] * hessian_norms[0]
    
    axes[1].plot(dataset_sizes_arr, hessian_norms[:len(dataset_sizes)], 
                 'bo-', linewidth=2, markersize=8, label='Observed')
    axes[1].plot(dataset_sizes_arr, theoretical, 'r--', linewidth=2, 
                 label=r'Theoretical: $1/\sqrt{D}$')
    axes[1].set_xlabel('Dataset Size')
    axes[1].set_ylabel('Hessian Norm')
    axes[1].set_title('Hessian Norm vs Dataset Size')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('scaling_laws.png', dpi=150)
    plt.show()

10. References to arXiv:2510.16927

10.1 Paper Overview

Title: Closing the Curvature Gap: Full Transformer Hessians and Their Implications for Scaling Laws

Authors: Egor Petrov, Nikita Kiselev, Vladislav Meshkov, Andrey Grabovoy

Institution: Yandex BRAIn Lab, Moscow State University

arXiv: arXiv:2510.16927

10.2 Key Contributions

ContributionDescription
Complete Hessian CharacterizationFirst full Hessian analysis including LayerNorm and FFN
Explicit Second-Order ExpressionsClosed-form derivatives for all Transformer components
Loss Landscape Convergence BoundsRigorous convergence rates
Scaling Law ImplicationsTheoretical framework connecting Hessian to empirical scaling laws
Taylor Expansion FrameworkNew method for analyzing convergence trajectories

10.3 Theoretical Results Summary

Theorem Summary:

  1. LayerNorm Jacobian (Theorem 2): Complete derivation of

  2. LayerNorm Hessian (Theorem 3): First complete second-order expression

  3. Self-Attention Hessian Bound (Theorem 1): Spectral norm bound for attention Hessian

  4. Transformer Block Derivative (Theorem 4): Full derivative expressions for complete blocks

  5. Loss Convergence Bound: rate for landscape stabilization

10.4 Implications and Applications

The paper’s results enable:

  1. Better Optimizer Design: Curvature-aware optimization strategies

  2. Data Budgeting: Theoretical guidance on dataset size requirements

  3. Architecture Search: Hessian-informed architecture design

  4. Generalization Bounds: Connection between curvature and generalization

  5. Training Diagnostics: Early detection of optimization difficulties


11. Summary and Future Directions

11.1 Key Takeaways

  1. Complete Hessian Analysis: This work completes the theoretical characterization of Transformer Hessians by providing explicit expressions for LayerNorm and FFN components.

  2. Curvature Propagation: Understanding how curvature flows through Transformer blocks reveals optimization challenges and informs architectural decisions.

  3. Scaling Law Connection: Hessian analysis provides a mechanistic explanation for empirical scaling laws, connecting dataset size to loss landscape smoothness.

  4. Practical Implications: The theoretical framework enables better optimization strategies, data budgeting, and training diagnostics.

11.2 Open Questions

QuestionDirection
Higher-order interactionsExtend analysis to third-order derivatives
Non-Gaussian lossesAnalyze for contrastive, GAN, and other losses
Dynamic analysisTime-varying Hessian during training
Large-scale validationEmpirical validation on billion-parameter models
Alternative architecturesApply framework to Mamba, RWKV, etc.

11.3 Further Reading


References


Last updated: 2026-05-03

Footnotes

  1. Petrov, E., Kiselev, N., Meshkov, V., & Grabovoy, A. (2025). “Closing the Curvature Gap: Full Transformer Hessians and Their Implications for Scaling Laws.” arXiv:2510.16927. https://arxiv.org/abs/2510.16927 2 3 4 5 6 7 8 9