TL;DR
- Speculative-decoding scheme by Cai et al. (arXiv:2401.10774, 2024).
- Adds k small prediction heads on top of the target model. Head i predicts the token at position +i in parallel during a single forward pass.
- Combined with tree attention over the proposed combinations, gives 2-3x end-to-end speedup with no separate draft model.
- Particularly popular because it integrates with existing serving stacks and requires only a short fine-tune to train the heads.
Overview#
Medusa removes the need for a separate draft model in speculative decoding by attaching several auxiliary heads to the target model. Each head predicts a token at a fixed future offset — head 1 predicts the next token, head 2 predicts the one after, and so on through head k.
On each decoding step the target model produces k candidate continuations in parallel; tree attention verifies multiple combinations and the longest accepted prefix is committed.
Training#
Only the new heads are trained — the base model's weights are frozen. The training set is the same data the base model was fine-tuned on; the loss is teacher-forcing on positions 1..k offset from each target token. Training a Medusa-2 setup (two heads) on a Llama 70B target takes on the order of a few hundred GPU-hours, much less than a full draft-model train.
Tree Attention#
Because each head produces top-K candidates, the verification step has to consider many combinations. Medusa uses a tree attention mask that lets the target model evaluate all combinations in one forward pass while sharing computation across overlapping prefixes.
Medusa speedup is roughly proportional to the number of heads, up to a point. Three to five heads is typical; beyond that the verification tree grows faster than acceptance keeps up.
Integration#
- TensorRT-LLM, vLLM, TGI and SGLang all ship Medusa support.
- Pre-trained Medusa heads are available on the Hugging Face Hub for many open-weight models.
- Distribution preservation: Medusa-1 was distribution-preserving via rejection sampling; Medusa-2 allows a typical-acceptance scheme that trades a small distributional difference for higher speedup.
When to Use#
Pick Medusa when EAGLE-2 is not yet available for your model family, when operational simplicity matters (no draft-model lifecycle) or when integrating into a stack that already supports it. EAGLE-2 typically wins on raw speedup; Medusa wins on ecosystem maturity and training simplicity.
References
- Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads · arXiv (Cai et al., 2024)
- Medusa on GitHub · GitHub
- TensorRT-LLM Medusa Documentation · NVIDIA