Distributed LLM Training 14 - What ZeRO Stage 1, 2, and 3 Each Remove
ZeRO is best understood as a staged system for removing different forms of replicated training state
Why the stages matter
ZeRO can feel complicated at first because it comes with numbered stages. But the core idea is straightforward: data parallel training wastes memory by fully replicating training state on every rank, and ZeRO removes that duplication in steps.
Stage 1: shard optimizer state
The first step is to shard optimizer state. This often helps quickly because Adam-style optimizers carry large extra state.
Benefits:
- meaningful memory savings
- relatively moderate conceptual change
Limit:
- parameters and gradients are still replicated
Stage 2: shard gradients too
The next step shards gradients as well. That cuts another major source of duplication, but it also makes runtime communication and optimizer handling more involved.
Stage 3: shard parameters too
The strongest stage shards parameters themselves. This directly attacks the biggest weakness of plain data parallelism: each rank holding the full model replica.
But it introduces more runtime complexity:
- parameter gathering is now part of execution
- communication grows in importance
- overall behavior becomes more sensitive to topology and scheduling
The next post looks at FSDP, which shares some of the same goals but takes a different PyTorch-centered path.