TL;DR
- Self-attention computes Attention(Q, K, V) = softmax(Q Kᵀ / sqrt(d_k)) V — a weighted sum of value vectors where the weights come from the softmax of query-key dot products, scaled by 1/sqrt(d_k) to keep gradients sane (Vaswani et al., 2017, arXiv:1706.03762).
- The 'self' part means Q, K and V are all linear projections of the same input — every token simultaneously asks 'who matters to me?' and answers 'how do I matter to others?' in one batched matmul.
- Naive cost is O(n² · d) FLOPs and O(n²) HBM memory; Flash Attention 2/3 keeps the maths exact but tiles the computation in SRAM, dropping HBM traffic to O(n) and delivering 2-4x wall-clock speed-ups on H100/B200.
- Variants in production today — single-head, multi-head, grouped-query (Llama 3, Qwen 3), multi-query (PaLM), multi-head latent (DeepSeek-V3), sliding window (Mistral 7B v0.1), causal vs bidirectional — all share the same Q·Kᵀ·V skeleton; the differences are masking, head sharing and memory layout.
- At inference, the KV cache (the K and V tensors retained per request) dominates HBM pressure and is what PagedAttention, prefix caching and the attention-sink phenomenon all exist to manage.
Overview#
Self-attention is the operation that turned sequence modelling into a GPU-native problem. Before 2017, the dominant primitive was a recurrent hidden state computed token by token — fundamentally incompatible with the parallel-matmul hardware that NVIDIA was already shipping. Vaswani et al. proposed replacing recurrence with a content-addressed lookup: every position projects a query, a key and a value; every query is compared to every key by dot product; the resulting weights mix the values. The entire operation reduces to two batched matrix multiplications and a softmax, with no sequential dependency across tokens.
Mechanically the operation is small — a few lines of PyTorch. Conceptually it is the foundation. Every decoder-only LLM in production in mid-2026 (Llama 3.1, Qwen 3, DeepSeek-V3, Mistral Large 2, Gemma 3, GPT-4o, Claude Sonnet, Gemini 2.5) has self-attention as its central operator. Every encoder embedding model used in retrieval-augmented generation has it. Every Diffusion Transformer (FLUX, SD3, Sora) has it. Inference engines (vLLM, TensorRT-LLM, SGLang) are built around making the attention kernel fast, memory-efficient and KV-cache-friendly. Hardware features ship to accelerate it: H100's FP8 Tensor Cores, B200's FP4 Tensor Cores, Hopper's Tensor Memory Accelerator (TMA), Blackwell's larger SRAM tile.
This entry is the reference for the operator who needs to reason about self-attention as a systems primitive — what the maths is, why the scaling factor exists, where the variants differ, how Flash Attention changes the picture without changing the output, and what the operational pitfalls are when serving long-context decoders in 2026. This entry helps you understand self-attention well enough to pick the right variant (MHA vs GQA vs MLA vs sliding-window) for your workload, predict KV cache pressure before it OOMs your serving fleet, and diagnose the silent-quality bugs that come from mishandled masking or position encoding. If you are deploying models on Yobibyte or running training on Yobitel NeoCloud, this matters because every workload — Llama 3.1 chat, BGE embeddings, Whisper transcription, FLUX image generation — hits this primitive in the inner loop, and the H100 / H200 / B200 inventory it runs on was sized assuming Flash Attention 3 plus GQA.
How it works: the scaled dot-product attention operation#
Given an input sequence X of shape (n, d_model), self-attention learns three weight matrices W_Q, W_K, W_V of shape (d_model, d_k). The projections Q = X W_Q, K = X W_K, V = X W_V give three matrices of shape (n, d_k). Attention(Q, K, V) = softmax(Q Kᵀ / sqrt(d_k)) V produces an output of shape (n, d_k).
Read it row by row. Row i of the output is a weighted sum of all rows of V. The weights come from row i of softmax(Q Kᵀ / sqrt(d_k)) — that is, a probability distribution over the n key positions, derived from the dot product between query i and every key. High dot product means query i and key j point in similar directions in the d_k-dimensional embedding space, which the softmax converts into a high weight, which means value row j contributes strongly to output row i. The whole operation is a differentiable content-addressed lookup.
The 1/sqrt(d_k) factor is the single detail that is easy to skip and disastrous to forget. Dot products of two independent random vectors of dimension d_k have variance d_k. Without the rescaling, scores grow with d_k, push the softmax into saturation (one value near 1, everything else near 0), and the gradient through the softmax collapses. With the rescaling, score variance is approximately 1 regardless of d_k, the softmax behaves smoothly, and gradients flow.
Two matmuls dominate the FLOP count: Q Kᵀ at O(n² · d_k) and the (softmax) · V at O(n² · d_k). Total: O(n² · d_k) compute. Naive memory: the attention matrix softmax(Q Kᵀ) has shape (n, n), and materialising it in HBM is O(n²). At n = 32,768 and BF16, that is 2 GB per head per layer — for an 80-layer Llama 3 70B with 64 query heads, the naive attention activations alone would exceed 10 TB. This is why every production kernel since 2022 (Flash Attention, Memory-Efficient Attention, xFormers) avoids materialising that matrix.
# self_attention.py — runs with: pip install torch && python self_attention.py
import torch
import torch.nn.functional as F
torch.manual_seed(0)
def scaled_dot_product_attention(q, k, v, mask=None):
"""Vaswani et al. (2017) eq. 1: softmax(Q K^T / sqrt(d_k)) V"""
d_k = q.size(-1)
scores = (q @ k.transpose(-2, -1)) / (d_k ** 0.5) # (..., n_q, n_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1) # (..., n_q, n_k)
return weights @ v, weights # (..., n_q, d_v)
# Single-sequence sanity check: 8 tokens, 16-dim head.
n, d_k = 8, 16
x = torch.randn(n, 64)
w_q, w_k, w_v = (torch.randn(64, d_k) for _ in range(3))
q, k, v = x @ w_q, x @ w_k, x @ w_v
causal = torch.tril(torch.ones(n, n))
out, attn = scaled_dot_product_attention(q, k, v, mask=causal)
print("output:", out.shape) # (8, 16)
print("attn row 0:", attn[0]) # only position 0 has weight (1.0)
print("attn row 7:", attn[7]) # distribution over positions 0..7
# Production code should call torch.nn.functional.scaled_dot_product_attention
# which dispatches to Flash Attention 3 on Hopper/Blackwell.
out_fast = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print("max abs diff vs naive:", (out - out_fast).abs().max().item())Always test causal-mask correctness with a 4-token sanity check. Row 0 of the attention matrix should be [1, 0, 0, 0]; row 3 should be a distribution over positions 0-3 summing to 1. Off-by-one mask bugs (attending to your own future token by one position) are the most common silent regression in custom attention code — quality looks fine on perplexity but degrades sharply on long-context reasoning.
Variants: how every modern Transformer changes the basic recipe#
The Q·Kᵀ·V skeleton is universal. What varies in 2026 production models is (a) how many heads share K and V, (b) what mask is applied, and (c) what window the attention operates over. The table below maps the five variants you actually meet in modern LLM and DiT code.
- MHA — Vaswani et al. 2017. h independent (Q, K, V) projections, concatenated and projected back. Conceptual baseline; rare in new dense models after 2023.
- MQA — Shazeer 2019 (arXiv:1911.02150). All h query heads share one K head and one V head. Shrinks KV cache dramatically but causes small quality regression and training instability at scale.
- GQA — Ainslie et al. 2023 (arXiv:2305.13245). The compromise: h query heads, g shared K/V heads with 1 < g < h. Llama 3, Qwen 3, Mistral Large, Gemma 2/3 all use g = 8 with h = 32 or 64. Quality matches MHA within evaluation noise; KV cache shrinks ~8x.
- MLA — DeepSeek-V2 (2024). Compresses K and V into a small learned latent (rank ~512), then expands per-head at attention time. Smaller cache than GQA at similar quality, but with more complex kernel support.
- SWA — Beltagy et al. 2020 (Longformer); operationalised by Mistral 7B v0.1. Each token attends to a fixed window w (Mistral used w = 4,096). Linear in n. Often interleaved with full attention or replaced in later versions because window edges lose information; Mistral dropped it in v0.2.
- Causal mask (lower triangular: position i attends 1..i) is universal in decoder-only LLMs. Bidirectional (no mask) is used in encoder-only models for classification and retrieval. Mixed variants exist (encoder-decoder cross-attention is bidirectional in source, causal in target).
| Variant | Heads (Q : KV) | Mask | KV cache vs MHA | Used by |
|---|---|---|---|---|
| Single-head | 1 : 1 | Causal or bidirectional | 1x | Pedagogical only |
| Multi-Head Attention (MHA) | h : h | Causal (decoder), full (encoder) | 1x baseline | BERT, GPT-2, early Llama |
| Multi-Query Attention (MQA) | h : 1 | Causal | 1/h (often 32x-64x smaller) | PaLM, Falcon |
| Grouped-Query Attention (GQA) | h : g (g = 4-8) | Causal | h/g (typically 8x smaller) | Llama 3, Qwen 3, Mistral, Gemma |
| Multi-Head Latent (MLA) | h : low-rank latent | Causal | ~16x smaller than MHA | DeepSeek-V2 / V3 |
| Sliding-Window Attention (SWA) | h : h within window w | Causal + window | O(w) per token | Mistral 7B v0.1, Gemma 2 |
| Bidirectional | h : h | None (full) | n/a (no autoregressive cache) | BERT, embedding models, encoders |
When a 2024+ LLM paper says 'multi-head attention', read carefully — they almost always mean GQA. Pure MHA is essentially extinct in new frontier models because the KV cache cost is prohibitive at long context.
Where it is used today: every Transformer block in production#
Every credible frontier model shipping in mid-2026 is a Transformer, and every Transformer block contains a self-attention operator. The deployment patterns split by family.
Decoder-only causal LLMs (Llama 3.1 8B/70B/405B, Qwen 3, DeepSeek-V3, Mistral Large 2, Gemma 3, GPT-4o, Claude Sonnet 4, Gemini 2.5) use causal-masked self-attention with GQA or MLA. They generate one token at a time at inference, caching K and V from prior tokens in the KV cache. This is the workload that vLLM, TensorRT-LLM and SGLang are built to serve.
Encoder-only bidirectional models (BERT-family, DeBERTa-v3, BGE, E5, GTE, text-embedding-3, modern-bert) use unmasked self-attention. They process the whole input in one forward pass and emit either per-token hidden states (for token classification) or a pooled embedding (for retrieval). No KV cache because there is no autoregressive generation.
Encoder-decoder stacks (T5, NLLB, Whisper) use bidirectional self-attention in the encoder, causal self-attention in the decoder, and cross-attention from decoder to encoder. Still common in translation and constrained-output tasks; largely displaced by decoder-only in general-purpose chat.
Diffusion Transformers (FLUX.1, Stable Diffusion 3, Sora, DALL-E 3) tokenise the latent image into patches and apply bidirectional self-attention across patches, often with cross-attention from text-encoder embeddings. The same primitive, applied to vision tokens instead of language tokens.
Vision Transformers (ViT, DINOv2, EVA) and audio Transformers (Whisper encoder, Wav2Vec 2) follow the encoder-only pattern on patch or frame tokens. Self-attention is modality-agnostic; the tokeniser changes per modality, the operator does not.
On Yobibyte, the same primitive carries every catalogue model — a Llama 3.1 70B chat workload, a DeepSeek-V3 reasoning workload, a Qwen 3 long-context document workload — through the same Flash Attention 3 kernel path on Yobitel's Hopper and Blackwell inventory. Customers never tune attention flags directly; the routing logic in Omniscient Compute picks GQA-aware kernels and sizes KV cache headroom based on the workload pattern this entry describes.
Trade-offs and known limitations#
Self-attention has well-understood structural costs and a handful of failure modes that bite teams in production. None are existential — every one has a standard mitigation in 2026 — but they are the operational reality of serving modern decoders.
Quadratic compute in sequence length is the headline limitation. At n = 256k tokens, prefill of a 70B model on a single H100 takes minutes rather than seconds. Mitigations: chunked prefill (vLLM's enable_chunked_prefill, SGLang's chunked-prefill scheduler) interleaves prefill chunks with decode steps so user-visible latency stays bounded; Ring Attention (Liu et al. 2023) distributes the n² matrix across devices for distributed long-context training and inference.
Quadratic memory in sequence length is the second wall. Flash Attention 2 and 3 reduce it to O(n) by tiling the softmax — same output, dramatically less HBM traffic — and are now the default kernel in PyTorch's torch.nn.functional.scaled_dot_product_attention on Hopper and Blackwell. With Flash Attention, the bottleneck shifts to the KV cache rather than the attention matrix itself.
KV cache memory grows linearly with sequence length, layers, KV heads and batch size. For Llama 3.1 70B (80 layers, 8 KV heads with GQA, d_k = 128) at 128k context and BF16, the cache is roughly 44 GB per request. PagedAttention (vLLM's block-table indexing) and prefix caching (sharing prefix KV across requests) are the standard memory-management techniques. At ~1M-token contexts, even GQA-shrunk caches exceed HBM and require CPU/NVMe offload with order-of-magnitude latency penalties.
Attention sinks (Xiao et al. 2023, arXiv:2309.17453) are a subtle quality issue: causal-attention softmax over very long contexts concentrates probability mass on the first few tokens regardless of content, because the model needs somewhere to 'park' attention when no key is relevant. Sliding-window methods that drop these initial tokens degrade quality sharply; the fix is to keep a small fixed set of leading 'sink tokens' in cache always.
Long-sequence position generalisation is structurally a positional-encoding problem, not an attention problem, but it manifests at the attention output: dot products between Q at position 1,048,576 and K at position 100 are not well-behaved unless RoPE is YaRN-scaled (or ALiBi is used). See the rotary-position-embedding and alibi-position-bias entries for the cure.
Sub-quadratic alternatives — Mamba and Mamba-2 (Dao & Gu 2024) for selective state-space, RWKV for linear-attention RNN with parallel training, Hyena for long convolutions — have all carved real niches but none has matched dense self-attention at frontier scale through mid-2026. Hybrid architectures (Jamba, Zamba) interleave SSM blocks with attention blocks to combine strengths.
Practical implementation notes#
On modern hardware, the right answer is almost never to write your own attention kernel. Use PyTorch's torch.nn.functional.scaled_dot_product_attention — since PyTorch 2.0 it dispatches to Flash Attention 2/3 on Hopper/Blackwell, xFormers Memory-Efficient Attention on Ampere, and the math fallback on CPU. For training, FlashAttention (the standalone wheel) ships the latest kernel updates faster than the PyTorch bundled version. For serving, vLLM and TensorRT-LLM use their own fused PagedAttention kernels built on Flash Attention primitives.
Flash Attention 2 (Dao 2023, arXiv:2307.08691) was the kernel that made long-context training practical: tiles the (Q, K, V) computation so the softmax never materialises in HBM, reorders the loops to maximise Tensor Core occupancy, and dropped memory from O(n²) to O(n) with a 2-3x speed-up. Flash Attention 3 (Shah et al. 2024, arXiv:2407.08608) adds Hopper-specific asynchrony (TMA + WGMMA overlap), FP8 support, and another ~1.5-2x on H100 — the kernel that ships in vLLM 0.6+, SGLang and TensorRT-LLM.
The KV cache is what dominates serving memory; PagedAttention (Kwon et al. 2023, arXiv:2309.06180) treats KV memory like virtual memory in an OS — fixed-size pages, a block table mapping logical request positions to physical pages, near-zero fragmentation. Prefix caching (vLLM's enable_prefix_caching = True) shares pages across requests that begin with the same prompt prefix, which is huge for system-prompt-heavy chat workloads.
Numerical precision in attention is straightforward in mid-2026: BF16 is the safe training default (8-bit exponent matches FP32 range, so loss spikes are rare); FP16 attention is fine if you scale the softmax denominator carefully; FP8 attention (Flash Attention 3, TensorRT-LLM FP8 mode) needs per-tensor scaling and a representative calibration set but unlocks ~1.5-2x throughput on H100; FP4 attention is emerging on Blackwell B200 for inference, not yet stable for training.
Inference engine support for the variants is uneven. GQA, MQA and bidirectional are universal. MLA is supported in vLLM 0.6+, TensorRT-LLM 0.13+ and SGLang. Sliding-window is supported in vLLM and SGLang but is no longer the design point of new models. Ring Attention is supported in research stacks (Megatron-LM, Axolotl) for training; production inference rarely needs it.
If you see KV cache OOMs at long context: first verify the model is GQA (not MHA), then enable prefix caching, then lower max_num_seqs to reduce concurrent KV slots, then move to H200 or B200 if the workload demands warrant it. Switching to a sliding-window model is usually the wrong fix — modern long-context decoders work because of full attention, not despite it.
Where self-attention sits in the Yobitel stack#
Every model in the Yobibyte catalogue — Llama 3.1, Qwen 3, DeepSeek-V3, Mistral Large 2, FLUX.1, Whisper, BGE — sits on self-attention. The platform routes inference through industry-standard runtimes (the vLLM / TensorRT-LLM / SGLang family), all of which use Flash Attention 3 kernels on Hopper and Blackwell hardware in the Yobitel GPU Cloud inventory. Customers do not see the attention plumbing; they choose a model, a region and a spend cap, and the kernel selection happens transparently.
InferenceBench measures the empirical consequences of attention-implementation choices — Flash Attention 2 vs 3, GQA vs MHA, prefix caching on vs off, FP8 vs BF16 — across the same model and hardware combinations a customer would deploy on. For teams designing a serving stack, the architectural reasoning here pairs with the measured tokens/sec/$ in InferenceBench.
References
- Attention Is All You Need (Vaswani et al., 2017) · arXiv
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022) · arXiv
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023) · arXiv
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (Shah et al., 2024) · arXiv
- Efficient Memory Management for Large Language Model Serving with PagedAttention (Kwon et al., 2023) · arXiv
- Efficient Streaming Language Models with Attention Sinks (Xiao et al., 2023) · arXiv
- GQA: Training Generalized Multi-Query Transformer Models (Ainslie et al., 2023) · arXiv
- The Annotated Transformer (Harvard NLP) · Harvard NLP