PyTorch 내부 구조 12 - Backward 구현 패턴과 저장 전략
backward는 forward의 덧붙임이 아니라 어떤 중간값을 저장하고 어떤 계산을 다시 할지 결정하는 설계 문제다
backward를 어떻게 설계할까
custom op를 만들면 결국 결정해야 한다.
- 어떤 값을 forward에서 저장할 것인가
- 어떤 값은 backward에서 다시 계산할 것인가
- numerical stability는 어디서 지킬 것인가
이 선택은 메모리와 성능을 동시에 좌우한다.
대표 패턴
- 입력/출력을 저장하고 backward에서 활용
- 최소한만 저장하고 일부를 recompute
- fused backward를 별도 kernel로 작성
이 셋은 각각 메모리, 구현 난이도, 성능에서 다른 트레이드오프를 가진다.
다음 글에서는 fused operator를 설계하는 관점으로 넘어간다. 여러 op를 합치는 이유는 단순히 kernel 수를 줄이기 위해서만은 아니다.