Embeddings to Attention - Relating Tokens to Each Other
Deep dive into attention mechanisms: why transformers replaced RNNs, scaled dot-product attention, multi-head attention, and how context length affects performance
15 minutes•Intermediate Level•Dec 2024
Building On Previous Knowledge
In the previous progression, you learned how tokens become embeddings—vectors that capture meaning. Each token has its own embedding vector.
But there’s a problem: embeddings are independent. The vector for “bank” doesn’t know whether it’s about finance or rivers until it sees the surrounding words.
Attention solves this by letting each token’s representation incorporate information from other tokens. After attention, “bank” in “river bank” has a different representation than “bank” in “savings bank”—because it attended to different context.
What Goes Wrong Without This:
Attention Failure Patterns
Attention Failure Patterns
Symptom: Your model truncates long documents and misses important information.
Cause: You treated context as infinite. Attention is O(n²) in memory.
128K context doesn't mean you can use 128K without consequences.
Symptom: Model gives inconsistent answers to the same question.
Cause: In long contexts, attention can miss relevant information.
"Lost in the middle" - models attend more to beginning and end.
Symptom: Reasoning fails on complex multi-step problems.
Cause: Attention struggles to carry information across many hops.
Each hop through attention layers is lossy.
Why Attention Matters
Before attention, sequence models used recurrence (RNNs, LSTMs):
RNN Problems
RNN Problems
Process sequentially:
token_1 → state_1 → token_2 → state_2 → ... → token_n → state_n
Problems:
1. Can't parallelize (each step depends on previous)
2. Long-range dependencies are hard (gradient vanishing)
3. Information bottleneck (fixed-size state)
A 1000-word document must compress through a single state vector.
Attention allows direct connections:
Attention Benefits
Attention Benefits
Every token can directly access every other token:
token_1 ←→ token_2 ←→ token_3 ←→ ... ←→ token_n
↑↑↑↑└───────────┴───────────┴───────────────┘
All pairwise connections
Benefits:
1. Fully parallelizable (all attention computed at once)
2. Direct long-range access (no bottleneck)
3. Dynamic weighting (attend more to relevant tokens)
This is why Transformersreplaced RNNs everywhere.
The Core Idea: Weighted Mixing
Attention is surprisingly simple at its core:
Attention as Weighted Mixing
Attention as Weighted Mixing
Input: Sequence of token embeddings [v1, v2, v3, v4]
For each token, compute a new representation by
MIXING all tokens weighted by relevance:
new_v2 = 0.1*v1 + 0.6*v2 + 0.2*v3 + 0.1*v4
↑↑↑↑
weights sum to 1.0 (softmax)
The weights (attention scores) determine how much
each token contributes to the new representation.
For the sentence “The cat sat on the mat”:
Attention Example
Attention Example
When processing "sat":
• High attention to "cat" (subject of sat)
• Medium attention to "mat" (related to sitting)
• Low attention to "the" (less informative)
Result: "sat" embedding now contains information
about WHAT sat (cat) and WHERE (mat).
Query, Key, Value
The Q, K, V framework formalizes how attention scores are computed:
Query, Key, Value Intuition
Query, Key, Value Intuition
┌──────────────────────────────────────────────────────────────────┐│ INTUITION: Library Metaphor│├──────────────────────────────────────────────────────────────────┤│││Query (Q): What am I looking for? ││ "I need books about machine learning" ││││Key (K): What does each item contain? ││ Book 1: "Introduction to AI" ││ Book 2: "Cooking recipes" ││ Book 3: "Deep Learning fundamentals" ││││Value (V): The actual content to retrieve││ The book's actual contents ││││MatchQuery against Keys →WeightValues by match quality │││└──────────────────────────────────────────────────────────────────┘
In practice, Q, K, V are linear projections of the input embeddings:
Q, K, V Projections
Q, K, V Projections
Input embedding: x (dimension d_model)
Q = x @ W_Q # project to query space
K = x @ W_K # project to key space
V = x @ W_V # project to value space
Where W_Q, W_K, W_V are learnedweight matrices.
Each token gets its own Q, K, V vectors.
Token i's query asks: "What should I attend to?"
Token j's key advertises: "Here's what I contain"
Token j's value provides: "Here's my information if you want it"
Scaled Dot-Product Attention
The standard attention formula:
Attention Formula
Attention Formula
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Let's break this down:
Step 1: Compute Attention Scores
Step 1: Attention Scores
Step 1: Attention Scores
scores = Q @ K^T
For a sequence of n tokens, each with d_k dimensional Q and K:
Q: (n, d_k)
K: (n, d_k)
K^T: (d_k, n)
Q @ K^T: (n, n)←attention scores matrixscores[i][j] = how much token i should attend to token j
Step 2: Scale
Step 2: Scaling
Step 2: Scaling
scaled_scores = scores / √d_k
Why scale?
Dot products grow with dimension size.
Large dot products → softmax becomes very peaked→gradients vanish (all weight on one token)
√d_k keeps variancestable regardless of dimension.
Step 3: Softmax
Step 3: Softmax
Step 3: Softmax
attention_weights = softmax(scaled_scores)
Softmaxconverts scores to probabilities:
• All values between 0 and 1
• Each row sums to 1.0
• High scores → high weights, low scores → near zero
Example row: [2.1, 0.5, -1.0, 0.8]
After softmax: [0.65, 0.13, 0.03, 0.19]
↑
Token with score 2.1 gets most attention
Step 4: Weighted Sum
Step 4: Weighted Sum
Step 4: Weighted Sum
output = attention_weights @ V
Each output vector is a weighted combination of all value vectors:
output_i = Σ (attention_weight[i][j] * V[j])
This is where information actually flows between tokens.
Complete Picture
Scaled Dot-Product Attention
Scaled Dot-Product Attention
SCALED DOT-PRODUCT ATTENTION:
Q (n×d_k) K (n×d_k)
│││┌─────────┘││ (transpose)
▼▼┌────────────┐│MatMul│Q @ K^T = (n×n) attention scores└─────┬──────┘│▼┌────────────┐│Scale│divide by √d_k└─────┬──────┘│▼┌────────────┐│Softmax│ convert to probabilities (each row)
└─────┬──────┘││V (n×d_v)
││▼▼┌────────────────────┐│MatMul│ weights @ V = output (n×d_v)
└─────────┬──────────┘│▼Output (n×d_v)
Multi-Head Attention
One attention pattern isn’t enough. Different relationships need different attention:
Why Multiple Heads
Why Multiple Heads
"The animal didn't cross the street because it was too tired."
Different questions need different attention patterns:
• Q: What is "it"? →attend "it" to "animal" (coreference)
• Q: What action? →attendverbs to subjects
• Q: What's the reason? →attend "tired" to "didn't cross"
Solution: Multiple attention "heads", each learning different patterns.
Multi-head attention runs h parallel attention operations:
┌──────────────────┬───────────────┬───────────────┐│ Model │d_model│Heads (h) │├──────────────────┼───────────────┼───────────────┤│BERT-base│ 768 │ 12 ││GPT-2│ 768 │ 12 ││GPT-3 (175B) │ 12288 │ 96 ││LLaMA 7B │ 4096 │ 32 │└──────────────────┴───────────────┴───────────────┘
Each head has d_k = d_model / h dimensions.
More heads = more diverse attention patterns.
Context Window and Attention
The context window limit exists because attention is O(n²):
Attention Cost Scaling
Attention Cost Scaling
For sequence length n:
• Attention matrix: n × n
• Memory: O(n²)
• Compute: O(n²)┌──────────────────┬───────────────┬───────────────┐│Context Length│ Attention │Memory│├──────────────────┼───────────────┼───────────────┤│ 1K tokens │ 1M entries │ ~4 MB ││ 4K tokens │ 16M entries │ ~64 MB ││ 32K tokens │ 1B entries │ ~4 GB ││ 128K tokens │ 16B entries │ ~64 GB │└──────────────────┴───────────────┴───────────────┘
This is why long-context models are expensive.
128K context doesn't mean free 128K—it means 128K² cost.
Techniques for Longer Context
Sparse Attention
Sparse Attention
1. SPARSE ATTENTION
Instead of n² full attention, attend to subset:
• Local attention: only nearby tokens
• Strided attention: every k-th token
• Random attention: sample positions
BigBird, Longformer use O(n) attention patterns.
Trade: some information paths are blocked.
Flash Attention
Flash Attention
2. FLASH ATTENTION
Not mathematically different—same result.
But implements attention in a memory-efficient way:
• Never materializes full n×n matrix
• Computes in tiles that fit in GPU SRAM
• 2-4x faster, same memory as single forward pass
This is why modern context windows keep growing.
Sliding Window & RoPE
Sliding Window & RoPE
3. SLIDING WINDOW / RoPECombine:
• Rotary Position Embeddings (RoPE) for relative positions
• Sliding window for bounded attention
• Global tokens that always attend everywhere
LLaMA, Mistral use these patterns.
The “Lost in the Middle” Problem
Lost in the Middle
Lost in the Middle
Position in context vs attention received:
Attention│
Score │ ████
│ ████ ████
│ ████████ ████████████
│ ████████████████ ████████████████████████
└──────────────────────────────────────────────BeginningMiddleEndBeginning and end get more attention.
Middle content can be "lost."
Practical impact:
Put critical information at beginning or end of prompts
Don’t bury important context in the middle of long documents
Test your application with information at different positions
Attention Visualization
What attention patterns look like:
Attention Weights Matrix
Attention Weights Matrix
"The cat sat on the mat"
Attention weights (simplified, one head):
The cat sat on the mat
The [0.3 0.2 0.1 0.1 0.2 0.1]
cat [0.2 0.4 0.2 0.0 0.1 0.1]
sat [0.1 0.5 0.2 0.1 0.0 0.1] ← "sat" attends heavily to "cat"
on [0.1 0.1 0.3 0.2 0.1 0.2]
the [0.1 0.1 0.1 0.2 0.3 0.2]
mat [0.1 0.1 0.2 0.2 0.2 0.2]
Different heads learn different patterns:
• Head 1: Subject-verb relationships
• Head 2: Positional (nearby tokens)
• Head 3: Syntactic structure
Code Example
Minimal implementation of scaled dot-product attention:
import torchimport torch.nn.functional as Fdef scaled_dot_product_attention( Q: torch.Tensor, # (batch, n, d_k) K: torch.Tensor, # (batch, n, d_k) V: torch.Tensor, # (batch, n, d_v) mask: torch.Tensor = None, # optional mask) -> torch.Tensor: """ Compute scaled dot-product attention. Returns: Output tensor of shape (batch, n, d_v) """ d_k = Q.size(-1) # Step 1: Compute attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, n, n) # Step 2: Scale scores = scores / (d_k ** 0.5) # Optional: Apply mask (for causal/padding) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # Step 3: Softmax to get attention weights attention_weights = F.softmax(scores, dim=-1) # (batch, n, n) # Step 4: Weighted sum of values output = torch.matmul(attention_weights, V) # (batch, n, d_v) return output# Example usagebatch_size, seq_len, d_model = 2, 10, 64# Random Q, K, V (in practice, these come from linear projections)Q = torch.randn(batch_size, seq_len, d_model)K = torch.randn(batch_size, seq_len, d_model)V = torch.randn(batch_size, seq_len, d_model)output = scaled_dot_product_attention(Q, K, V)print(f"Output shape: {output.shape}") # (2, 10, 64)
Key Takeaways
Key Takeaways
Key Takeaways
1. Attention lets tokens incorporate information from all other tokens
2. Q, K, V are projections that define what to attend to and what to retrieve
3. Scaled dot-product attention: softmax(QK^T / √d_k) @ V
4. Multi-head attention runs h parallel attention operations
- Each head can learn different relationship patterns
5. Context window limits exist because attention is O(n²)
- 128K context = 128K² computation
6. "Lost in the middle" is real
- Critical information should be at beginning or end
Verify Your Understanding
Before proceeding, you should be able to:
Draw the attention formula and explain each component — What does the softmax do? Why scale by √d_k? What does multiplying by V accomplish?
Explain why multi-head attention is better than single-head — Give a concrete example of different “types” of relationships different heads might learn.
Your LLM has 128K context but struggles to answer questions about content in the middle. What’s happening? How would you restructure your prompt?
Calculate the memory required for full attention with 32K tokens at float16 precision. How does this change with 64K tokens?