Skip to content

Normalization

LayerNorm, BatchNorm, RMSNorm: what they do, when to use them, and Pre-Norm vs Post-Norm

TL;DR

Normalization keeps activations in a reasonable range during training. LayerNorm is standard for transformers, RMSNorm is faster and used in modern LLMs like Llama. Pre-norm (normalize before sublayer) is more stable than Post-norm for deep networks.

Visual Overview

The Problem: Internal Covariate Shift

Batch Normalization

Normalizes across the batch dimension. Each feature is normalized using statistics from the current batch.

Batch Normalization

When it works well:

  • CNNs (computer vision)
  • Large batch sizes (stable statistics)
  • Training (has batch to compute stats)

Problems:

  • Needs batch statistics at inference (use running average)
  • Small batches -> noisy statistics -> unstable
  • Batch size 1 -> undefined (no batch to normalize over)

Layer Normalization

Normalizes across the feature dimension. Each sample is normalized independently.

Layer Normalization

When it works well:

  • Transformers (the standard)
  • RNNs, LSTMs
  • Any batch size (including 1)
  • Inference (no batch dependency)

Why transformers use LayerNorm:

  • Sequence length varies -> batch statistics meaningless
  • Inference often batch_size=1
  • Each token normalized independently

RMSNorm (Root Mean Square Normalization)

Simplified LayerNorm: only variance normalization, no mean centering.

RMSNorm

Why it works:

  • Mean centering turns out to be less important than variance scaling
  • Removing mean computation saves ~7% training time
  • Quality is equivalent or better in practice

Used in: Llama, Llama 2, Mistral, most modern LLMs

# RMSNorm implementation
def rmsnorm(x, weight, eps=1e-6):
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    return weight * (x / rms)

Pre-Norm vs Post-Norm

Where you place normalization matters for training stability.

Pre-Norm vs Post-Norm

Which to use: Pre-norm for new models. Post-norm only if replicating original BERT/GPT-2.


Comparison Table

AspectBatchNormLayerNormRMSNorm
Normalizes acrossBatchFeaturesFeatures
Works with batch=1NoYesYes
Needs running statsYesNoNo
Mean centeringYesYesNo
SpeedBaselineBaseline~7% faster
Used inCNNsTransformersModern LLMs

Debugging Normalization Issues

Debugging Normalization Issues

When This Matters

SituationWhat to know
Reading transformer codeLayerNorm before attention/FFN (pre-norm)
Understanding Llama/MistralRMSNorm, not LayerNorm
Training instabilitySwitch to pre-norm, check norm placement
Batch size constraintsLayerNorm works with any batch size
Optimizing inference speedRMSNorm is slightly faster
Porting CNN techniquesBatchNorm doesn’t work for transformers
Understanding model configs”norm_eps” is the epsilon in denominator
Interview Notes
💼55% of architecture interviews
Interview Relevance
55% of architecture interviews
🏭Understanding transformer architecture
Production Impact
Powers systems at Understanding transformer architecture
RMSNorm ~7% faster than LayerNorm
Performance
RMSNorm ~7% faster than LayerNorm query improvement