AdamW is the standard optimizer for transformers. Use warmup to prevent early instability, cosine decay for pre-training, linear decay for fine-tuning, and gradient clipping to prevent explosions. Fine-tuning needs 10-100x smaller learning rates than pre-training.
Visual Overview
SGD Update and Problems
SGD Update and Problems
SGD UPDATE┌───────────────────────────────────────────────────────────┐│││ w = w - lr x gradient ││││ Where: ││ w = weights ││ lr = learning rate ││ gradient = dLoss/dw │││└───────────────────────────────────────────────────────────┘SGD PROBLEMS┌───────────────────────────────────────────────────────────┐│││ 1. OSCILLATION IN VALLEYS││ Loss surface has steep sides, shallow floor ││ SGD bounces side-to-side, slow progress forward ││││ / / / ││ / / / ││ / / / ──→minimum│││←bounce→│←bounce→│││││ 2. SAME LEARNING RATE FOR ALL││ Some parameters need big updates, others small ││Single LR can't satisfy both│││└───────────────────────────────────────────────────────────┘
Momentum
Add “velocity” to SGD. Accumulate gradient direction over time.
Momentum Update and Intuition
Momentum Update and Intuition
MOMENTUM UPDATE┌───────────────────────────────────────────────────────────┐│││ velocity = beta x velocity + gradient ││ w = w - lr x velocity ││││ Where: ││ beta = momentum coefficient (typically 0.9) ││ velocity = accumulated gradient direction │││└───────────────────────────────────────────────────────────┘MOMENTUM INTUITION┌───────────────────────────────────────────────────────────┐│││ Ball rolling downhill. ││││ Without momentum: ││ Step 1: gradient = [1, 0] → move [1, 0] ││ Step 2: gradient = [-1, 0.1] → move [-1, 0.1] ││ (oscillating!) ││││ With momentum (beta=0.9): ││ Step 1: velocity = [1, 0] → move [1, 0] ││ Step 2: velocity = 0.9x[1,0] + [-1, 0.1] ││ = [0.9, 0] + [-1, 0.1] ││ = [-0.1, 0.1] → much smaller ││ oscillation! ││││Consistent direction gets amplified. ││Oscillating direction gets dampened. │││└───────────────────────────────────────────────────────────┘
Adam (Adaptive Moment Estimation)
Combines momentum with adaptive learning rates per parameter.
Adam Update and Why It Works
Adam Update and Why It Works
ADAM UPDATE┌───────────────────────────────────────────────────────────┐│││ m = beta1 x m + (1-beta1) x gradient # 1st moment ││ v = beta2 x v + (1-beta2) x gradient^2 # 2nd moment ││││ m_hat = m / (1 - beta1^t) # Bias correct ││ v_hat = v / (1 - beta2^t) ││││ w = w - lr x m_hat / (sqrt(v_hat) + eps) ││││ Default hyperparameters: ││ beta1 = 0.9 (momentum decay) ││ beta2 = 0.999 (variance decay) ││ eps = 1e-8 (numerical stability) │││└───────────────────────────────────────────────────────────┘WHY ADAM WORKS WELL┌───────────────────────────────────────────────────────────┐│││ 1. MOMENTUM (m) ││ Same as SGD+momentum - damps oscillation││││ 2. ADAPTIVE LR (v) ││ Parameters with large gradients → smaller LR ││ Parameters with small gradients → larger LR ││││ High-gradient param: lr / sqrt(large_v) = small step ││ Low-gradient param: lr / sqrt(small_v) = large step ││││ 3. BIAS CORRECTION││ Early steps: m and v are biased toward 0 ││Correction compensates for this │││└───────────────────────────────────────────────────────────┘
Adam is the default choice for most deep learning.
AdamW
Adam with proper weight decay. The standard for transformers.
AdamW vs Adam + L2
AdamW vs Adam + L2
ADAMW VS ADAM + L2┌───────────────────────────────────────────────────────────┐│││Adam + L2 regularization (WRONG): ││ gradient = task_gradient + lambda x w ││ ... adam update with this gradient ... ││││Problem: Weight decay mixed into adaptive LR││ High-variance params get less regularization ││││AdamW (CORRECT): ││ gradient = task_gradient # NO weight decay here ││ ... adam update with this gradient ... ││ w = w - lr x lambda x w # Decay applied AFTER ││││Weight decay truly decoupled from optimization. │││└───────────────────────────────────────────────────────────┘
Always use AdamW for transformers, not Adam with L2.
Learning Rate Schedules
Learning rate should change during training. High initially, lower later.
Linear Decay
Linear Decay
Linear Decay
LINEAR DECAY┌───────────────────────────────────────────────────────────┐│││ lr = initial_lr x (1 - step/total_steps) ││││ lr │││││ 1 │* ││││││││ 0.5 ││││││ 0 │ * ││└──────────││ 0 steps T │││└───────────────────────────────────────────────────────────┘
Cosine Decay
Cosine Decay
Cosine Decay
COSINE DECAY┌───────────────────────────────────────────────────────────┐│││ lr = lr_min + 0.5 x (lr_max - lr_min) ││ x (1 + cos(pi x step/total_steps)) ││││ lr │││││ 1 │* ││││││││ 0.5 ││││ ___ ││ 0 │ * ││└──────────││ 0 steps T ││││Smooth decay, lingers longer at mid-range LR. ││Often works better than linear. │││└───────────────────────────────────────────────────────────┘
Warmup
Warmup
Warmup
WARMUP┌───────────────────────────────────────────────────────────┐│││ Start with very low LR, ramp up, then decay. ││││ lr │││ *───│││ / │││ / ││ 0.5 │ / │││ / ││ 0 │*─/ * ││└──────────────────────││ 0 warmup peak T ││││ Warmup prevents early instability. ││ Gradients are noisy at start (random weights). ││ High LR + noisy gradients = explosion. │││└───────────────────────────────────────────────────────────┘
Typical warmup: 1-5% of total training steps.
Common Configurations
Pre-training LLM
Pre-training LLM Config
Pre-training LLM Config
Optimizer: AdamW
beta1 = 0.9, beta2 = 0.95
Weight decay: 0.1
LR: 1e-4 to 3e-4
Schedule: Cosine decay with warmup
Warmup: 2000 steps (or 1% of total)
Fine-tuning
Fine-tuning Config
Fine-tuning Config
Optimizer: AdamW
beta1 = 0.9, beta2 = 0.999
Weight decay: 0.01
LR: 1e-5 to 5e-5 (10-100x smaller than pre-training)
Schedule: Linear decay with warmup
Warmup: 3-5% of total steps
Quick Reference
Scenario
LR
Schedule
Warmup
Pre-training LLM
1e-4 - 3e-4
Cosine
1-2%
Fine-tuning LLM
1e-5 - 5e-5
Linear
3-5%
Fine-tuning BERT
2e-5 - 5e-5
Linear
10%
Training CNN
1e-3
Step decay
None
Gradient Clipping
Limit gradient magnitude to prevent explosions.
Gradient Clipping
Gradient Clipping
GRADIENT CLIPPING┌───────────────────────────────────────────────────────────┐│││Clip by global norm (most common): ││ total_norm = sqrt(SUM(gradient^2)) ││ if total_norm > max_norm: ││ gradient = gradient x (max_norm / total_norm) ││││ Typical max_norm: 1.0 ││││Prevents single bad batch from destroying model. │││└───────────────────────────────────────────────────────────┘
When to use:
Always for transformers
When training is unstable
When loss spikes occasionally
Debugging Optimization
Debugging Optimization
Debugging Optimization
LOSS NOT DECREASING┌───────────────────────────────────────────────────────────┐│││Symptoms: ││ • Loss stays flat from start ││ • Or decreases very slowly││││Causes: ││ • Learning rate too low││ • Warmup too long││ • Wrong optimizer config ││││Debug steps: ││ 1. Try 10x higher LR││ 2. Reduce warmup steps ││ 3. Check gradient values (should be non-zero) ││ 4. Verify optimizer is updating weights │││└───────────────────────────────────────────────────────────┘LOSS EXPLODES (NaN)
┌───────────────────────────────────────────────────────────┐│││Symptoms: ││ • Loss suddenly becomes NaN││ • Training crashes││││Causes: ││ • Learning rate too high││ • Missing gradient clipping ││ • No warmup ││││Debug steps: ││ 1. Add gradient clipping (max_norm=1.0) ││ 2. Reduce LR by 10x ││ 3. Add warmup (5% of steps) ││ 4. Check for numerical issues in data │││└───────────────────────────────────────────────────────────┘FINE-TUNING DESTROYS MODEL┌───────────────────────────────────────────────────────────┐│││Symptoms: ││ • After fine-tuning, model is worse than base││ • "Catastrophic forgetting" ││││Causes: ││ • Learning rate too high││ • No warmup ││ • Weight decay too high││││Debug steps: ││ 1. Use much smaller LR (1e-5 or lower) ││ 2. Add warmup ││ 3. Reduce weight decay ││ 4. Consider LoRA instead of full fine-tuning │││└───────────────────────────────────────────────────────────┘