ZeRO를 단계별로 보는 이유

ZeRO를 처음 접하면 stage 수가 많아서 복잡해 보인다. 하지만 본질은 단순하다. data parallel에서 rank마다 중복 보유하던 상태를 어디까지 나눌 것인가를 단계적으로 확장한 것이다.

stage 1: optimizer state sharding

가장 먼저 나누는 것은 optimizer state다. Adam 계열에서는 이 항목이 특히 크기 때문에 효과가 빠르게 나타난다.

좋은 점:

  • 구현 부담이 상대적으로 낮다
  • optimizer state 메모리를 줄일 수 있다

한계:

  • gradients와 parameters는 여전히 복제된다

stage 2: gradient sharding

다음 단계는 gradient까지 나누는 것이다. backward 뒤에 전체 gradient를 각 rank가 다 들고 있을 필요가 없으므로 추가 메모리 절감이 가능해진다.

하지만 이 시점부터 통신 패턴과 optimizer step 동작을 더 신중히 봐야 한다.

stage 3: parameter sharding

가장 강한 단계에서는 parameter까지 shard한다. 이때부터는 data parallel의 가장 큰 약점인 full parameter replication을 직접 줄일 수 있다.

대신 대가도 크다.

  • forward/backward 중 parameter gather가 필요하다
  • runtime이 훨씬 복잡해진다
  • 통신과 메모리의 균형을 잘 맞춰야 한다

ZeRO를 선택할 때의 질문

  • 현재 메모리 병목의 주범이 무엇인가
  • stage를 올릴수록 통신 증가를 감당할 수 있는가
  • 모델/optimizer/cluster 조합에서 운영 복잡성이 허용되는가

즉 ZeRO는 "무조건 stage 3가 최고"라는 식으로 보면 안 된다. 어떤 상태를 중복 보유하는 것이 더 싸고, 무엇을 줄이는 것이 더 큰 이익인지 봐야 한다.

왜 FSDP와 이어서 봐야 하는가

ZeRO stage 3와 FSDP는 parameter sharding이라는 큰 방향을 공유한다. 하지만 구현 감각과 integration 방식에는 차이가 있다. 다음 글에서 FSDP를 별도로 보는 이유가 여기에 있다.