분산 LLM 학습 15 - FSDP는 DDP와 무엇이 다르고 언제 유리한가
FSDP는 전체 파라미터를 shard한 채 필요할 때만 모아 쓰는 방식으로 메모리 문제를 직접 겨냥한다
FSDP를 DDP의 대체제로만 보면 부족하다
FSDP(Fully Sharded Data Parallel)는 이름만 보면 DDP의 변형처럼 보이지만, 메모리 관점에서는 꽤 다른 철학을 가진다. 핵심은 전체 파라미터를 shard한 상태로 두고, 각 구간 계산에 필요한 순간에만 full parameter를 materialize하는 것이다.
즉 목표는 분명하다.
- parameter replication 줄이기
- gradient와 optimizer state도 shard하기
- 큰 모델을 더 적은 per-rank memory로 학습하기
어떤 흐름으로 동작하는가
고수준에서 보면 FSDP는 다음과 비슷한 흐름을 가진다.
- 모듈 단위로 파라미터를 shard 상태로 보관한다
- forward 직전에 필요한 파라미터를 all-gather한다
- 연산 후 가능하면 다시 shard 상태로 되돌린다
- backward에서도 필요한 시점에 gather/reduce-scatter를 수행한다
이 구조 덕분에 peak memory를 크게 낮출 수 있다.
장점과 비용
장점:
- 큰 모델도 per-rank 메모리 한계 안에서 돌리기 쉬워진다
- PyTorch ecosystem 안에서 통합하기 좋다
비용:
- gather/scatter 통신이 증가한다
- wrapping granularity가 성능에 영향을 준다
- auto-wrap 정책을 잘못 잡으면 효율이 떨어진다
즉 FSDP는 메모리 문제를 잘 푸는 대신 runtime 설계가 더 섬세해진다.
어떤 상황에서 유리한가
- 모델이 rank당 메모리에 빠듯하다
- full replication이 너무 비싸다
- PyTorch 중심 스택 안에서 운영하고 싶다
반대로 intra-node / inter-node 통신이 매우 불리하면 기대만큼 이득이 안 나올 수 있다.
다음 글에서는 compute와 communication을 겹치는 overlap 기법을 본다. 이 시점부터는 "메모리 절감"만으로는 충분하지 않고, step 시간을 어떻게 줄일지로 초점이 이동한다.