분산 LLM 학습 13 - Activation Checkpointing과 Recomputation의 트레이드오프
메모리를 아끼기 위해 계산을 다시 하는 전략은 단순한 옵션이 아니라 분산 학습 설계의 중심 선택지다
왜 다시 계산하는가
training에서는 backward를 위해 forward 중간 결과를 저장해야 한다. 그런데 activation이 너무 크면 메모리 한계에 먼저 걸린다. activation checkpointing은 일부 중간 결과를 저장하지 않고, backward 시점에 다시 계산해서 메모리를 아끼는 전략이다.
즉 시간을 더 쓰고 메모리를 덜 쓰는 선택이다.
어떤 감각으로 이해해야 할까
checkpointing을 단순히 "메모리 최적화 옵션"으로 보면 얕다. 실제로는 다음 셋 중 무엇이 병목인지 판단하는 문제다.
- memory capacity
- extra compute cost
- schedule complexity
메모리가 절대적으로 부족하면 recomputation이 거의 필수다. 반대로 compute가 이미 꽉 차 있고 메모리는 여유가 있다면 재계산은 손해일 수 있다.
transformer에서는 어디에 적용하나
대개 block 단위 또는 attention/MLP 묶음 단위로 적용한다. 너무 잘게 쪼개면 recomputation 오버헤드와 코드 복잡성이 커지고, 너무 크게 묶으면 메모리 절감이 제한적이다.
즉 checkpoint granularity가 중요하다.
분산 환경에서 더 까다로운 이유
checkpointing은 단일 GPU에서도 중요하지만, 분산 환경에서는 다음 문제가 겹친다.
- pipeline schedule과 activation lifetime이 얽힌다
- tensor parallel로 나뉜 activation을 다시 계산해야 한다
- communication과 recomputation이 함께 들어가면 step time 해석이 어려워진다
그래서 "메모리를 줄였다"만 보면 안 되고, 전체 throughput과 utilization을 함께 봐야 한다.
좋은 판단 기준
- OOM을 피하는 최소 수준의 checkpointing부터 시작한다
- recomputation으로 늘어난 시간이 통신 병목에 가려지는지 확인한다
- 어느 layer가 activation hotspot인지 profiler로 본다
특히 긴 context 학습에서는 checkpointing이 사실상 기본 전제가 되는 경우가 많다.
다음 글에서는 메모리 절감 전략의 대표 주자인 ZeRO를 본다. optimizer state, gradient, parameter를 단계적으로 나누는 방식이 왜 큰 전환점이 되었는지 정리한다.