분산 LLM 학습 14 - ZeRO Stage 1, 2, 3는 각각 무엇을 없애는가
ZeRO는 하나의 기술이 아니라 어떤 메모리 복제를 줄일 것인지 단계적으로 선택하는 체계다
ZeRO를 단계별로 보는 이유
ZeRO를 처음 접하면 stage 수가 많아서 복잡해 보인다. 하지만 본질은 단순하다. data parallel에서 rank마다 중복 보유하던 상태를 어디까지 나눌 것인가를 단계적으로 확장한 것이다.
stage 1: optimizer state sharding
가장 먼저 나누는 것은 optimizer state다. Adam 계열에서는 이 항목이 특히 크기 때문에 효과가 빠르게 나타난다.
좋은 점:
- 구현 부담이 상대적으로 낮다
- optimizer state 메모리를 줄일 수 있다
한계:
- gradients와 parameters는 여전히 복제된다
stage 2: gradient sharding
다음 단계는 gradient까지 나누는 것이다. backward 뒤에 전체 gradient를 각 rank가 다 들고 있을 필요가 없으므로 추가 메모리 절감이 가능해진다.
하지만 이 시점부터 통신 패턴과 optimizer step 동작을 더 신중히 봐야 한다.
stage 3: parameter sharding
가장 강한 단계에서는 parameter까지 shard한다. 이때부터는 data parallel의 가장 큰 약점인 full parameter replication을 직접 줄일 수 있다.
대신 대가도 크다.
- forward/backward 중 parameter gather가 필요하다
- runtime이 훨씬 복잡해진다
- 통신과 메모리의 균형을 잘 맞춰야 한다
ZeRO를 선택할 때의 질문
- 현재 메모리 병목의 주범이 무엇인가
- stage를 올릴수록 통신 증가를 감당할 수 있는가
- 모델/optimizer/cluster 조합에서 운영 복잡성이 허용되는가
즉 ZeRO는 "무조건 stage 3가 최고"라는 식으로 보면 안 된다. 어떤 상태를 중복 보유하는 것이 더 싸고, 무엇을 줄이는 것이 더 큰 이익인지 봐야 한다.
왜 FSDP와 이어서 봐야 하는가
ZeRO stage 3와 FSDP는 parameter sharding이라는 큰 방향을 공유한다. 하지만 구현 감각과 integration 방식에는 차이가 있다. 다음 글에서 FSDP를 별도로 보는 이유가 여기에 있다.