TL;DR
- Saves activation memory by not storing intermediate activations; instead, the forward pass is re-run during the backward pass for the chunks that were discarded.
- Originally formalised by Chen et al., 2016 (arXiv:1604.06174) under the name 'sublinear memory cost'.
- Costs roughly 30-40 % extra compute time but can reduce activation memory by 5-10× — often the difference between fitting and not fitting a long-context training step.
Overview#
During backpropagation, the gradient for each layer requires the activation of that layer's input. Naively, every layer's input activation is stored in memory during the forward pass and read back during backward. For a deep network with long sequences, this activation memory dominates GPU memory consumption.
Gradient checkpointing — also called activation recomputation — stores only a subset of activations (typically at transformer-block boundaries) and recomputes the rest by re-running the forward pass during backward. The compute cost is paid in extra time; the memory cost is paid back many times over.
Mechanism#
Chen et al.'s original formulation showed that saving O(√N) checkpoints out of N layers reduces activation memory to O(√N) at a cost of one extra forward pass — sublinear in both. In modern practice, the heuristic is simpler: checkpoint at every transformer block, accepting roughly 30-40 % extra forward FLOPs in exchange for ~5-10× activation-memory reduction.
Selective recomputation (Megatron-LM v3 / Korthikanti et al., 2022) takes a more nuanced approach — checkpoint only the operations where recomputation is cheap relative to memory saved (e.g. softmax and dropout), and store activations for expensive operations (matmuls). This gets most of the memory benefit at much lower compute cost — ~5 % overhead instead of 30 %.
Performance Characteristics#
- Memory savings: 5-10× activation memory at full checkpointing; 3-5× at selective.
- Compute cost: ~30 % at full; ~5 % at selective recomputation.
- Scales: applies independently to any parallelism strategy — composes with DP, TP, PP, FSDP.
- Sequence length is often the binding constraint that checkpointing relieves.
When to Use#
Use full checkpointing when you are otherwise out of memory and willing to pay 30 % extra training time. Use selective recomputation (the Megatron-LM default) for the best memory-compute trade-off in long-context training. For short sequences and small models that fit comfortably, checkpointing is a pessimisation.
Megatron-LM's selective activation recomputation is the right default for any modern long-context transformer training run. Full checkpointing is only needed when memory is truly desperate.
Pitfalls#
- Stochastic operations (dropout) need to use the same random seed in forward and recomputed-forward — frameworks handle this but custom code may not.
- Mixing checkpointing with operations that have side effects (random ops, certain custom kernels) requires care.
- Checkpointing inside FSDP or ZeRO-3 interacts with parameter gathers — the right wrapping order matters.
Software#
- PyTorch `torch.utils.checkpoint.checkpoint` and `torch.utils.checkpoint.checkpoint_sequential`.
- FSDP `apply_activation_checkpointing` for transformer-block-granular checkpointing.
- Megatron-LM `--recompute-activations selective` for the production selective-recomputation path.
- DeepSpeed activation-checkpointing integration with ZeRO.
References
- Training Deep Nets with Sublinear Memory Cost · arXiv (Chen et al., 2016)
- Reducing Activation Recomputation in Large Transformer Models · arXiv (Korthikanti et al., 2022)
- PyTorch checkpoint documentation · PyTorch