추상 개념에서 block 구조로 내려오기

tensor parallel을 실제로 이해하려면 transformer block을 펼쳐놓고 보는 편이 좋다. 분산은 결국 특정 연산을 어떻게 자를지 결정하는 일이다.

transformer block에서 큰 계산은 대체로 다음에 몰려 있다.

  • QKV projection
  • attention output projection
  • MLP up projection
  • MLP down projection

이 연산들은 큰 linear layer라서 tensor parallel 후보가 되기 쉽다.

attention 쪽에서 생기는 일

QKV projection을 분할하면 각 rank가 일부 head 또는 일부 hidden slice를 담당하게 된다. 이 구조는 head 단위 분할과도 잘 맞는다. 하지만 attention output을 다음 레이어에 넘길 때는 다시 합치거나 reduce하는 단계가 필요할 수 있다.

즉 attention에서 중요한 질문은:

  • head를 어떤 단위로 나눌 것인가
  • softmax 이전과 이후 어느 시점에 통신이 필요한가
  • sequence 길이가 길어질수록 activation 이동량이 얼마나 커지는가

MLP 쪽에서 생기는 일

MLP는 대개 hidden dimension을 크게 확장했다가 다시 줄인다. 이 구조는 tensor parallel에 특히 잘 맞는다. 첫 번째 projection을 column parallel, 두 번째 projection을 row parallel로 두는 식의 전형적인 패턴이 자주 등장한다.

이 조합이 좋은 이유는 중간 activation을 모든 rank가 다 가질 필요 없이 흘려보낼 수 있는 구조가 나오기 때문이다.

"잘 맞는 연산"과 "그렇지 않은 연산"

tensor parallel은 큰 dense matmul에는 잘 맞지만, 모든 연산이 그런 것은 아니다.

  • normalization
  • residual add
  • 일부 indexing / masking 작업

이런 연산은 오히려 전체 흐름에서 작은 glue 역할을 하며, 통신/shape 정렬 비용을 더 신경 써야 한다.

실전 감각

transformer에서 tensor parallel을 본다는 것은 단순히 메모리를 줄이는 일이 아니다. 각 block마다 다음을 함께 보는 일이다.

  • 어느 연산이 로컬로 끝나는가
  • 어디서 collective가 필요한가
  • activation shape가 rank마다 어떻게 나뉘는가
  • backward에서 gradient 흐름이 어떤 통신을 유발하는가

이 감각이 있어야 Megatron류 코드를 읽을 때도 구조가 보인다.

다음 글에서는 tensor parallel의 확장선에 있는 sequence parallel과 context 길이 문제를 본다. 긴 시퀀스에서 activation 메모리와 통신을 어떻게 다뤄야 하는지가 핵심이다.