분산 LLM 학습 01 - 왜 LLM 학습은 분산 시스템 문제가 되는가
여러 GPU를 붙이는 순간 학습 코드는 계산만의 문제가 아니라 메모리와 통신, 장애 복구까지 포함한 시스템 문제가 된다
data parallel부터 tensor parallel, FSDP, ZeRO, 그리고 현대 LLM 학습 프레임워크까지
단일 GPU 학습에서 멀티 GPU와 대규모 모델 학습 시스템으로 넘어가고 싶은 독자.
기본적인 딥러닝 학습 경험, GPU 사용 경험, 그리고 학습 루프가 어떻게 돌아가는지에 대한 일반적인 이해.
여러 GPU를 붙이는 순간 학습 코드는 계산만의 문제가 아니라 메모리와 통신, 장애 복구까지 포함한 시스템 문제가 된다
가장 기본적인 분산 학습 방식인 data parallel은 단순해 보이지만 gradient 동기화와 메모리 복제 비용을 함께 안고 있다
분산 학습에서 가장 자주 등장하는 collective인 all-reduce를 이해해야 gradient synchronization 비용을 제대로 읽을 수 있다
DDP는 단순 래퍼가 아니라 autograd hook, gradient bucket, process group을 사용해 동기화를 조직하는 런타임이다
GPU 수를 늘리는 일은 단순한 throughput 증가가 아니라 optimizer가 보는 batch 의미를 바꾸는 일이다
파라미터만 보는 순간 분산 학습 판단을 잘못하게 된다. activation, gradient, optimizer state를 함께 봐야 한다
분산 학습 성능은 GPU 개수보다 GPU들이 어떤 링크로 연결되어 있는지에 더 크게 흔들릴 때가 많다
모델이 한 GPU에 안 들어가기 시작하면 더 이상 데이터만 나누는 것으로는 부족하고 연산 자체를 분할해야 한다
tensor parallel은 추상 개념이 아니라 attention projection, output projection, MLP 같은 구체적인 지점에 들어간다
모델 크기만 커지는 것이 아니라 컨텍스트 길이도 길어지면 activation 메모리와 통신 패턴이 다시 달라진다
모델을 레이어 단위로 여러 stage에 나누는 순간 계산 분할뿐 아니라 idle time과 stage imbalance가 핵심 문제가 된다
pipeline parallel의 효율은 레이어 분할보다 schedule 선택에 더 크게 흔들릴 때가 많다
메모리를 아끼기 위해 계산을 다시 하는 전략은 단순한 옵션이 아니라 분산 학습 설계의 중심 선택지다
ZeRO는 하나의 기술이 아니라 어떤 메모리 복제를 줄일 것인지 단계적으로 선택하는 체계다
FSDP는 전체 파라미터를 shard한 채 필요할 때만 모아 쓰는 방식으로 메모리 문제를 직접 겨냥한다
분산 학습 최적화의 핵심은 통신을 없애는 것이 아니라 계산 뒤에 숨어서 끝나게 만드는 데 있다
긴 분산 학습에서는 빠른 step만큼이나 중단 이후 안전하게 이어가는 능력이 중요하다
분산 학습 디버깅은 에러 메시지 읽기보다 어느 rank가 어떤 collective 앞에서 멈췄는지 구조적으로 좁히는 일이다
프레임워크를 이름으로 기억하기보다 어떤 병렬화와 어떤 상태 관리를 추상화하는지로 읽어야 한다
분산 학습 전략은 멋있는 기법을 조합하는 일이 아니라 현재 병목에 맞는 최소 구조를 선택하고 검증하는 일이다