Skip to content

Ai-engineering Series

Embeddings to Attention - Relating Tokens to Each Other

Deep dive into attention mechanisms: why transformers replaced RNNs, scaled dot-product attention, multi-head attention, and how context length affects performance

Building On Previous Knowledge

In the previous progression, you learned how tokens become embeddings—vectors that capture meaning. Each token has its own embedding vector.

But there’s a problem: embeddings are independent. The vector for “bank” doesn’t know whether it’s about finance or rivers until it sees the surrounding words.

Attention solves this by letting each token’s representation incorporate information from other tokens. After attention, “bank” in “river bank” has a different representation than “bank” in “savings bank”—because it attended to different context.

What Goes Wrong Without This:

Attention Failure Patterns

Why Attention Matters

Before attention, sequence models used recurrence (RNNs, LSTMs):

RNN Problems

Attention allows direct connections:

Attention Benefits

The Core Idea: Weighted Mixing

Attention is surprisingly simple at its core:

Attention as Weighted Mixing

For the sentence “The cat sat on the mat”:

Attention Example

Query, Key, Value

The Q, K, V framework formalizes how attention scores are computed:

Query, Key, Value Intuition

In practice, Q, K, V are linear projections of the input embeddings:

Q, K, V Projections

Scaled Dot-Product Attention

The standard attention formula:

Attention Formula

Step 1: Compute Attention Scores

Step 1: Attention Scores

Step 2: Scale

Step 2: Scaling

Step 3: Softmax

Step 3: Softmax

Step 4: Weighted Sum

Step 4: Weighted Sum

Complete Picture

Scaled Dot-Product Attention

Multi-Head Attention

One attention pattern isn’t enough. Different relationships need different attention:

Why Multiple Heads

Multi-head attention runs h parallel attention operations:

Multi-Head Attention

Typical configurations:

Model Configurations

Context Window and Attention

The context window limit exists because attention is O(n²):

Attention Cost Scaling

Techniques for Longer Context

Sparse Attention
Flash Attention
Sliding Window & RoPE

The “Lost in the Middle” Problem

Lost in the Middle

Practical impact:

  • Put critical information at beginning or end of prompts
  • Don’t bury important context in the middle of long documents
  • Test your application with information at different positions

Attention Visualization

What attention patterns look like:

Attention Weights Matrix

Code Example

Minimal implementation of scaled dot-product attention:

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(
    Q: torch.Tensor,  # (batch, n, d_k)
    K: torch.Tensor,  # (batch, n, d_k)
    V: torch.Tensor,  # (batch, n, d_v)
    mask: torch.Tensor = None,  # optional mask
) -> torch.Tensor:
    """
    Compute scaled dot-product attention.

    Returns:
        Output tensor of shape (batch, n, d_v)
    """
    d_k = Q.size(-1)

    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, n, n)

    # Step 2: Scale
    scores = scores / (d_k ** 0.5)

    # Optional: Apply mask (for causal/padding)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)  # (batch, n, n)

    # Step 4: Weighted sum of values
    output = torch.matmul(attention_weights, V)  # (batch, n, d_v)

    return output

# Example usage
batch_size, seq_len, d_model = 2, 10, 64

# Random Q, K, V (in practice, these come from linear projections)
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")  # (2, 10, 64)

Key Takeaways

Key Takeaways

Verify Your Understanding

Before proceeding, you should be able to:

Draw the attention formula and explain each component — What does the softmax do? Why scale by √d_k? What does multiplying by V accomplish?

Explain why multi-head attention is better than single-head — Give a concrete example of different “types” of relationships different heads might learn.

Your LLM has 128K context but struggles to answer questions about content in the middle. What’s happening? How would you restructure your prompt?

Calculate the memory required for full attention with 32K tokens at float16 precision. How does this change with 64K tokens?


What’s Next

After this, you can:

  • Continue → Attention → Generation — how models produce text token by token
  • Go deeper → Explore transformer architectures, pre-training objectives

Concepts covered in this article