torch.autograd는 함수의 자동 미분을 구현하기 위한 클래스들과 함수들을 제공한다.

유용하게 사용 할 수 있는 여러 함수가 있지만, 이 포스트에서는, 이상(오류)이 발생했을 때 그 추적 log를 출력해주는 detect_anomaly 클래스에 대해 정리해보았다.

torch.autograd.detect_anomaly

autograd 엔진에 대한 오류를 감지하는 context manager로, 감지를 활성화 한 상태에서 정방향 패스(forward pass)를 실행했을 때, 역방향 패스(backward pass) 중 오류가 발생한 backward function을 생성한 forward operation의 traceback을 출력해준다.

즉, backward 패스 과정에서 오류가 발생 했을 때, 구체적으로 어떤 파일의 어떤 연산에서 발생했는지 그 traceback을 출력해준다.

단, 감지를 활성화하면 추가 테스트로 인해 프로그램 실행 속도가 느려지기 때문에 디버깅을 위해서만 사용하는 것이 좋다(NaN, inf 가 생성되는지를 검사하기 위해 모든 텐서를 확인하기 때문에 훈련 속도가 느려짐).

torch.autograd.set_detect_anomaly(mode)

mode에 따라 이상 감지를 활성화하거나 비활성화 할 수 있는 context manager. mode로 True를 지정하면 이상 감지를 설정하는 것이고, False를 지정하면 감지 설정을 해제하는 것이다. 즉 mode 인수를 통해 위의 torch.autograd.detect_anomaly를 활성화할지 비활성화 할지를 설정할 수 있는 클래스이며, torch.autograd.set_detect_anomaly(True)의 역할은 torch.autograd.detect_anomaly와 같다.

사용 예시

1. 첫번째 방법 : 전체 코드에 적용

학습 시작 전, 스크립트 제일 위에 추가한다.

torch.autograd.set_detect_anomaly(True)

# 아래부턴 실행하려는 기존 코드~~

2. 두번째 방법 : 특정 부분에 적용

적용하려는 부분을 with autograd.detect_anomaly(): 하위에 둔다.

with torch.autograd.detect_anomaly():
        input = torch.rand(5, 10, requires_grad=True)
        output = function_A(input)
        output.backward()

# 아래부턴 실행하려는 기존 코드~~~