Why the stages matter

ZeRO can feel complicated at first because it comes with numbered stages. But the core idea is straightforward: data parallel training wastes memory by fully replicating training state on every rank, and ZeRO removes that duplication in steps.

Stage 1: shard optimizer state

The first step is to shard optimizer state. This often helps quickly because Adam-style optimizers carry large extra state.

Benefits:

  • meaningful memory savings
  • relatively moderate conceptual change

Limit:

  • parameters and gradients are still replicated

Stage 2: shard gradients too

The next step shards gradients as well. That cuts another major source of duplication, but it also makes runtime communication and optimizer handling more involved.

Stage 3: shard parameters too

The strongest stage shards parameters themselves. This directly attacks the biggest weakness of plain data parallelism: each rank holding the full model replica.

But it introduces more runtime complexity:

  • parameter gathering is now part of execution
  • communication grows in importance
  • overall behavior becomes more sensitive to topology and scheduling

The next post looks at FSDP, which shares some of the same goals but takes a different PyTorch-centered path.