본문 바로가기

연구/pytorch

[pytorch] Loading pretrained weight 가 제대로 됐는지 확인하는 법

반응형

이미 이 블로그에는 pretrained weight 가 제대로 불렸는지 확인하는 예시코드를 올린적 있다.

[https://powerofsummary.tistory.com/277]

 

하지만 기존에 올렸던 포스팅에서는 weights를 그대로 출력해서 그냥 눈으로 숫자들을 하나하나 확인해야만 하는 코드였다.

근데 그것마저도 너무 귀찮아서 그냥 쪼금 더 세련된 코드를 올려서 내가 나중에 다시 써먹으려고 한다.

(매우 세련된 코드는 아님;)

 

2개의 code 를 추가해야하는데, 각각 pretrained weight 를 가져오는 코드의 앞 뒤에 붙여준다. 코드는 아래와 같다.

# For checking loaded pretrained weight1
    weight_dict = {}
    for i, (name, param) in enumerate(model.named_parameters()):
        weight_dict[name] = param.detach().clone()

# ** Load pretrained weights **
# ** Load pretrained weights **
# ** Load pretrained weights **

# For checking loaded pretrained weight2
    for i, (name, param) in enumerate(model.named_parameters()):
        if torch.sum(weight_dict[name] - param) == 0:
            print(f'{name} : Not changed')
        else:
            print(f'{name} : changed')

    set_trace()

정리하자면 weights를 부르기 전 model 의 weight 와 pretrained weights를 부른 이후의 model weight 를 비교하는 코드다. 전과 후의 값의 차이가 0이면 weights가 불리지 않은 것이고, 0이 아니면 weights 가 새로 업데이트 됐다는 사실을 이용한 코드이다.

 

반응형