분산 LLM 학습 20 - 실제 LLM 학습 스택을 설계하는 순서
분산 학습 전략은 멋있는 기법을 조합하는 일이 아니라 현재 병목에 맞는 최소 구조를 선택하고 검증하는 일이다
마지막에는 구조를 조합해야 한다
이 시리즈에서 개별 기법을 하나씩 보았지만, 실제 시스템에서는 그것들이 한꺼번에 나타난다. 그래서 마지막에는 "무엇부터 도입할 것인가"라는 설계 순서가 중요하다.
좋은 설계 순서
1. 단일 GPU 기준선을 먼저 만든다
모델, optimizer, batch, sequence length, throughput, memory peak를 먼저 측정한다. 기준선이 없으면 분산 이후 무엇이 좋아지고 나빠졌는지 읽을 수 없다.
2. data parallel로 가장 먼저 확장한다
모델이 한 GPU에 들어간다면 먼저 DDP로 간다. 이 단계에서 dataloader, global batch, all-reduce 비용, scaling efficiency를 본다.
3. 메모리 병목을 항목별로 분해한다
OOM이나 낮은 batch의 원인이 activation인지, optimizer state인지, parameter replication인지 분해한다. 여기서 checkpointing, ZeRO, FSDP, tensor parallel 중 무엇이 필요한지 결정한다.
4. topology를 고려한 병렬화 조합을 고른다
intra-node가 빠르고 inter-node가 느리다면 tensor parallel은 node 내부에 묶고, data parallel을 node 간으로 두는 식의 설계가 자주 나온다. pipeline parallel도 stage 배치와 연결된다.
5. checkpoint와 운영 경로를 초기에 만든다
resume가 불안정한 상태에서 큰 학습을 시작하면 비용이 커진다.
포트폴리오 관점에서 무엇을 만들면 좋은가
이 시리즈를 공부했다면 다음 같은 결과물이 좋다.
- 작은 transformer에 DDP와 tensor parallel을 각각 붙여 차이를 기록한 실험
- activation checkpointing / ZeRO / FSDP 선택 기준을 정리한 메모리 분석 글
- 작은 multi-node run에서 NCCL timeline과 병목을 분석한 보고서
이런 결과물은 단순한 요약보다 훨씬 설득력이 있다.
이 시리즈의 핵심 정리
Distributed Training Engineer에게 중요한 것은 특정 프레임워크 사용법보다 더 근본적인 감각이다.
- 계산, 메모리, 통신 중 무엇이 병목인지 구분하는 능력
- 병렬화 전략이 어떤 비용을 없애고 어떤 비용을 새로 만드는지 읽는 능력
- 시스템이 중단되거나 느려질 때 어디부터 의심해야 하는지 아는 능력
이 감각이 생기면 프레임워크가 달라져도 훨씬 빨리 적응할 수 있다.
다음 단계로는 PyTorch internals, GPU systems, 그리고 실제 Megatron/DeepSpeed 코드 읽기를 함께 묶어 보는 편이 좋다. 분산 학습은 결국 커널, 런타임, 프레임워크, 인프라가 한 시스템 안에서 만나는 지점이기 때문이다.