backward를 어떻게 설계할까

custom op를 만들면 결국 결정해야 한다.

  • 어떤 값을 forward에서 저장할 것인가
  • 어떤 값은 backward에서 다시 계산할 것인가
  • numerical stability는 어디서 지킬 것인가

이 선택은 메모리와 성능을 동시에 좌우한다.

대표 패턴

  • 입력/출력을 저장하고 backward에서 활용
  • 최소한만 저장하고 일부를 recompute
  • fused backward를 별도 kernel로 작성

이 셋은 각각 메모리, 구현 난이도, 성능에서 다른 트레이드오프를 가진다.

다음 글에서는 fused operator를 설계하는 관점으로 넘어간다. 여러 op를 합치는 이유는 단순히 kernel 수를 줄이기 위해서만은 아니다.