Transformer Full Block Hessian Analysis
Abstract
This article provides a comprehensive analysis of the Hessian matrix for complete Transformer blocks, extending prior work on self-attention mechanisms to include explicit second-order expressions for LayerNorm and feed-forward networks (FFNs). The Hessian analysis reveals critical insights into curvature propagation, convergence dynamics, and the mathematical foundations of empirical scaling laws. 1
1. Introduction: Why Complete Hessian Analysis Matters
1.1 The Curvature Gap in Transformer Theory
While recent studies have derived Hessian expressions for self-attention mechanisms, the full Transformer block—including LayerNorm and feed-forward networks (FFNs)—lacks a comprehensive theoretical characterization. This gap limits our understanding of:
| Aspect | Without Full Hessian | With Full Hessian |
|---|---|---|
| Optimization Dynamics | Partial understanding | Complete picture |
| Convergence Rates | Empirical estimates | Theoretical bounds |
| Scaling Laws | Phenomenological | Mechanistic explanation |
| Curvature Propagation | Layer-wise | End-to-end |
1.2 Why Hessian Analysis is Essential
The Hessian matrix encodes the curvature of the loss landscape:
Key properties revealed by Hessian analysis:
- Spectral structure: Distribution of eigenvalues determines optimization difficulty
- Conditioning: affects convergence speed
- Negative curvature: Presence of saddle points and local maxima
- Block structure: Reveals interactions between sub-layers
1.3 Connections to Prior Work
This work builds upon and generalizes prior analyses:
- Self-attention Hessian: Extends the framework from prior work on attention curvature 1
- LayerNorm derivatives: First complete second-order treatment
- FFN Hessian: Explicit derivations for feed-forward blocks
- Block-wise assembly: Complete Transformer layer characterization
2. LayerNorm Hessian Derivation
2.1 LayerNorm Definition
For an input matrix , LayerNorm computes:
where:
2.2 Jacobian of LayerNorm
Theorem 2 (Jacobian of LayerNorm): 1
Let . Define:
Then LayerNorm can be expressed as:
Jacobian with respect to input:
where:
2.3 Hessian of LayerNorm
Theorem 3 (Hessian of LayerNorm): 1
The Hessian of LayerNorm with respect to its input is:
Key insight: Since , the Hessian depends only on the second derivatives of .
2.4 PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNormHessian(nn.Module):
"""
LayerNorm with Hessian computation for analysis
LayerNorm(x)_{i,j} = γ_j * (x_{i,j} - μ_i) / sqrt(σ_i² + ε) + β_j
"""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
self.normalized_shape = normalized_shape
self.eps = eps
# Learnable parameters
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x):
"""
x: (batch, seq_len, d_model) or (batch, d_model)
"""
if x.dim() == 2:
x = x.unsqueeze(1)
batch, seq_len, d_model = x.shape
# Compute mean and variance per feature dimension across last dim
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalize
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Scale and shift
return self.weight * x_norm + self.bias
def compute_jacobian(self, x):
"""
Compute Jacobian ∂LayerNorm/∂x
Returns: Jacobian matrix of shape (L*d_V, L*d_V)
"""
x = x.clone().requires_grad_(True)
output = self.forward(x)
# For computational efficiency, return analytical form
batch, seq_len, d_model = x.shape
L, d = seq_len, d_model
# Compute statistics
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
std = torch.sqrt(var + self.eps)
# Create centering matrix M
I_L = torch.eye(L, device=x.device, dtype=x.dtype)
I_d = torch.eye(d, device=x.device, dtype=x.dtype)
ones_d = torch.ones(d, d, device=x.device, dtype=x.dtype)
# G = I_{Ld} - (1/d)(I_L ⊗ 1_{d×d})
G = torch.eye(L * d, device=x.device, dtype=x.dtype) - \
(1/d) * torch.kron(I_L, ones_d)
# P = diag(1/σ_i)
P = torch.diag(1.0 / std.squeeze(-1).reshape(-1))
# Jacobian = (P ⊗ I_d) @ G
jacobian = torch.kron(P, I_d) @ G
return jacobian
def compute_hessian_wrt_weight(self, x):
"""
Compute Hessian ∂²Loss/∂γ² for analysis
Simplified: diag(∂Loss/∂γ) structure
"""
x = x.clone().requires_grad_(True)
output = self.forward(x)
# Simple approximation: Hessian wrt weight is diagonal
# In practice, this couples through the loss
batch, seq_len, d_model = x.shape
# Return approximate diagonal Hessian
return torch.eye(d_model, device=x.device, dtype=x.dtype)
def verify_layernorm_hessian():
"""Verify LayerNorm Hessian computations numerically"""
torch.manual_seed(42)
# Setup
batch, seq_len, d_model = 2, 8, 16
x = torch.randn(batch, seq_len, d_model, requires_grad=True)
ln = LayerNormHessian(d_model)
# Compute analytically
jac_analytic = ln.compute_jacobian(x)
# Verify numerically
x_test = x.clone().detach().requires_grad_(True)
output = ln(x_test)
# Sample gradient checks
eps = 1e-5
grad_numerical = torch.zeros_like(jac_analytic)
for i in range(min(5, seq_len * d_model)):
for j in range(min(5, seq_len * d_model)):
x_plus = x_test.clone()
x_plus.flatten()[i] += eps
x_minus = x_test.clone()
x_minus.flatten()[i] -= eps
out_plus = ln(x_plus).flatten()[j]
out_minus = ln(x_minus).flatten()[j]
grad_numerical[i, j] = (out_plus - out_minus) / (2 * eps)
print(f"Jacobian shape: {jac_analytic.shape}")
print(f"Numerical gradient approx:\n{grad_numerical[:5, :5]}")
print(f"Max difference: {(jac_analytic[:5, :5] - grad_numerical).abs().max():.6f}")
if __name__ == "__main__":
verify_layernorm_hessian()3. Feed-Forward Network Hessian Derivation
3.1 FFN Definition
The feed-forward network in a Transformer is:
where:
- (up-projection)
- (down-projection)
- is typically GELU or ReLU
3.2 ReLU Activation Derivatives
Lemma 1 (ReLU derivative and Hessian): 1
For , almost everywhere:
Key insight: ReLU’s Hessian is almost everywhere zero, making FFN Hessian computation more tractable.
3.3 FFN Block Hessian Structure
For a Transformer block with FFN:
Theorem 4 (Transformer block derivative): 1
The derivatives with respect to FFN parameters are:
3.4 FFN Hessian Code Implementation
class FFNHessian(nn.Module):
"""
Feed-Forward Network with Hessian computation
FFN(x) = W2 @ σ(W1 @ x) + x (with residual)
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
# FFN weights
self.w1 = nn.Linear(d_model, d_ff, bias=True)
self.w2 = nn.Linear(d_ff, d_model, bias=True)
# Activation
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
x: (batch, seq_len, d_model)
"""
return self.dropout(self.w2(self.activation(self.w1(x)))) + x
def compute_ffn_jacobians(self, x):
"""
Compute Jacobians ∂FFN/∂W1 and ∂FFN/∂W2
"""
batch, seq_len, d_model = x.shape
# Pre-compute intermediate values
h = self.w1(x) # (batch, seq_len, d_ff)
h_act = self.activation(h) # (batch, seq_len, d_ff)
# ∂σ/∂(W1 @ x) - GELU derivative
# GELU'(x) = 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
cdf = 0.5 * (1 + torch.tanh(torch.sqrt(torch.tensor(2/np.pi)) *
(h + 0.044715 * h**3)))
pdf = torch.exp(-0.5 * h**2) / torch.sqrt(2 * torch.tensor(np.pi))
dgelu = cdf + h * pdf
# Reshape for batch matrix multiplication
# x: (batch, seq, d) -> (batch, seq*d)
x_flat = x.reshape(batch, -1)
h_act_flat = h_act.reshape(batch, -1)
# ∂FFN/∂W1: chain rule through activation
# ∂FFN/∂W1 = (∂σ/∂(W1x) ⊗ I) @ (x ⊗ I) @ W2
d_ff = self.d_ff
d_m = self.d_model
# Simplified: use Kronecker structure
# In practice, compute per-sample and aggregate
return {
'd_ffn_d_w1': None, # Requires Kronecker product
'd_ffn_d_w2': h_act,
'activation_derivative': dgelu,
'intermediate': h
}
def compute_hessian_wrt_w2(self, x):
"""
Compute Hessian ∂²Loss/∂W2²
Using Gauss-Newton approximation:
H_GN ≈ J^T @ diag(r) @ J
where r are the residuals
"""
batch, seq_len, d_model = x.shape
d_ff = self.d_ff
# Forward pass
h = self.w1(x)
h_act = self.activation(h)
# Simplified Hessian: block diagonal structure
# Each output dimension has its own block
# For each sample, the Hessian block for W2 is:
# H_W2 = (h_act ⊗ I_d) @ diag(r) @ (h_act ⊗ I_d)^T
hessian_blocks = []
for b in range(batch):
h_sample = h_act[b] # (seq_len, d_ff)
# Outer product structure
# H = sum over sequence positions of:
# outer(flattened(h_sample), flattened(h_sample))
hessian = torch.zeros(d_ff * d_model, d_ff * d_model)
for i in range(seq_len):
v = h_sample[i] # (d_ff,)
# This creates rank-1 updates
# In practice, use efficient computation
hessian_blocks.append(hessian)
return torch.stack(hessian_blocks).mean(0) # Average over batch
class TransformerBlockHessian(nn.Module):
"""
Full Transformer Block with Hessian Analysis
Z = LayerNorm(Y + FFN(LayerNorm(X + Attention(X))))
"""
def __init__(self, d_model, d_ff, n_heads, dropout=0.1):
super().__init__()
self.d_model = d_model
# Components
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = FFNHessian(d_model, d_ff, dropout)
def forward(self, x):
# Pre-norm architecture (more stable gradients)
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ffn(self.norm2(x))
return x
def compute_block_hessian(self, x):
"""
Compute full block Hessian structure
Returns block matrix:
[ H_aa H_an H_af
H_na H_nn H_nf
H_fa H_fn H_ff ]
where a=attention, n=norm, f=ffn parameters
"""
# This is a simplified representation
# Full computation requires careful handling of all cross-terms
block_sizes = {
'attention': 3 * self.d_model**2, # Q, K, V projections
'norm1': self.d_model * 2, # gamma, beta
'ffn': 2 * self.d_model * self.ffn.d_ff, # W1, W2
'norm2': self.d_model * 2
}
total_params = sum(block_sizes.values())
# Initialize block Hessian
H = torch.zeros(total_params, total_params)
return {
'H': H,
'block_sizes': block_sizes,
'structure': 'block_sparse' # Many zero blocks
}4. Self-Attention Hessian (Building on Prior Work)
4.1 Self-Attention Mechanism
Consider a single-head self-attention layer:
where the attention matrix is:
4.2 Attention Hessian Decomposition
The Hessian decomposes using the Gauss-Newton approximation:
| Component | Definition | Origin |
|---|---|---|
| Outer-product Hessian | ||
| Functional Hessian |
4.3 Spectral Norm Bound for Attention
Theorem 1 (Hessian spectral norm): 1
Let be the spectral norm. For a single self-attention layer:
where involves terms like:
The full expression includes terms from:
- - interactions
- - interactions
- - interactions
- Diagonal terms (-)
4.4 Attention Hessian Implementation
class SelfAttentionHessian(nn.Module):
"""
Self-Attention with Hessian Analysis
"""
def __init__(self, d_model, n_heads, d_k=None, d_v=None):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_k or d_model // n_heads
self.d_v = d_v or d_model // n_heads
# Projections
self.W_Q = nn.Linear(d_model, self.d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, self.d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, self.d_v * n_heads, bias=False)
self.out_proj = nn.Linear(self.d_v * n_heads, d_model)
def forward(self, x, return_attention=False):
batch, seq_len, d_model = x.shape
# Linear projections
Q = self.W_Q(x).view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(batch, seq_len, self.n_heads, self.d_v).transpose(1, 2)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
# Output
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch, seq_len, -1)
output = self.out_proj(context)
if return_attention:
return output, attn_weights
return output
def compute_attention_hessian(self, x, target):
"""
Compute Hessian of loss w.r.t. attention parameters
Uses Gauss-Newton approximation for efficiency
"""
x = x.clone().requires_grad_(True)
# Forward pass
output, attn = self.forward(x, return_attention=True)
# MSE loss
loss = F.mse_loss(output, target)
# Backward to get gradients
loss.backward()
# Compute Jacobian structure
with torch.no_grad():
batch, seq_len, d_model = x.shape
L = seq_len
d_V = self.d_v * self.n_heads
# Extract gradients for each projection
grad_WQ = self.W_Q.weight.grad.clone()
grad_WK = self.W_K.weight.grad.clone()
grad_WV = self.W_V.weight.grad.clone()
# Estimate Hessian norms using eigenvalue bounds
# This is a simplified approximation
hessian_estimates = {
'H_WQ': torch.norm(grad_WQ) ** 2 / (grad_WQ.numel() ** 0.5),
'H_WK': torch.norm(grad_WK) ** 2 / (grad_WK.numel() ** 0.5),
'H_WV': torch.norm(grad_WV) ** 2 / (grad_WV.numel() ** 0.5),
}
return {
'loss': loss.item(),
'hessian_estimates': hessian_estimates,
'attention': attn
}
def analyze_attention_curvature(model, dataloader, device='cuda'):
"""
Analyze attention Hessian across training
"""
model.eval()
curvature_history = []
for batch_idx, (x, y) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
# Compute attention Hessian
result = model.compute_attention_hessian(x, y)
curvature_history.append({
'batch': batch_idx,
'loss': result['loss'],
'h_WQ': result['hessian_estimates']['H_WQ'],
'h_WK': result['hessian_estimates']['H_WK'],
'h_WV': result['hessian_estimates']['H_WV']
})
if batch_idx >= 100: # Analyze first 100 batches
break
return curvature_history5. Curvature Propagation Through Transformer Blocks
5.1 Block Structure Overview
A complete Transformer block (Pre-norm architecture) is:
Parameter set:
5.2 Block Hessian Assembly
The full block Hessian has the following structure:
Key properties:
- Sparse structure: Many zero blocks due to residual connections
- Cross-layer coupling: Through LayerNorm
- Block diagonal dominance: FFN and attention have relatively independent curvature
5.3 Curvature Flow Analysis
class CurvaturePropagationAnalyzer:
"""
Analyze how curvature propagates through Transformer blocks
"""
def __init__(self, model):
self.model = model
self.curvature_history = []
def compute_layer_curvature(self, x, layer_idx):
"""
Compute effective curvature for a specific layer
"""
layer = self.model.transformer.layers[layer_idx]
with torch.no_grad():
# Forward through layer
x_norm = layer.norm1(x)
attn_out, attn = layer.self_attn(x_norm, x_norm, x_norm, return_attention=True)
x = x + attn_out
x_norm2 = layer.norm2(x)
ffn_out = layer.mlp(x_norm2)
x = x + ffn_out
# Estimate local curvature
# Using gradient norm as proxy
x_grad = x.clone().requires_grad_(True)
output = self.model.head(x_grad.mean(dim=1))
# Simplified: use Fisher information as curvature proxy
# H ≈ F = E[∇ℓ∇ℓ^T]
return {
'layer': layer_idx,
'attn_curvature': attn.std().item(),
'gradient_scale': x.norm().item(),
'attention_scores': attn
}
def analyze_curvature_flow(self, x):
"""
Track curvature through all layers
"""
curvatures = []
for layer_idx in range(len(self.model.transformer.layers)):
layer_curv = self.compute_layer_curvature(x, layer_idx)
curvatures.append(layer_curv)
# Update x for next layer (in practice, this is done in forward pass)
x = self.model.transformer.layers[layer_idx](x)
self.curvature_history.append(curvatures)
return curvatures
def visualize_curvature_flow(self):
"""
Visualize how curvature changes across layers
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
layers = list(range(len(self.curvature_history[0])))
# Plot attention curvature
ax1 = axes[0, 0]
for step, curvatures in enumerate(self.curvature_history[:10]):
attn_curvs = [c['attn_curvature'] for c in curvatures]
ax1.plot(layers, attn_curvs, label=f'Step {step}', alpha=0.7)
ax1.set_xlabel('Layer')
ax1.set_ylabel('Attention Curvature')
ax1.set_title('Attention Curvature Flow')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot gradient scale
ax2 = axes[0, 1]
for step, curvatures in enumerate(self.curvature_history[:10]):
grad_scales = [c['gradient_scale'] for c in curvatures]
ax2.plot(layers, grad_scales, label=f'Step {step}', alpha=0.7)
ax2.set_xlabel('Layer')
ax2.set_ylabel('Gradient Scale')
ax2.set_title('Gradient Scale Flow')
ax2.legend()
ax2.grid(True, alpha=0.3)
# Average curvature by layer
ax3 = axes[1, 0]
avg_attn_curv = torch.stack([
torch.tensor([c['attn_curvature'] for c in step])
for step in self.curvature_history
]).mean(0)
ax3.bar(layers, avg_attn_curv.numpy())
ax3.set_xlabel('Layer')
ax3.set_ylabel('Average Curvature')
ax3.set_title('Average Attention Curvature by Layer')
# Heatmap of attention curvature
ax4 = axes[1, 1]
curv_matrix = torch.stack([
torch.tensor([c['attn_curvature'] for c in step])
for step in self.curvature_history
]).numpy()
im = ax4.imshow(curv_matrix, aspect='auto', cmap='viridis')
ax4.set_xlabel('Layer')
ax4.set_ylabel('Training Step')
ax4.set_title('Curvature Heatmap')
plt.colorbar(im, ax=ax4)
plt.tight_layout()
plt.savefig('curvature_flow.png', dpi=150)
plt.show()5.4 Mathematical Framework for Curvature Propagation
Key theorems for curvature flow:
-
LayerNorm stabilizes curvature: The normalization ensures that the input-dependent Hessian terms remain bounded
-
FFN preserves block structure: Due to ReLU’s zero Hessian, FFN Hessian is relatively simple
-
Attention introduces data-dependent coupling: The softmax Jacobian creates correlations across the sequence
6. Implications for Convergence Dynamics
6.1 Loss Landscape Convergence
As the dataset size increases, the loss landscape converges to a more stable structure:
Key insight: Larger datasets lead to:
- Smoother optimization landscapes
- Better-conditioned Hessians
- More predictable convergence
6.2 Hessian-Based Convergence Analysis
Theorem (Convergence bound): 1
For empirical loss with Hessian :
where depends on the spectral norm bounds derived in previous sections.
6.3 Practical Convergence Diagnostics
class ConvergenceDiagnostics:
"""
Use Hessian analysis for convergence diagnostics
"""
def __init__(self, model):
self.model = model
self.loss_history = []
self.grad_history = []
self.curvature_history = []
def compute_effective_conditioning(self, x, y):
"""
Estimate condition number from gradient statistics
"""
output = self.model(x)
loss = F.cross_entropy(output, y)
loss.backward()
# Collect gradients
grads = []
for p in self.model.parameters():
if p.grad is not None:
grads.append(p.grad.flatten())
all_grads = torch.cat(grads)
# Estimate "local" condition number using gradient statistics
grad_mean = all_grads.mean()
grad_std = all_grads.std()
grad_max = all_grads.abs().max()
# Effective condition number proxy
eff_cond = grad_max / (grad_std + 1e-8)
return {
'loss': loss.item(),
'grad_norm': all_grads.norm().item(),
'grad_mean': grad_mean.item(),
'grad_std': grad_std.item(),
'eff_condition_number': eff_cond.item()
}
def analyze_convergence_trajectory(self, dataloader, n_steps=100):
"""
Analyze convergence using Hessian-based metrics
"""
self.model.train()
for step, (x, y) in enumerate(dataloader):
if step >= n_steps:
break
# Compute diagnostics
diag = self.compute_effective_conditioning(x, y)
self.loss_history.append(diag['loss'])
self.grad_history.append(diag['grad_norm'])
self.curvature_history.append(diag['eff_condition_number'])
# Training step
self.model.zero_grad()
return self._summarize_convergence()
def _summarize_convergence(self):
"""
Summarize convergence properties
"""
import numpy as np
loss_arr = np.array(self.loss_history)
grad_arr = np.array(self.grad_history)
cond_arr = np.array(self.curvature_history)
return {
'loss_decreased': loss_arr[-1] < loss_arr[0],
'loss_reduction': (loss_arr[0] - loss_arr[-1]) / loss_arr[0],
'gradient_decay_rate': np.polyfit(range(len(grad_arr)), np.log(grad_arr + 1e-10), 1)[0],
'condition_number_trend': 'improving' if cond_arr[-1] < cond_arr[0] else 'degrading',
'final_condition': cond_arr[-1]
}6.4 Optimization Recommendations
Based on Hessian analysis:
| Observation | Recommendation |
|---|---|
| High condition number in early layers | Use adaptive optimizers (Adam, AdaGrad) |
| Negative eigenvalues present | Apply gradient clipping, use momentum |
| Curvature varies across layers | Layer-wise learning rate scaling |
| Hessian norm grows with depth | Use Pre-norm architecture |
7. Connections to Empirical Scaling Laws
7.1 Neural Scaling Laws Background
Empirical scaling laws describe how test loss scales with model size and dataset size :
Key observations:
- Larger models are more sample-efficient
- Performance improves predictably with scale
- Double descent phenomena at certain scales
7.2 Hessian Perspective on Scaling Laws
New perspective from full Hessian analysis:
-
Dataset size and curvature:
- Larger datasets better conditioned Hessians
- Smoother loss landscape better generalization
-
Model size and condition number:
- Larger models have larger parameter space
- Hessian spectrum widens: more large/small eigenvalues
- This explains “critical batch size” phenomena
-
Training dynamics:
- Phase transitions in curvature during training
- Initial phase: rapid curvature changes
- Later phase: landscape stabilization
7.3 Theoretical Bounds for Scaling
Theorem (Loss landscape evolution): 1
The difference between loss functions trained on and samples satisfies:
This provides:
- Theoretical justification for scaling law exponents
- Quantitative predictions for dataset size requirements
- Framework for compute-optimal training
7.4 Scaling Law Implementation
class ScalingLawAnalyzer:
"""
Analyze scaling laws from Hessian perspective
"""
def __init__(self, model_factory):
self.model_factory = model_factory
def compute_hessian_norm(self, model, dataloader, n_samples=100):
"""
Compute Hessian spectral norm (approximation)
"""
# Use power iteration for top eigenvalue
model.eval()
# Get a batch
x, y = next(iter(dataloader))
x, y = x[:n_samples], y[:n_samples]
# Simplified: compute gradient covariance
output = model(x)
loss = F.cross_entropy(output, y)
loss.backward()
# Collect gradients
grads = []
for p in model.parameters():
if p.grad is not None:
grads.append(p.grad.flatten())
all_grads = torch.cat(grads)
# Gradient covariance as Hessian proxy
# H ≈ E[∇ℓ∇ℓ^T] for well-trained networks
hessian_proxy = torch.outer(all_grads, all_grads)
# Estimate spectral norm
eigenvalues = torch.linalg.eigvalsh(hessian_proxy)
return {
'top_eigenvalue': eigenvalues[-1].item(),
'bottom_eigenvalue': eigenvalues[0].item(),
'condition_number': eigenvalues[-1].item() / (eigenvalues[0].item() + 1e-8),
'effective_rank': len(eigenvalues) / (eigenvalues.sum() / (eigenvalues[0] + 1e-8))
}
def analyze_scaling_with_size(self, model_sizes, dataset_sizes):
"""
Analyze how Hessian properties scale with model/data size
"""
results = []
for n_params in model_sizes:
for n_samples in dataset_sizes:
# Create model
model = self.model_factory(n_params)
# Compute Hessian properties
hess_props = self.compute_hessian_norm(model, self.dataloader)
results.append({
'n_params': n_params,
'n_samples': n_samples,
**hess_props
})
return pd.DataFrame(results)
def predict_optimal_dataset_size(self, model_size, target_condition):
"""
Predict optimal dataset size for target condition number
Based on theory: condition_number ∝ 1/√(dataset_size)
"""
# Empirical fit: use prior observations
alpha = 0.5 # Theoretical exponent
C = 1.0 # Scale constant
n_samples = (C / target_condition) ** (1/alpha)
return int(n_samples)8. Taylor Expansion Framework for Loss Difference Analysis
8.1 Taylor Expansion Foundation
Given two loss functions and :
Second-order Taylor expansion around :
8.2 Convergence Trajectory Analysis
Framework for quantifying convergence:
class TaylorExpansionAnalyzer:
"""
Taylor expansion-based framework for loss difference analysis
"""
def __init__(self, model):
self.model = model
self.hessian_cache = {}
def compute_taylor_coefficients(self, x, y, x_train):
"""
Compute Taylor expansion coefficients for loss difference
ΔL = L_{k+1} - L_k ≈ 1/2 * δw^T * H * δw
"""
# Compute losses
L_k = self.compute_loss(x, y, x_train[:-1])
L_k1 = self.compute_loss(x, y, x_train)
delta_L = L_k1 - L_k
# Estimate local curvature
x_grad = x.clone().requires_grad_(True)
output = self.model(x_grad)
loss = F.cross_entropy(output, y)
# Get gradients
grads = []
for p in self.model.parameters():
if p.grad is not None:
grads.append(p.grad.flatten())
all_grads = torch.cat(grads)
# Quadratic coefficient from gradient statistics
grad_mom2 = (all_grads ** 2).mean()
return {
'delta_loss': delta_L,
'quadratic_coefficient': grad_mom2.item(),
'gradient_norm': all_grads.norm().item()
}
def predict_convergence_trajectory(self, n_steps, learning_rate):
"""
Predict convergence using Taylor expansion
Based on: δw_{t+1} ≈ (I - ηH) * δw_t
"""
trajectory = []
delta_w_norm = 1.0 # Initial distance to optimum
for t in range(n_steps):
# Taylor prediction
delta_w_next = delta_w_norm * (1 - learning_rate * self.estimated_hessian_eigenvalue)
trajectory.append({
'step': t,
'delta_w_norm': delta_w_norm,
'delta_w_predicted': delta_w_next
})
delta_w_norm = delta_w_next
return trajectory
def estimate_hessian_from_trajectory(self, loss_trajectory, learning_rate):
"""
Estimate average Hessian eigenvalue from loss trajectory
Using: L_{t+1} - L_t ≈ -η * δw_t^T H ∇L
"""
losses = np.array([t['loss'] for t in loss_trajectory])
# Finite difference approximation
dL = np.diff(losses)
ddL = np.diff(dL)
# Approximate average curvature
avg_curvature = -2 * ddL.mean() / (learning_rate ** 2 + 1e-8)
return {
'avg_curvature': avg_curvature,
'convergence_rate': np.exp(np.polyfit(range(len(dL)), np.log(np.abs(dL) + 1e-10), 1)[0])
}8.3 Mathematical Framework
Key results from Taylor expansion analysis:
-
Second-order term dominance: For well-conditioned regions, quadratic term dominates
-
Convergence rate prediction:
-
Critical learning rate:
9. Code Examples in PyTorch
9.1 Complete Transformer Hessian Analysis Pipeline
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
class TransformerHessianAnalyzer:
"""
Complete Hessian analysis for Transformer models
"""
def __init__(self, model: nn.Module):
self.model = model
self.history = {
'hessian_norms': [],
'condition_numbers': [],
'eigenvalue_spectra': [],
'layer_curvatures': []
}
def compute_hessian_eigenvalues(self, x: torch.Tensor, y: torch.Tensor,
k: int = 20) -> torch.Tensor:
"""
Compute top-k Hessian eigenvalues using power iteration
"""
# Enable gradient computation
x = x.clone().detach().requires_grad_(True)
# Forward and backward pass
output = self.model(x)
loss = F.cross_entropy(output.view(-1, output.size(-1)), y.view(-1))
loss.backward()
# Collect flat gradient
grads = []
params = []
for p in self.model.parameters():
if p.grad is not None:
grads.append(p.grad.flatten())
params.append(p.data.flatten())
all_grads = torch.cat(grads).unsqueeze(1) # (n, 1)
all_params = torch.cat(params)
n = len(all_params)
# Random initialization for power iteration
v = torch.randn(n, 1)
v = v / v.norm()
eigenvalues = []
for _ in range(k):
# Hessian-vector product approximation
# H @ v ≈ (1/η) * (g(w + ηv) - g(w)) / η
eta = 1e-5
# Save original parameters
orig_params = all_params.clone()
# Perturb
all_params.data = orig_params + eta * v.squeeze()
# Compute new gradient
self._set_model_params(all_params)
output = self.model(x)
loss_new = F.cross_entropy(output.view(-1, output.size(-1)), y.view(-1))
loss_new.backward()
grads_new = []
for p in self.model.parameters():
if p.grad is not None:
grads_new.append(p.grad.flatten())
all_grads_new = torch.cat(grads_new).unsqueeze(1)
# Restore original
all_params.data = orig_params
self._set_model_params(all_params)
# Compute approximate Hv
Hv = (all_grads_new - all_grads) / eta
# Power iteration
v_new = Hv / (Hv.norm() + 1e-10)
# Rayleigh quotient
eigenvalue = (v.t() @ Hv).item() / (v.t() @ v).item()
eigenvalues.append(eigenvalue)
v = v_new
return torch.tensor(eigenvalues)
def _set_model_params(self, flat_params: torch.Tensor):
"""Set model parameters from flat tensor"""
idx = 0
for p in self.model.parameters():
size = p.numel()
p.data = flat_params[idx:idx+size].view(p.shape)
idx += size
def analyze_full_hessian(self, dataloader: torch.utils.data.DataLoader,
n_batches: int = 10) -> Dict:
"""
Comprehensive Hessian analysis
"""
self.model.eval()
all_eigenvalues = []
hessian_norms = []
for i, (x, y) in enumerate(dataloader):
if i >= n_batches:
break
x, y = x.cuda(), y.cuda()
# Compute eigenvalues
eigenvals = self.compute_hessian_eigenvalues(x, y, k=10)
all_eigenvalues.append(eigenvals)
# Hessian norm estimate
hessian_norm = eigenvals.max().item() if len(eigenvals) > 0 else 0
hessian_norms.append(hessian_norm)
self.history['hessian_norms'].append(hessian_norm)
all_eigenvals_cat = torch.cat(all_eigenvalues)
return {
'mean_hessian_norm': np.mean(hessian_norms),
'eigenvalue_mean': all_eigenvals_cat.mean().item(),
'eigenvalue_std': all_eigenvals_cat.std().item(),
'top_eigenvalue': all_eigenvals_cat.max().item(),
'spectral_gap': all_eigenvals_cat.max().item() - all_eigenvals_cat.min().item()
}
def analyze_layer_wise_curvature(self, x: torch.Tensor) -> List[Dict]:
"""
Analyze curvature layer by layer
"""
self.model.eval()
x = x.cuda().requires_grad_(True)
layer_curvatures = []
def hook_fn(name):
def hook(module, input, output):
# Compute local gradient norm as curvature proxy
if output.grad is not None:
curvature = output.grad.norm().item()
else:
curvature = 0
layer_curvatures.append({
'name': name,
'output_norm': output.norm().item(),
'curvature_proxy': curvature
})
return hook
# Register hooks
hooks = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.MultiheadAttention)):
hooks.append(module.register_forward_hook(hook_fn(name)))
# Forward pass
output = self.model(x)
loss = F.cross_entropy(output.view(-1, output.size(-1)), torch.randint(0, output.size(-1), x.shape[:2]).cuda())
loss.backward()
# Remove hooks
for h in hooks:
h.remove()
return layer_curvatures
class CurvatureAwareOptimizer:
"""
Optimizer that adapts based on Hessian information
"""
def __init__(self, model: nn.Module, base_lr: float = 1e-3,
curvature_adaptation: bool = True):
self.model = model
self.base_lr = base_lr
self.curvature_adaptation = curvature_adaptation
self.hessian_estimates = {}
# Adam optimizer
self.optimizer = torch.optim.Adam(
model.parameters(),
lr=base_lr,
betas=(0.9, 0.999),
eps=1e-8
)
def estimate_local_curvature(self, x: torch.Tensor, y: torch.Tensor) -> Dict:
"""
Estimate local curvature for adaptive learning rate
"""
self.model.eval()
x, y = x.cuda(), y.cuda()
# Compute gradient
output = self.model(x)
loss = F.cross_entropy(output.view(-1, output.size(-1)), y.view(-1))
loss.backward()
# Collect gradient statistics
grad_norms = []
for p in self.model.parameters():
if p.grad is not None:
grad_norms.append(p.grad.norm().item())
# Estimate condition number proxy
max_grad = max(grad_norms) if grad_norms else 1
min_grad = min(grad_norms) if grad_norms else 1
cond_proxy = max_grad / (min_grad + 1e-8)
self.hessian_estimates['condition_number'] = cond_proxy
return {
'condition_number': cond_proxy,
'max_gradient': max_grad,
'min_gradient': min_grad,
'loss': loss.item()
}
def step_with_curvature_adaptation(self, x: torch.Tensor, y: torch.Tensor):
"""
Perform optimization step with curvature-aware learning rate
"""
if self.curvature_adaptation:
curv = self.estimate_local_curvature(x, y)
# Adapt learning rate based on condition number
adaptation_factor = min(1.0, 1.0 / (curv['condition_number'] ** 0.25 + 1e-8))
# Update optimizer's learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.base_lr * adaptation_factor
# Standard optimization step
self.optimizer.step()
self.optimizer.zero_grad()
return self.hessian_estimates.copy()
def demo_hessian_analysis():
"""
Demonstration of Transformer Hessian analysis
"""
# Create a small Transformer model
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size=1000, d_model=64, nhead=4, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = nn.Parameter(torch.randn(1, 100, d_model) * 0.1)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=256, batch_first=True),
num_layers=num_layers
)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x) + self.pos_encoder[:, :x.size(1), :]
x = self.transformer(x)
return self.fc(x)
# Create model and analyzer
model = SimpleTransformer().cuda()
analyzer = TransformerHessianAnalyzer(model)
optimizer = CurvatureAwareOptimizer(model)
# Create dummy data
batch_size = 8
seq_len = 32
vocab_size = 1000
# Training loop (simplified)
for step in range(50):
x = torch.randint(0, vocab_size, (batch_size, seq_len)).cuda()
y = torch.randint(0, vocab_size, (batch_size, seq_len)).cuda()
# Analyze every 10 steps
if step % 10 == 0:
results = analyzer.analyze_full_hessian([(x, y)], n_batches=1)
print(f"Step {step}: Hessian norm = {results['mean_hessian_norm']:.4f}")
# Optimizer step
optimizer.step_with_curvature_adaptation(x, y)
print("Hessian analysis complete!")
return analyzer, optimizer
if __name__ == "__main__":
demo_hessian_analysis()9.2 Visualization and Analysis Utilities
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_hessian_spectrum(eigenvalues: torch.Tensor,
title: str = "Hessian Eigenvalue Spectrum"):
"""
Visualize Hessian eigenvalue distribution
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
eigenvalues = eigenvalues.cpu().numpy()
# Histogram
axes[0].hist(eigenvalues, bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Eigenvalue')
axes[0].set_ylabel('Count')
axes[0].set_title('Eigenvalue Distribution')
# Cumulative
sorted_eig = np.sort(eigenvalues)[::-1]
axes[1].plot(range(len(sorted_eig)), sorted_eig, 'b-', linewidth=2)
axes[1].set_xlabel('Eigenvalue Index')
axes[1].set_ylabel('Eigenvalue')
axes[1].set_title('Sorted Eigenvalues')
axes[1].set_yscale('log')
# Statistics
stats_text = f"""Statistics:
Max: {eigenvalues.max():.4f}
Min: {eigenvalues.min():.4f}
Mean: {eigenvalues.mean():.4f}
Std: {eigenvalues.std():.4f}
Condition: {eigenvalues.max()/eigenvalues.min():.2f}"""
axes[2].text(0.1, 0.5, stats_text, transform=axes[2].transAxes,
fontsize=12, verticalalignment='center',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
axes[2].axis('off')
axes[2].set_title('Summary Statistics')
plt.suptitle(title)
plt.tight_layout()
plt.savefig('hessian_spectrum.png', dpi=150)
plt.show()
def visualize_curvature_flow(curvature_history: List[List[Dict]]):
"""
Visualize curvature propagation through layers
"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Aggregate data
n_layers = len(curvature_history[0])
n_steps = len(curvature_history)
layer_curvatures = np.zeros((n_steps, n_layers))
layer_norms = np.zeros((n_steps, n_layers))
for step, layers in enumerate(curvature_history):
for layer_idx, data in enumerate(layers):
layer_curvatures[step, layer_idx] = data.get('curvature_proxy', 0)
layer_norms[step, layer_idx] = data.get('output_norm', 0)
# Heatmap: curvature
sns.heatmap(layer_curvatures, ax=axes[0, 0], cmap='viridis', aspect='auto')
axes[0, 0].set_xlabel('Layer')
axes[0, 0].set_ylabel('Training Step')
axes[0, 0].set_title('Curvature Flow (Gradient Norm)')
# Heatmap: activation norms
sns.heatmap(layer_norms, ax=axes[0, 1], cmap='plasma', aspect='auto')
axes[0, 1].set_xlabel('Layer')
axes[0, 1].set_ylabel('Training Step')
axes[0, 1].set_title('Activation Norm Flow')
# Line plot: average curvature per layer
for layer_idx in range(n_layers):
axes[1, 0].plot(layer_curvatures[:, layer_idx],
label=f'Layer {layer_idx}', alpha=0.7)
axes[1, 0].set_xlabel('Training Step')
axes[1, 0].set_ylabel('Curvature Proxy')
axes[1, 0].set_title('Curvature by Layer Over Training')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Bar plot: final curvature
final_curvatures = layer_curvatures[-1]
axes[1, 1].bar(range(n_layers), final_curvatures)
axes[1, 1].set_xlabel('Layer')
axes[1, 1].set_ylabel('Final Curvature')
axes[1, 1].set_title('Final Layer Curvatures')
axes[1, 1].set_xticks(range(n_layers))
plt.tight_layout()
plt.savefig('curvature_flow.png', dpi=150)
plt.show()
def plot_scaling_law_predictions(model_sizes: List[int],
hessian_norms: List[float],
dataset_sizes: List[int]):
"""
Plot predictions from Hessian-based scaling law analysis
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Hessian norm vs model size
axes[0].loglog(model_sizes, hessian_norms, 'bo-', linewidth=2, markersize=8)
axes[0].set_xlabel('Model Parameters')
axes[0].set_ylabel('Hessian Norm')
axes[0].set_title('Hessian Norm vs Model Size')
axes[0].grid(True, alpha=0.3)
# Fit power law
log_params = np.log(model_sizes)
log_hessian = np.log(hessian_norms)
slope, intercept = np.polyfit(log_params, log_hessian, 1)
# Plot fit
fit_line = np.exp(intercept) * np.array(model_sizes) ** slope
axes[0].loglog(model_sizes, fit_line, 'r--', linewidth=2,
label=f'Power law: N^{slope:.2f}')
axes[0].legend()
# Dataset size effect
dataset_sizes_arr = np.array(dataset_sizes)
# Theoretical: Hessian norm ∝ 1/√(dataset_size)
theoretical = 1.0 / np.sqrt(dataset_sizes_arr)
theoretical = theoretical / theoretical[0] * hessian_norms[0]
axes[1].plot(dataset_sizes_arr, hessian_norms[:len(dataset_sizes)],
'bo-', linewidth=2, markersize=8, label='Observed')
axes[1].plot(dataset_sizes_arr, theoretical, 'r--', linewidth=2,
label=r'Theoretical: $1/\sqrt{D}$')
axes[1].set_xlabel('Dataset Size')
axes[1].set_ylabel('Hessian Norm')
axes[1].set_title('Hessian Norm vs Dataset Size')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('scaling_laws.png', dpi=150)
plt.show()10. References to arXiv:2510.16927
10.1 Paper Overview
Title: Closing the Curvature Gap: Full Transformer Hessians and Their Implications for Scaling Laws
Authors: Egor Petrov, Nikita Kiselev, Vladislav Meshkov, Andrey Grabovoy
Institution: Yandex BRAIn Lab, Moscow State University
arXiv: arXiv:2510.16927
10.2 Key Contributions
| Contribution | Description |
|---|---|
| Complete Hessian Characterization | First full Hessian analysis including LayerNorm and FFN |
| Explicit Second-Order Expressions | Closed-form derivatives for all Transformer components |
| Loss Landscape Convergence Bounds | Rigorous convergence rates |
| Scaling Law Implications | Theoretical framework connecting Hessian to empirical scaling laws |
| Taylor Expansion Framework | New method for analyzing convergence trajectories |
10.3 Theoretical Results Summary
Theorem Summary:
-
LayerNorm Jacobian (Theorem 2): Complete derivation of
-
LayerNorm Hessian (Theorem 3): First complete second-order expression
-
Self-Attention Hessian Bound (Theorem 1): Spectral norm bound for attention Hessian
-
Transformer Block Derivative (Theorem 4): Full derivative expressions for complete blocks
-
Loss Convergence Bound: rate for landscape stabilization
10.4 Implications and Applications
The paper’s results enable:
-
Better Optimizer Design: Curvature-aware optimization strategies
-
Data Budgeting: Theoretical guidance on dataset size requirements
-
Architecture Search: Hessian-informed architecture design
-
Generalization Bounds: Connection between curvature and generalization
-
Training Diagnostics: Early detection of optimization difficulties
11. Summary and Future Directions
11.1 Key Takeaways
-
Complete Hessian Analysis: This work completes the theoretical characterization of Transformer Hessians by providing explicit expressions for LayerNorm and FFN components.
-
Curvature Propagation: Understanding how curvature flows through Transformer blocks reveals optimization challenges and informs architectural decisions.
-
Scaling Law Connection: Hessian analysis provides a mechanistic explanation for empirical scaling laws, connecting dataset size to loss landscape smoothness.
-
Practical Implications: The theoretical framework enables better optimization strategies, data budgeting, and training diagnostics.
11.2 Open Questions
| Question | Direction |
|---|---|
| Higher-order interactions | Extend analysis to third-order derivatives |
| Non-Gaussian losses | Analyze for contrastive, GAN, and other losses |
| Dynamic analysis | Time-varying Hessian during training |
| Large-scale validation | Empirical validation on billion-parameter models |
| Alternative architectures | Apply framework to Mamba, RWKV, etc. |
11.3 Further Reading
- Hessian Spectral Analysis of Transformers — Prior work on attention Hessian
- Normalization and Gradient Flow — LayerNorm’s role in training stability
- Adaptive Optimizer Theory — Connection to Adam and related methods
- Transformer Scaling Laws — Empirical scaling behavior
- Second-Order Optimization — Practical Hessian-based methods
References
Last updated: 2026-05-03