I/D/E · Generative AI

Optimization

Summary

SGD, Adam, AdamW, learning rate schedules, warmup, and gradient clipping for training

TL;DR

AdamW is the standard optimizer for transformers. Use warmup to prevent early instability, cosine decay for pre-training, linear decay for fine-tuning, and gradient clipping to prevent explosions. Fine-tuning needs 10-100x smaller learning rates than pre-training.

Visual Overview

SGD Update and Problems
SGD UPDATE

                                                           
   w = w - lr x gradient                                   
                                                           
   Where:                                                  
     w = weights                                           
     lr = learning rate                                    
     gradient = dLoss/dw                                   
                                                           


SGD PROBLEMS

 
 1. OSCILLATION IN VALLEYS 
 Loss surface has steep sides, shallow floor 
 SGD bounces side-to-side, slow progress forward 
 
 / / / 
 /  /  /  
 / / /  minimum 
 bouncebounce 
 
 2. SAME LEARNING RATE FOR ALL 
 Some parameters need big updates, others small 
 Single LR can't satisfy both 
 


Momentum

Add “velocity” to SGD. Accumulate gradient direction over time.

Momentum Update and Intuition
MOMENTUM UPDATE

                                                           
   velocity = beta x velocity + gradient                   
   w = w - lr x velocity                                   
                                                           
   Where:                                                  
     beta = momentum coefficient (typically 0.9)           
     velocity = accumulated gradient direction             
                                                           


MOMENTUM INTUITION

 
 Ball rolling downhill. 
 
 Without momentum: 
 Step 1: gradient = [1, 0]  move [1, 0] 
 Step 2: gradient = [-1, 0.1]  move [-1, 0.1] 
 (oscillating!) 
 
 With momentum (beta=0.9): 
 Step 1: velocity = [1, 0]  move [1, 0] 
 Step 2: velocity = 0.9x[1,0] + [-1, 0.1] 
 = [0.9, 0] + [-1, 0.1] 
 = [-0.1, 0.1]  much smaller 
 oscillation! 
 
 Consistent direction gets amplified. 
 Oscillating direction gets dampened. 
 


Adam (Adaptive Moment Estimation)

Combines momentum with adaptive learning rates per parameter.

Adam Update and Why It Works
ADAM UPDATE

                                                           
   m = beta1 x m + (1-beta1) x gradient     # 1st moment   
   v = beta2 x v + (1-beta2) x gradient^2   # 2nd moment   
                                                           
   m_hat = m / (1 - beta1^t)                # Bias correct 
   v_hat = v / (1 - beta2^t)                               
                                                           
   w = w - lr x m_hat / (sqrt(v_hat) + eps)                
                                                           
   Default hyperparameters:                                
     beta1 = 0.9    (momentum decay)                       
     beta2 = 0.999  (variance decay)                       
     eps = 1e-8     (numerical stability)                  
                                                           


WHY ADAM WORKS WELL

 
 1. MOMENTUM (m) 
 Same as SGD+momentum - damps oscillation 
 
 2. ADAPTIVE LR (v) 
 Parameters with large gradients  smaller LR 
 Parameters with small gradients  larger LR 
 
 High-gradient param: lr / sqrt(large_v) = small step 
 Low-gradient param: lr / sqrt(small_v) = large step 
 
 3. BIAS CORRECTION 
 Early steps: m and v are biased toward 0 
 Correction compensates for this 
 

Adam is the default choice for most deep learning.


AdamW

Adam with proper weight decay. The standard for transformers.

AdamW vs Adam + L2

                                                           
   Adam + L2 regularization (WRONG):                       
     gradient = task_gradient + lambda x w                 
     ... adam update with this gradient ...                
                                                           
     Problem: Weight decay mixed into adaptive LR          
              High-variance params get less regularization 
                                                           
   AdamW (CORRECT):                                        
     gradient = task_gradient       # NO weight decay here 
     ... adam update with this gradient ...                
     w = w - lr x lambda x w        # Decay applied AFTER  
                                                           
     Weight decay truly decoupled from optimization.       
                                                           

Always use AdamW for transformers, not Adam with L2.


Learning Rate Schedules

Learning rate should change during training. High initially, lower later.

Linear Decay

Linear Decay

                                                           
   lr = initial_lr x (1 - step/total_steps)                
                                                           
      lr                                                   
                                                          
     1 *                                                  
                                                         
                                                         
   0.5                                                   
                                                         
     0      *                                             
                                                
       0   steps   T                                       
                                                           

Cosine Decay

Cosine Decay

                                                           
   lr = lr_min + 0.5 x (lr_max - lr_min)                   
        x (1 + cos(pi x step/total_steps))                 
                                                           
      lr                                                   
                                                          
     1 *                                                 
                                                         
                                                         
   0.5                                                   
            ___                                          
     0          *                                         
                                                
       0   steps   T                                       
                                                           
   Smooth decay, lingers longer at mid-range LR.           
   Often works better than linear.                         
                                                           

Warmup

Warmup

                                                           
   Start with very low LR, ramp up, then decay.            
                                                           
      lr                                                   
              *                                       
             /                                           
            /                                            
   0.5     /                                             
          /                                              
     0 */             *                                  
                                    
       0  warmup  peak  T                                  
                                                           
   Warmup prevents early instability.                      
   Gradients are noisy at start (random weights).          
   High LR + noisy gradients = explosion.                  
                                                           

Typical warmup: 1-5% of total training steps.


Common Configurations

Pre-training LLM

Pre-training LLM Config
Optimizer: AdamW
beta1 = 0.9, beta2 = 0.95
Weight decay: 0.1
LR: 1e-4 to 3e-4
Schedule: Cosine decay with warmup
Warmup: 2000 steps (or 1% of total)

Fine-tuning

Fine-tuning Config
Optimizer: AdamW
beta1 = 0.9, beta2 = 0.999
Weight decay: 0.01
LR: 1e-5 to 5e-5 (10-100x smaller than pre-training)
Schedule: Linear decay with warmup
Warmup: 3-5% of total steps

Quick Reference

ScenarioLRScheduleWarmup
Pre-training LLM1e-4 - 3e-4Cosine1-2%
Fine-tuning LLM1e-5 - 5e-5Linear3-5%
Fine-tuning BERT2e-5 - 5e-5Linear10%
Training CNN1e-3Step decayNone

Gradient Clipping

Limit gradient magnitude to prevent explosions.

Gradient Clipping

                                                           
   Clip by global norm (most common):                      
     total_norm = sqrt(SUM(gradient^2))                    
     if total_norm > max_norm:                             
         gradient = gradient x (max_norm / total_norm)     
                                                           
   Typical max_norm: 1.0                                   
                                                           
   Prevents single bad batch from destroying model.        
                                                           

When to use:

  • Always for transformers
  • When training is unstable
  • When loss spikes occasionally

Debugging Optimization

Debugging Optimization
LOSS NOT DECREASING

                                                           
   Symptoms:                                               
     • Loss stays flat from start                          
     • Or decreases very slowly                            
                                                           
   Causes:                                                 
     • Learning rate too low                               
     • Warmup too long                                     
     • Wrong optimizer config                              
                                                           
   Debug steps:                                            
     1. Try 10x higher LR                                  
     2. Reduce warmup steps                                
     3. Check gradient values (should be non-zero)         
     4. Verify optimizer is updating weights               
                                                           


LOSS EXPLODES (NaN)

 
 Symptoms: 
 • Loss suddenly becomes NaN 
 • Training crashes 
 
 Causes: 
 • Learning rate too high 
Missing gradient clipping 
 • No warmup 
 
 Debug steps: 
 1. Add gradient clipping (max_norm=1.0) 
 2. Reduce LR by 10x 
 3. Add warmup (5% of steps) 
 4. Check for numerical issues in data 
 


FINE-TUNING DESTROYS MODEL

 
 Symptoms: 
 • After fine-tuning, model is worse than base 
 • "Catastrophic forgetting" 
 
 Causes: 
 • Learning rate too high 
 • No warmup 
 • Weight decay too high 
 
 Debug steps: 
 1. Use much smaller LR (1e-5 or lower) 
 2. Add warmup 
 3. Reduce weight decay 
 4. Consider LoRA instead of full fine-tuning 
 


When This Matters

SituationWhat to know
Training transformersUse AdamW, not Adam
Fine-tuningLR 10-100x smaller than pre-training
Training unstableAdd warmup, gradient clipping
Loss not decreasingTry higher LR
Loss explodingLower LR, add gradient clipping
Understanding configsbeta1=momentum, beta2=variance averaging
Choosing scheduleCosine for pre-training, linear for fine-tuning

See It In Action

Production signal

Why this concept matters

Interview 60% of ML interviews
Production Every training and fine-tuning job
Performance Right optimizer can 2-3x training speed