PyTorch 내부 구조 14 - AMP, Autocast, Numerical Stability
custom op가 실제 학습에 들어가려면 mixed precision 환경에서의 dtype 규칙과 안정성까지 고려해야 한다
빠른 operator가 곧 안전한 operator는 아니다
현실의 학습 코드는 mixed precision 위에서 돈다. 따라서 custom op도 다음을 함께 고려해야 한다.
- autocast 하에서 어떤 dtype으로 실행할지
- accumulation은 어떤 precision으로 할지
- overflow / underflow를 어떻게 피할지
자주 놓치는 부분
- fp16 입력에서 reduction stability
- normalization 계열 연산의 epsilon 처리
- forward는 괜찮지만 backward가 더 불안정한 경우
이런 지점에서 operator는 benchmark에서는 빨라도 실제 학습에서는 불안정해질 수 있다.
다음 글에서는 profiling으로 넘어간다. 지금까지 본 내부 구조를 성능 timeline 위에서 읽을 수 있어야 실제 최적화가 가능하다.