이 포스트에서는 Pytorch library에서 forward() 중에 NaN이 뜨는 경우와 loss.backward()를 하고나면 NaN이 발생하는 경우를 다룹니다.
Forward propagation 중에 NaN발생
nan은 Not a number의 준말이다.
nan이 뜨는 이유는 많이 있겠지만 다차원텐서 연산중에 날 수 있는 가장 많은 케이스는 바로 0으로 나누는 것이다.
A_norm = torch.unsqueeze(torch.norm(A_vec, dim=2), 2)
affinity_graph.append(torch.div(torch.div(num, A_norm).transpose(1, 2), A_norm).transpose(1,2)
나 같은 경우에 분석을 거치고 거치다 결국 위와 같은 코드에서 nan이 뜨는것을 발견했다.
num을 A_norm으로 행,열을 따라 각각 나눠주는 과정인데 문제는 A_norm의 원소중에 0이 포함되있는걸 몰랐다는 것이다.
따라서 해결책은?
A_norm = torch.unsqueeze(torch.norm(A_vec, dim=2), 2) + 1e-6
affinity_graph.append(torch.div(torch.div(num, A_norm).transpose(1, 2), A_norm).transpose(1,2)
다음과 같이 A_norm에 아주 작은 수를 더해서 0이 나올 수 없게 만들어 주는 방법이 있다.
여기서 더해주는 수는 1e-6보다 작아도 된다. 다만 A_norm의 값에 영향을 안줄정도로 작게 주는것이 좋다.
Backward propagation 중에 NaN발생
정확한 이유는 아직도 모르지만 나는 torch.norm()함수를 거칠때 NaN이 뜨기도하고 안뜨기도 했었다.
이해가 안가는 부분은 torch.norm()을 거쳐도 여태까지는 잘만 학습돼왔는데 갑자기 nan이 발생하는 경우가 있다는 것이다.
forward에서 nan이 뜨는 경우도 디버깅이 어렵겠지만 backward에서 nan이 뜨는 경우가 디버깅이 더 어렵다. 정확히 어느부분에서 에러가 발생하는지 알 수 없기 때문이다.
인터넷에서 검색해봤을 때는 크게 2가지가 backward중에 nan을 발생시켰다.
첫째는 exp함수를 거쳤을 때 발생할 수도 있다는 글을 몇개 보았으나 나 같은 경우에는 그게 문제가 아닌듯 했다.
둘째는 sqrt함수를 사용했을 때다. sqrt를 사용안하는데 왜 문제가 발생하지? 라고 생각하던차에 생각난것이 torch.norm()함수였다.
이 함수는 tensor의 l2 norm을 계산하는데 l2 norm계산 과정에는 sqrt사용이 필수적이기 때문이다.
수정하기 전 코드는 다음과 같다.
pred_norm = torch.norm(prediction,dim=(2,3),keepdim=True) + 1e-10 # [N,C,1,1]
assert pred_norm.shape == (nb, nc, 1, 1)
# 각 채널마다 정규화 시켜준다. 즉, i번 채널의 feature map은 해당 feature map의 norm으로 정규화 된다.
prediction = prediction / pred_norm
assert prediction.shape == (nb, nc, H,W)
tar_norm = torch.norm(target, dim=(2, 3),keepdim=True) + 1e-10 # [N,C,1,1]
assert tar_norm.shape == (nb, nc, 1, 1)
# 각 채널마다 정규화 시켜준다. 즉, i번 채널의 feature map은 해당 feature map의 norm으로 정규화 된다.
target = target / tar_norm
assert target.shape == (nb, nc, H, W)
혹시 될수도 있지 않을까 하는 마음에 l2 norm을 계산하는 부분을 직접 수정해서 다음과 같이 만들었다.
# L2 loss를 계산한다.
pred_l2_norm = torch.sqrt(prediction.pow(2).sum(dim=(2,3),keepdim=True) + 1e-6)
assert pred_l2_norm.shape == (nb, nc, 1, 1)
prediction = prediction / pred_l2_norm
assert prediction.shape == (nb, nc, H, W)
tar_l2_norm = torch.sqrt(target.pow(2).sum(dim=(2,3),keepdim=True) + 1e-6)
assert tar_l2_norm.shape == (nb, nc, 1, 1)
# 각 채널마다 정규화 시켜준다. 즉, i번 채널의 feature map은 해당 feature map의 norm으로 정규화 된다.
target = target / tar_l2_norm
assert target.shape == (nb, nc, H, W)
결과적으로는 학습이 아주 잘되는 이상한 현상을 발견했다(;;)
2021/03/13/토/새벽코딩...
back propagation중에 nan뜨는 에러가 다시 발생하여 또 다시 에러 케이스와 해결책을 첨부한다.
우선 back propagation중에 nan이 떴다는 것을 알 수 있는 방법은 가장 처음 epoch에서는 nan이 뜨지 않았다는 점이다. 만약 forward과정에서 문제가 있었다면 가장 첫 epoch부터 nan이 발생하여야 한다.
자 이제 바로 본론으로 넘어가서 어느 부분이 문제였는지 살펴보자.
이번에도 어김없이 sqrt가 문제였다!!!! (back propagation nan범인은 exp, sqrt 둘중 하나!!!)
문제의 코드
자, 위의 캡쳐만 보고도 어디를 어떻게 고쳐야하는지 감이 오는가? (만약 바로 알겠다면 당신은 적어도 중수이상의 실력자 ㅎㅎ)
sqrt를 해주는 대상이 0이 나오지 않게 하는것이 중요하다.
따라서 sqrt안에 매우매우 작은 값을 넣어줘서 0이 나오지 않게하고 동시에 계산에 큰 영향을 미치지 않도록 한다.
해결 코드