TL;DR
AdamW is the standard optimizer for transformers. Use warmup to prevent early instability, cosine decay for pre-training, linear decay for fine-tuning, and gradient clipping to prevent explosions. Fine-tuning needs 10-100x smaller learning rates than pre-training.
Visual Overview
SGD UPDATE ┌───────────────────────────────────────────────────────────┐ │ │ │ w = w - lr x gradient │ │ │ │ Where: │ │ w = weights │ │ lr = learning rate │ │ gradient = dLoss/dw │ │ │ └───────────────────────────────────────────────────────────┘ SGD PROBLEMS ┌───────────────────────────────────────────────────────────┐ │ │ │ 1. OSCILLATION IN VALLEYS │ │ Loss surface has steep sides, shallow floor │ │ SGD bounces side-to-side, slow progress forward │ │ │ │ / / / │ │ / / / │ │ / / / ──→ minimum │ │ │←bounce→│←bounce→│ │ │ │ │ 2. SAME LEARNING RATE FOR ALL │ │ Some parameters need big updates, others small │ │ Single LR can't satisfy both │ │ │ └───────────────────────────────────────────────────────────┘
Momentum
Add “velocity” to SGD. Accumulate gradient direction over time.
MOMENTUM UPDATE ┌───────────────────────────────────────────────────────────┐ │ │ │ velocity = beta x velocity + gradient │ │ w = w - lr x velocity │ │ │ │ Where: │ │ beta = momentum coefficient (typically 0.9) │ │ velocity = accumulated gradient direction │ │ │ └───────────────────────────────────────────────────────────┘ MOMENTUM INTUITION ┌───────────────────────────────────────────────────────────┐ │ │ │ Ball rolling downhill. │ │ │ │ Without momentum: │ │ Step 1: gradient = [1, 0] → move [1, 0] │ │ Step 2: gradient = [-1, 0.1] → move [-1, 0.1] │ │ (oscillating!) │ │ │ │ With momentum (beta=0.9): │ │ Step 1: velocity = [1, 0] → move [1, 0] │ │ Step 2: velocity = 0.9x[1,0] + [-1, 0.1] │ │ = [0.9, 0] + [-1, 0.1] │ │ = [-0.1, 0.1] → much smaller │ │ oscillation! │ │ │ │ Consistent direction gets amplified. │ │ Oscillating direction gets dampened. │ │ │ └───────────────────────────────────────────────────────────┘
Adam (Adaptive Moment Estimation)
Combines momentum with adaptive learning rates per parameter.
ADAM UPDATE ┌───────────────────────────────────────────────────────────┐ │ │ │ m = beta1 x m + (1-beta1) x gradient # 1st moment │ │ v = beta2 x v + (1-beta2) x gradient^2 # 2nd moment │ │ │ │ m_hat = m / (1 - beta1^t) # Bias correct │ │ v_hat = v / (1 - beta2^t) │ │ │ │ w = w - lr x m_hat / (sqrt(v_hat) + eps) │ │ │ │ Default hyperparameters: │ │ beta1 = 0.9 (momentum decay) │ │ beta2 = 0.999 (variance decay) │ │ eps = 1e-8 (numerical stability) │ │ │ └───────────────────────────────────────────────────────────┘ WHY ADAM WORKS WELL ┌───────────────────────────────────────────────────────────┐ │ │ │ 1. MOMENTUM (m) │ │ Same as SGD+momentum - damps oscillation │ │ │ │ 2. ADAPTIVE LR (v) │ │ Parameters with large gradients → smaller LR │ │ Parameters with small gradients → larger LR │ │ │ │ High-gradient param: lr / sqrt(large_v) = small step │ │ Low-gradient param: lr / sqrt(small_v) = large step │ │ │ │ 3. BIAS CORRECTION │ │ Early steps: m and v are biased toward 0 │ │ Correction compensates for this │ │ │ └───────────────────────────────────────────────────────────┘
Adam is the default choice for most deep learning.
AdamW
Adam with proper weight decay. The standard for transformers.
┌───────────────────────────────────────────────────────────┐ │ │ │ Adam + L2 regularization (WRONG): │ │ gradient = task_gradient + lambda x w │ │ ... adam update with this gradient ... │ │ │ │ Problem: Weight decay mixed into adaptive LR │ │ High-variance params get less regularization │ │ │ │ AdamW (CORRECT): │ │ gradient = task_gradient # NO weight decay here │ │ ... adam update with this gradient ... │ │ w = w - lr x lambda x w # Decay applied AFTER │ │ │ │ Weight decay truly decoupled from optimization. │ │ │ └───────────────────────────────────────────────────────────┘
Always use AdamW for transformers, not Adam with L2.
Learning Rate Schedules
Learning rate should change during training. High initially, lower later.
Linear Decay
┌───────────────────────────────────────────────────────────┐ │ │ │ lr = initial_lr x (1 - step/total_steps) │ │ │ │ lr │ │ │ │ │ 1 │* │ │ │ │ │ │ │ │ 0.5 │ │ │ │ │ │ 0 │ * │ │ └────────── │ │ 0 steps T │ │ │ └───────────────────────────────────────────────────────────┘
Cosine Decay
┌───────────────────────────────────────────────────────────┐ │ │ │ lr = lr_min + 0.5 x (lr_max - lr_min) │ │ x (1 + cos(pi x step/total_steps)) │ │ │ │ lr │ │ │ │ │ 1 │* │ │ │ │ │ │ │ │ 0.5 │ │ │ │ ___ │ │ 0 │ * │ │ └────────── │ │ 0 steps T │ │ │ │ Smooth decay, lingers longer at mid-range LR. │ │ Often works better than linear. │ │ │ └───────────────────────────────────────────────────────────┘
Warmup
┌───────────────────────────────────────────────────────────┐ │ │ │ Start with very low LR, ramp up, then decay. │ │ │ │ lr │ │ │ *─── │ │ │ / │ │ │ / │ │ 0.5 │ / │ │ │ / │ │ 0 │*─/ * │ │ └────────────────────── │ │ 0 warmup peak T │ │ │ │ Warmup prevents early instability. │ │ Gradients are noisy at start (random weights). │ │ High LR + noisy gradients = explosion. │ │ │ └───────────────────────────────────────────────────────────┘
Typical warmup: 1-5% of total training steps.
Common Configurations
Pre-training LLM
Optimizer: AdamW beta1 = 0.9, beta2 = 0.95 Weight decay: 0.1 LR: 1e-4 to 3e-4 Schedule: Cosine decay with warmup Warmup: 2000 steps (or 1% of total)
Fine-tuning
Optimizer: AdamW beta1 = 0.9, beta2 = 0.999 Weight decay: 0.01 LR: 1e-5 to 5e-5 (10-100x smaller than pre-training) Schedule: Linear decay with warmup Warmup: 3-5% of total steps
Quick Reference
| Scenario | LR | Schedule | Warmup |
|---|---|---|---|
| Pre-training LLM | 1e-4 - 3e-4 | Cosine | 1-2% |
| Fine-tuning LLM | 1e-5 - 5e-5 | Linear | 3-5% |
| Fine-tuning BERT | 2e-5 - 5e-5 | Linear | 10% |
| Training CNN | 1e-3 | Step decay | None |
Gradient Clipping
Limit gradient magnitude to prevent explosions.
┌───────────────────────────────────────────────────────────┐ │ │ │ Clip by global norm (most common): │ │ total_norm = sqrt(SUM(gradient^2)) │ │ if total_norm > max_norm: │ │ gradient = gradient x (max_norm / total_norm) │ │ │ │ Typical max_norm: 1.0 │ │ │ │ Prevents single bad batch from destroying model. │ │ │ └───────────────────────────────────────────────────────────┘
When to use:
- Always for transformers
- When training is unstable
- When loss spikes occasionally
Debugging Optimization
LOSS NOT DECREASING ┌───────────────────────────────────────────────────────────┐ │ │ │ Symptoms: │ │ • Loss stays flat from start │ │ • Or decreases very slowly │ │ │ │ Causes: │ │ • Learning rate too low │ │ • Warmup too long │ │ • Wrong optimizer config │ │ │ │ Debug steps: │ │ 1. Try 10x higher LR │ │ 2. Reduce warmup steps │ │ 3. Check gradient values (should be non-zero) │ │ 4. Verify optimizer is updating weights │ │ │ └───────────────────────────────────────────────────────────┘ LOSS EXPLODES (NaN) ┌───────────────────────────────────────────────────────────┐ │ │ │ Symptoms: │ │ • Loss suddenly becomes NaN │ │ • Training crashes │ │ │ │ Causes: │ │ • Learning rate too high │ │ • Missing gradient clipping │ │ • No warmup │ │ │ │ Debug steps: │ │ 1. Add gradient clipping (max_norm=1.0) │ │ 2. Reduce LR by 10x │ │ 3. Add warmup (5% of steps) │ │ 4. Check for numerical issues in data │ │ │ └───────────────────────────────────────────────────────────┘ FINE-TUNING DESTROYS MODEL ┌───────────────────────────────────────────────────────────┐ │ │ │ Symptoms: │ │ • After fine-tuning, model is worse than base │ │ • "Catastrophic forgetting" │ │ │ │ Causes: │ │ • Learning rate too high │ │ • No warmup │ │ • Weight decay too high │ │ │ │ Debug steps: │ │ 1. Use much smaller LR (1e-5 or lower) │ │ 2. Add warmup │ │ 3. Reduce weight decay │ │ 4. Consider LoRA instead of full fine-tuning │ │ │ └───────────────────────────────────────────────────────────┘
When This Matters
| Situation | What to know |
|---|---|
| Training transformers | Use AdamW, not Adam |
| Fine-tuning | LR 10-100x smaller than pre-training |
| Training unstable | Add warmup, gradient clipping |
| Loss not decreasing | Try higher LR |
| Loss exploding | Lower LR, add gradient clipping |
| Understanding configs | beta1=momentum, beta2=variance averaging |
| Choosing schedule | Cosine for pre-training, linear for fine-tuning |
See It In Action
- Backpropagation Explainer - ~120 second animated visual explanation showing gradient descent weight updates
Production signal