FSDP를 DDP의 대체제로만 보면 부족하다

FSDP(Fully Sharded Data Parallel)는 이름만 보면 DDP의 변형처럼 보이지만, 메모리 관점에서는 꽤 다른 철학을 가진다. 핵심은 전체 파라미터를 shard한 상태로 두고, 각 구간 계산에 필요한 순간에만 full parameter를 materialize하는 것이다.

즉 목표는 분명하다.

  • parameter replication 줄이기
  • gradient와 optimizer state도 shard하기
  • 큰 모델을 더 적은 per-rank memory로 학습하기

어떤 흐름으로 동작하는가

고수준에서 보면 FSDP는 다음과 비슷한 흐름을 가진다.

  1. 모듈 단위로 파라미터를 shard 상태로 보관한다
  2. forward 직전에 필요한 파라미터를 all-gather한다
  3. 연산 후 가능하면 다시 shard 상태로 되돌린다
  4. backward에서도 필요한 시점에 gather/reduce-scatter를 수행한다

이 구조 덕분에 peak memory를 크게 낮출 수 있다.

장점과 비용

장점:

  • 큰 모델도 per-rank 메모리 한계 안에서 돌리기 쉬워진다
  • PyTorch ecosystem 안에서 통합하기 좋다

비용:

  • gather/scatter 통신이 증가한다
  • wrapping granularity가 성능에 영향을 준다
  • auto-wrap 정책을 잘못 잡으면 효율이 떨어진다

즉 FSDP는 메모리 문제를 잘 푸는 대신 runtime 설계가 더 섬세해진다.

어떤 상황에서 유리한가

  • 모델이 rank당 메모리에 빠듯하다
  • full replication이 너무 비싸다
  • PyTorch 중심 스택 안에서 운영하고 싶다

반대로 intra-node / inter-node 통신이 매우 불리하면 기대만큼 이득이 안 나올 수 있다.

다음 글에서는 compute와 communication을 겹치는 overlap 기법을 본다. 이 시점부터는 "메모리 절감"만으로는 충분하지 않고, step 시간을 어떻게 줄일지로 초점이 이동한다.