[PyTorch] 딥러닝 학습 오류 탐지 : torch.autograd.detect_anomaly
by Heejin Do
torch.autograd는 함수의 자동 미분을 구현하기 위한 클래스들과 함수들을 제공한다.
유용하게 사용 할 수 있는 여러 함수가 있지만, 이 포스트에서는, 이상(오류)이 발생했을 때 그 추적 log를 출력해주는 detect_anomaly 클래스에 대해 정리해보았다.
torch.autograd.detect_anomaly
autograd 엔진에 대한 오류를 감지하는 context manager로, 감지를 활성화 한 상태에서 정방향 패스(forward pass)를 실행했을 때, 역방향 패스(backward pass) 중 오류가 발생한 backward function을 생성한 forward operation의 traceback을 출력해준다.
즉, backward 패스 과정에서 오류가 발생 했을 때, 구체적으로 어떤 파일의 어떤 연산에서 발생했는지 그 traceback을 출력해준다.
단, 감지를 활성화하면 추가 테스트로 인해 프로그램 실행 속도가 느려지기 때문에 디버깅을 위해서만 사용하는 것이 좋다(NaN, inf 가 생성되는지를 검사하기 위해 모든 텐서를 확인하기 때문에 훈련 속도가 느려짐).
torch.autograd.set_detect_anomaly(mode)
mode에 따라 이상 감지를 활성화하거나 비활성화 할 수 있는 context manager. mode로 True를 지정하면 이상 감지를 설정하는 것이고, False를 지정하면 감지 설정을 해제하는 것이다. 즉 mode 인수를 통해 위의 torch.autograd.detect_anomaly
를 활성화할지 비활성화 할지를 설정할 수 있는 클래스이며, torch.autograd.set_detect_anomaly(True)
의 역할은 torch.autograd.detect_anomaly
와 같다.
사용 예시
1. 첫번째 방법 : 전체 코드에 적용
학습 시작 전, 스크립트 제일 위에 추가한다.
2. 두번째 방법 : 특정 부분에 적용
적용하려는 부분을 with autograd.detect_anomaly():
하위에 둔다.
Subscribe via RSS