분산 LLM 학습 09 - Transformer에서 Tensor Parallel이 실제로 들어가는 위치
tensor parallel은 추상 개념이 아니라 attention projection, output projection, MLP 같은 구체적인 지점에 들어간다
추상 개념에서 block 구조로 내려오기
tensor parallel을 실제로 이해하려면 transformer block을 펼쳐놓고 보는 편이 좋다. 분산은 결국 특정 연산을 어떻게 자를지 결정하는 일이다.
transformer block에서 큰 계산은 대체로 다음에 몰려 있다.
- QKV projection
- attention output projection
- MLP up projection
- MLP down projection
이 연산들은 큰 linear layer라서 tensor parallel 후보가 되기 쉽다.
attention 쪽에서 생기는 일
QKV projection을 분할하면 각 rank가 일부 head 또는 일부 hidden slice를 담당하게 된다. 이 구조는 head 단위 분할과도 잘 맞는다. 하지만 attention output을 다음 레이어에 넘길 때는 다시 합치거나 reduce하는 단계가 필요할 수 있다.
즉 attention에서 중요한 질문은:
- head를 어떤 단위로 나눌 것인가
- softmax 이전과 이후 어느 시점에 통신이 필요한가
- sequence 길이가 길어질수록 activation 이동량이 얼마나 커지는가
MLP 쪽에서 생기는 일
MLP는 대개 hidden dimension을 크게 확장했다가 다시 줄인다. 이 구조는 tensor parallel에 특히 잘 맞는다. 첫 번째 projection을 column parallel, 두 번째 projection을 row parallel로 두는 식의 전형적인 패턴이 자주 등장한다.
이 조합이 좋은 이유는 중간 activation을 모든 rank가 다 가질 필요 없이 흘려보낼 수 있는 구조가 나오기 때문이다.
"잘 맞는 연산"과 "그렇지 않은 연산"
tensor parallel은 큰 dense matmul에는 잘 맞지만, 모든 연산이 그런 것은 아니다.
- normalization
- residual add
- 일부 indexing / masking 작업
이런 연산은 오히려 전체 흐름에서 작은 glue 역할을 하며, 통신/shape 정렬 비용을 더 신경 써야 한다.
실전 감각
transformer에서 tensor parallel을 본다는 것은 단순히 메모리를 줄이는 일이 아니다. 각 block마다 다음을 함께 보는 일이다.
- 어느 연산이 로컬로 끝나는가
- 어디서 collective가 필요한가
- activation shape가 rank마다 어떻게 나뉘는가
- backward에서 gradient 흐름이 어떤 통신을 유발하는가
이 감각이 있어야 Megatron류 코드를 읽을 때도 구조가 보인다.
다음 글에서는 tensor parallel의 확장선에 있는 sequence parallel과 context 길이 문제를 본다. 긴 시퀀스에서 activation 메모리와 통신을 어떻게 다뤄야 하는지가 핵심이다.