I/D/E · Generative AI

Normalization

Summary

LayerNorm, BatchNorm, RMSNorm: what they do, when to use them, and Pre-Norm vs Post-Norm

TL;DR

Normalization keeps activations in a reasonable range during training. LayerNorm is standard for transformers, RMSNorm is faster and used in modern LLMs like Llama. Pre-norm (normalize before sublayer) is more stable than Post-norm for deep networks.

Visual Overview

The Problem: Internal Covariate Shift

                                                           
   During training, each layer's input distribution        
   changes as previous layers update.                      
                                                           
   Epoch 1: Layer 3 receives inputs with mean=0, std=1     
   Epoch 2: Layer 2 weights changed  Layer 3 now sees     
            mean=0.5, std=2                                
   Epoch 3: More drift  Layer 3 sees mean=1.2, std=3.5    
                                                           
   Layer 3 keeps having to re-adapt to shifting inputs.    
   Training is slower and less stable.                     
                                                           
   SOLUTION: Normalize activations to have consistent      
   statistics.                                             
                                                           


Batch Normalization

Normalizes across the batch dimension. Each feature is normalized using statistics from the current batch.

Batch Normalization

                                                           
   Input: x with shape (batch_size, features)              
                                                           
   For each feature f:                                     
     mu_f = mean(x[:, f])        # mean across batch       
     sigma_f = std(x[:, f])      # std across batch        
                                                           
     x_norm[:, f] = (x[:, f] - mu_f) / (sigma_f + eps)     
                                                           
   Then apply learnable scale and shift:                   
     output = gamma × x_norm + beta                        
                                                           
   gamma and beta are learned per feature.                 
                                                           


VISUAL: NORMALIZE ACROSS BATCH

 
 Feature 1 Feature 2 Feature 3 
  
 Batch 1 2.1  0.5  -1.2  
  
 Batch 2 1.8  0.7  -0.9   Normalize 
  down each 
 Batch 3 2.3  0.4  -1.1  column 
  
    
 mu=2.07 mu=0.53 mu=-1.07 
 

When it works well:

  • CNNs (computer vision)
  • Large batch sizes (stable statistics)
  • Training (has batch to compute stats)

Problems:

  • Needs batch statistics at inference (use running average)
  • Small batches -> noisy statistics -> unstable
  • Batch size 1 -> undefined (no batch to normalize over)

Layer Normalization

Normalizes across the feature dimension. Each sample is normalized independently.

Layer Normalization

                                                           
   Input: x with shape (batch_size, features)              
                                                           
   For each sample i:                                      
     mu_i = mean(x[i, :])        # mean across features    
     sigma_i = std(x[i, :])      # std across features     
                                                           
     x_norm[i, :] = (x[i, :] - mu_i) / (sigma_i + eps)     
                                                           
   Then apply learnable scale and shift:                   
     output = gamma × x_norm + beta                        
                                                           


VISUAL: NORMALIZE ACROSS FEATURES

 
 Feature 1 Feature 2 Feature 3 
  
 Batch 1 2.1  0.5  -1.2   Normalize 
  this row 
 Batch 2 1.8  0.7  -0.9   Normalize 
  this row 
 Batch 3 2.3  0.4  -1.1   Normalize 
  this row 
 

When it works well:

  • Transformers (the standard)
  • RNNs, LSTMs
  • Any batch size (including 1)
  • Inference (no batch dependency)

Why transformers use LayerNorm:

  • Sequence length varies -> batch statistics meaningless
  • Inference often batch_size=1
  • Each token normalized independently

RMSNorm (Root Mean Square Normalization)

Simplified LayerNorm: only variance normalization, no mean centering.

RMSNorm

                                                           
   Standard LayerNorm:                                     
     x_norm = (x - mean(x)) / std(x)                       
                                                           
   RMSNorm:                                                
     x_norm = x / RMS(x)                                   
                                                           
     where RMS(x) = sqrt(mean(x²))                         
                                                           
   No mean subtraction. Just scale by root-mean-square.    
                                                           

Why it works:

  • Mean centering turns out to be less important than variance scaling
  • Removing mean computation saves ~7% training time
  • Quality is equivalent or better in practice

Used in: Llama, Llama 2, Mistral, most modern LLMs

# RMSNorm implementation
def rmsnorm(x, weight, eps=1e-6):
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    return weight * (x / rms)

Pre-Norm vs Post-Norm

Where you place normalization matters for training stability.

Pre-Norm vs Post-Norm
POST-NORM (Original Transformer)

                                                           
   x = x + Attention(x)                                    
   x = LayerNorm(x)          Norm AFTER residual          
   x = x + FFN(x)                                          
   x = LayerNorm(x)                                        
                                                           
   Problem: Gradients must flow through LayerNorm          
            Can cause instability in deep networks         
                                                           


PRE-NORM (Modern Transformers)

 
 x = x + Attention(LayerNorm(x))  Norm BEFORE 
 x = x + FFN(LayerNorm(x)) 
 
 Advantages: 
Residual stream is "clean" (just additions) 
Gradients flow directly through residual path 
More stable for deep networks 
Easier to train without careful LR tuning 
 

Which to use: Pre-norm for new models. Post-norm only if replicating original BERT/GPT-2.


Comparison Table

AspectBatchNormLayerNormRMSNorm
Normalizes acrossBatchFeaturesFeatures
Works with batch=1NoYesYes
Needs running statsYesNoNo
Mean centeringYesYesNo
SpeedBaselineBaseline~7% faster
Used inCNNsTransformersModern LLMs

Debugging Normalization Issues

Debugging Normalization Issues
TRAINING INSTABILITY (LOSS SPIKES)

                                                           
   Symptoms:                                               
Loss suddenly spikes during training                
Gradients explode intermittently                    
                                                           
   Causes:                                                 
     • Post-norm architecture with deep network            
Missing normalization somewhere                     
     • Norm placed incorrectly                             
                                                           
   Debug steps:                                            
     1. Switch to pre-norm if using post-norm              
     2. Check every sublayer has normalization             
     3. Verify norm is before attention/FFN, not after     
     4. Reduce learning rate                               
                                                           


ACTIVATIONS GROWING UNBOUNDED

 
 Symptoms: 
 • Activation magnitudes grow over layers 
 • Eventually overflow to NaN 
 
 Causes: 
Missing normalization layer 
 • Residual accumulation without norm 
 • Wrong norm dimension 
 
 Debug steps: 
 1. Print activation statistics per layer 
 2. Verify norm is applied (gamma, beta params exist) 
 3. Check norm dimension matches input shape 
 


When This Matters

SituationWhat to know
Reading transformer codeLayerNorm before attention/FFN (pre-norm)
Understanding Llama/MistralRMSNorm, not LayerNorm
Training instabilitySwitch to pre-norm, check norm placement
Batch size constraintsLayerNorm works with any batch size
Optimizing inference speedRMSNorm is slightly faster
Porting CNN techniquesBatchNorm doesn’t work for transformers
Understanding model configs”norm_eps” is the epsilon in denominator

Production signal

Why this concept matters

Interview 55% of architecture interviews
Production Understanding transformer architecture
Performance RMSNorm ~7% faster than LayerNorm