분산 LLM 학습 16 - Communication Overlap은 어떻게 step 시간을 숨기는가
분산 학습 최적화의 핵심은 통신을 없애는 것이 아니라 계산 뒤에 숨어서 끝나게 만드는 데 있다
통신을 줄이기만 해서는 부족하다
large-scale training에서 통신을 완전히 없앨 수는 없다. 결국 현실적인 목표는 통신이 계산을 기다리게 하지 않도록 만드는 것이다. 이것이 overlap의 핵심이다.
어떤 겹침이 가능한가
대표적으로 다음 같은 조합이 있다.
- backward와 gradient all-reduce 겹치기
- parameter prefetch와 다음 연산 준비 겹치기
- reduce-scatter와 일부 optimizer 준비 작업 겹치기
DDP의 bucket도, FSDP의 prefetch도, 여러 프레임워크의 hook 기반 런타임도 결국 이 겹침을 만들기 위해 존재한다.
overlap이 잘 되지 않는 이유
말은 쉽지만 실제로는 다음 문제 때문에 overlap 효율이 낮다.
- bucket이 너무 늦게 준비된다
- kernel 길이가 너무 짧아 통신을 가릴 compute가 없다
- 작은 collective가 너무 많다
- straggler rank 때문에 일부 rank는 이미 기다리고 있다
즉 overlap은 런타임 스케줄링, 모델 구조, 네트워크 상태가 함께 맞아야 효과가 난다.
profiler에서 어떻게 보이나
좋은 overlap은 timeline에서 compute kernel과 NCCL kernel이 서로 빈틈 없이 섞여 보인다. 반대로 나쁜 overlap은 backward 뒤에 큰 통신 덩어리가 따로 붙어 있는 식으로 보인다.
이 차이는 step time에 매우 크게 반영된다.
실무적인 조정 포인트
- bucket size
- parameter ordering
- prefetch 시점
- accumulation step 수
- process placement
이 요소들은 전부 "언제 무엇이 ready 되고, 언제 통신을 시작할 수 있는가"와 연결된다.
다음 글에서는 학습이 길어질수록 중요해지는 checkpoint와 장애 복구를 본다. 큰 분산 학습은 잘 도는 것만큼 다시 살릴 수 있는지가 중요하다.