FSDP is not just "DDP but bigger"

FSDP changes the memory story more fundamentally than DDP. Instead of keeping full parameter replicas on every rank, it keeps parameters sharded and gathers them only when a computation needs them.

Its goals are clear:

  • reduce parameter replication
  • shard gradients and optimizer state as well
  • lower per-rank memory enough to train larger models

The high-level execution pattern

At a conceptual level:

  1. keep parameters sharded by module
  2. all-gather them before the relevant computation
  3. release or reshade them when possible
  4. use reduce-scatter or related collectives during backward

That can dramatically reduce peak memory.

The tradeoff

Advantages:

  • much better memory efficiency for large models
  • strong integration inside PyTorch workflows

Costs:

  • more communication
  • sensitivity to wrapping granularity
  • more runtime complexity than plain DDP

So FSDP is often the right answer to memory pressure, but not a free performance win.

The next post focuses on communication overlap, because once memory is under control, reducing exposed communication time becomes a major optimization target.