TL;DR
Normalization keeps activations in a reasonable range during training. LayerNorm is standard for transformers, RMSNorm is faster and used in modern LLMs like Llama. Pre-norm (normalize before sublayer) is more stable than Post-norm for deep networks.
Visual Overview
┌───────────────────────────────────────────────────────────┐ │ │ │ During training, each layer's input distribution │ │ changes as previous layers update. │ │ │ │ Epoch 1: Layer 3 receives inputs with mean=0, std=1 │ │ Epoch 2: Layer 2 weights changed → Layer 3 now sees │ │ mean=0.5, std=2 │ │ Epoch 3: More drift → Layer 3 sees mean=1.2, std=3.5 │ │ │ │ Layer 3 keeps having to re-adapt to shifting inputs. │ │ Training is slower and less stable. │ │ │ │ SOLUTION: Normalize activations to have consistent │ │ statistics. │ │ │ └───────────────────────────────────────────────────────────┘
Batch Normalization
Normalizes across the batch dimension. Each feature is normalized using statistics from the current batch.
┌───────────────────────────────────────────────────────────┐ │ │ │ Input: x with shape (batch_size, features) │ │ │ │ For each feature f: │ │ mu_f = mean(x[:, f]) # mean across batch │ │ sigma_f = std(x[:, f]) # std across batch │ │ │ │ x_norm[:, f] = (x[:, f] - mu_f) / (sigma_f + eps) │ │ │ │ Then apply learnable scale and shift: │ │ output = gamma × x_norm + beta │ │ │ │ gamma and beta are learned per feature. │ │ │ └───────────────────────────────────────────────────────────┘ VISUAL: NORMALIZE ACROSS BATCH ┌───────────────────────────────────────────────────────────┐ │ │ │ Feature 1 Feature 2 Feature 3 │ │ ┌──────────┬──────────┬──────────┐ │ │ Batch 1│ 2.1 │ 0.5 │ -1.2 │ │ │ ├──────────┼──────────┼──────────┤ │ │ Batch 2│ 1.8 │ 0.7 │ -0.9 │ ← Normalize │ │ ├──────────┼──────────┼──────────┤ down each │ │ Batch 3│ 2.3 │ 0.4 │ -1.1 │ column │ │ └──────────┴──────────┴──────────┘ │ │ ▼ ▼ ▼ │ │ mu=2.07 mu=0.53 mu=-1.07 │ │ │ └───────────────────────────────────────────────────────────┘
When it works well:
- CNNs (computer vision)
- Large batch sizes (stable statistics)
- Training (has batch to compute stats)
Problems:
- Needs batch statistics at inference (use running average)
- Small batches -> noisy statistics -> unstable
- Batch size 1 -> undefined (no batch to normalize over)
Layer Normalization
Normalizes across the feature dimension. Each sample is normalized independently.
┌───────────────────────────────────────────────────────────┐ │ │ │ Input: x with shape (batch_size, features) │ │ │ │ For each sample i: │ │ mu_i = mean(x[i, :]) # mean across features │ │ sigma_i = std(x[i, :]) # std across features │ │ │ │ x_norm[i, :] = (x[i, :] - mu_i) / (sigma_i + eps) │ │ │ │ Then apply learnable scale and shift: │ │ output = gamma × x_norm + beta │ │ │ └───────────────────────────────────────────────────────────┘ VISUAL: NORMALIZE ACROSS FEATURES ┌───────────────────────────────────────────────────────────┐ │ │ │ Feature 1 Feature 2 Feature 3 │ │ ┌──────────┬──────────┬──────────┐ │ │ Batch 1│ 2.1 │ 0.5 │ -1.2 │ → Normalize │ │ ├──────────┼──────────┼──────────┤ this row │ │ Batch 2│ 1.8 │ 0.7 │ -0.9 │ → Normalize │ │ ├──────────┼──────────┼──────────┤ this row │ │ Batch 3│ 2.3 │ 0.4 │ -1.1 │ → Normalize │ │ └──────────┴──────────┴──────────┘ this row │ │ │ └───────────────────────────────────────────────────────────┘
When it works well:
- Transformers (the standard)
- RNNs, LSTMs
- Any batch size (including 1)
- Inference (no batch dependency)
Why transformers use LayerNorm:
- Sequence length varies -> batch statistics meaningless
- Inference often batch_size=1
- Each token normalized independently
RMSNorm (Root Mean Square Normalization)
Simplified LayerNorm: only variance normalization, no mean centering.
┌───────────────────────────────────────────────────────────┐ │ │ │ Standard LayerNorm: │ │ x_norm = (x - mean(x)) / std(x) │ │ │ │ RMSNorm: │ │ x_norm = x / RMS(x) │ │ │ │ where RMS(x) = sqrt(mean(x²)) │ │ │ │ No mean subtraction. Just scale by root-mean-square. │ │ │ └───────────────────────────────────────────────────────────┘
Why it works:
- Mean centering turns out to be less important than variance scaling
- Removing mean computation saves ~7% training time
- Quality is equivalent or better in practice
Used in: Llama, Llama 2, Mistral, most modern LLMs
# RMSNorm implementation
def rmsnorm(x, weight, eps=1e-6):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
return weight * (x / rms)
Pre-Norm vs Post-Norm
Where you place normalization matters for training stability.
POST-NORM (Original Transformer) ┌───────────────────────────────────────────────────────────┐ │ │ │ x = x + Attention(x) │ │ x = LayerNorm(x) ← Norm AFTER residual │ │ x = x + FFN(x) │ │ x = LayerNorm(x) │ │ │ │ Problem: Gradients must flow through LayerNorm │ │ Can cause instability in deep networks │ │ │ └───────────────────────────────────────────────────────────┘ PRE-NORM (Modern Transformers) ┌───────────────────────────────────────────────────────────┐ │ │ │ x = x + Attention(LayerNorm(x)) ← Norm BEFORE │ │ x = x + FFN(LayerNorm(x)) │ │ │ │ Advantages: │ │ • Residual stream is "clean" (just additions) │ │ • Gradients flow directly through residual path │ │ • More stable for deep networks │ │ • Easier to train without careful LR tuning │ │ │ └───────────────────────────────────────────────────────────┘
Which to use: Pre-norm for new models. Post-norm only if replicating original BERT/GPT-2.
Comparison Table
| Aspect | BatchNorm | LayerNorm | RMSNorm |
|---|---|---|---|
| Normalizes across | Batch | Features | Features |
| Works with batch=1 | No | Yes | Yes |
| Needs running stats | Yes | No | No |
| Mean centering | Yes | Yes | No |
| Speed | Baseline | Baseline | ~7% faster |
| Used in | CNNs | Transformers | Modern LLMs |
Debugging Normalization Issues
TRAINING INSTABILITY (LOSS SPIKES) ┌───────────────────────────────────────────────────────────┐ │ │ │ Symptoms: │ │ • Loss suddenly spikes during training │ │ • Gradients explode intermittently │ │ │ │ Causes: │ │ • Post-norm architecture with deep network │ │ • Missing normalization somewhere │ │ • Norm placed incorrectly │ │ │ │ Debug steps: │ │ 1. Switch to pre-norm if using post-norm │ │ 2. Check every sublayer has normalization │ │ 3. Verify norm is before attention/FFN, not after │ │ 4. Reduce learning rate │ │ │ └───────────────────────────────────────────────────────────┘ ACTIVATIONS GROWING UNBOUNDED ┌───────────────────────────────────────────────────────────┐ │ │ │ Symptoms: │ │ • Activation magnitudes grow over layers │ │ • Eventually overflow to NaN │ │ │ │ Causes: │ │ • Missing normalization layer │ │ • Residual accumulation without norm │ │ • Wrong norm dimension │ │ │ │ Debug steps: │ │ 1. Print activation statistics per layer │ │ 2. Verify norm is applied (gamma, beta params exist) │ │ 3. Check norm dimension matches input shape │ │ │ └───────────────────────────────────────────────────────────┘
When This Matters
| Situation | What to know |
|---|---|
| Reading transformer code | LayerNorm before attention/FFN (pre-norm) |
| Understanding Llama/Mistral | RMSNorm, not LayerNorm |
| Training instability | Switch to pre-norm, check norm placement |
| Batch size constraints | LayerNorm works with any batch size |
| Optimizing inference speed | RMSNorm is slightly faster |
| Porting CNN techniques | BatchNorm doesn’t work for transformers |
| Understanding model configs | ”norm_eps” is the epsilon in denominator |
Production signal