분산 학습 디버깅은 왜 더 어려울까

단일 프로세스 디버깅과 달리 분산 환경에서는 한 rank의 문제로 전체가 멈춘다. 그리고 겉으로 보이는 증상은 비슷하다.

  • 모두 멈춘 것처럼 보인다
  • timeout이 발생한다
  • 어떤 rank만 OOM이 난다
  • 한 rank가 더 빨리 끝나 collective에서 desync가 난다

즉 증상보다 구조를 먼저 읽어야 한다.

가장 먼저 분리해야 할 축

  • correctness 문제인가
  • performance 문제인가
  • 시스템 환경 문제인가

예를 들어 timeout은 모델 코드 bug일 수도 있고, 통신 토폴로지 문제일 수도 있고, dataloader 병목 때문에 한 rank가 늦어진 결과일 수도 있다.

자주 보는 유형

1. collective mismatch

어떤 rank는 all-reduce에 들어갔는데 다른 rank는 아직 다른 연산 중이면 deadlock처럼 보이는 상태가 나온다.

2. uneven input / straggler

데이터 길이 분포가 고르지 않거나 I/O가 흔들리면 한 rank만 늦어지고 나머지가 collective에서 기다린다.

3. memory fragmentation / rank-local OOM

전체 평균 메모리는 괜찮아도 특정 rank에서만 peak가 다르게 나올 수 있다.

디버깅 순서

  1. 가장 작은 world size에서 재현되는지 본다
  2. deterministic하게 재현되는지 본다
  3. 어느 rank에서 먼저 멈추는지 로그와 timeline을 본다
  4. collective 직전/직후에 상태를 좁힌다
  5. data path와 communication path를 분리해서 본다

이 순서가 잡혀 있으면 문제를 훨씬 빠르게 줄일 수 있다.

다음 글에서는 Megatron-LM, DeepSpeed 같은 프레임워크를 구조적으로 읽는 법을 본다. 내부 추상화를 어떻게 나눴는지 알면 코드를 덜 두려워하게 된다.