반응형
이미 이 블로그에는 pretrained weight 가 제대로 불렸는지 확인하는 예시코드를 올린적 있다.
[https://powerofsummary.tistory.com/277]
하지만 기존에 올렸던 포스팅에서는 weights를 그대로 출력해서 그냥 눈으로 숫자들을 하나하나 확인해야만 하는 코드였다.
근데 그것마저도 너무 귀찮아서 그냥 쪼금 더 세련된 코드를 올려서 내가 나중에 다시 써먹으려고 한다.
(매우 세련된 코드는 아님;)
2개의 code 를 추가해야하는데, 각각 pretrained weight 를 가져오는 코드의 앞 뒤에 붙여준다. 코드는 아래와 같다.
# For checking loaded pretrained weight1
weight_dict = {}
for i, (name, param) in enumerate(model.named_parameters()):
weight_dict[name] = param.detach().clone()
# ** Load pretrained weights **
# ** Load pretrained weights **
# ** Load pretrained weights **
# For checking loaded pretrained weight2
for i, (name, param) in enumerate(model.named_parameters()):
if torch.sum(weight_dict[name] - param) == 0:
print(f'{name} : Not changed')
else:
print(f'{name} : changed')
set_trace()
정리하자면 weights를 부르기 전 model 의 weight 와 pretrained weights를 부른 이후의 model weight 를 비교하는 코드다. 전과 후의 값의 차이가 0이면 weights가 불리지 않은 것이고, 0이 아니면 weights 가 새로 업데이트 됐다는 사실을 이용한 코드이다.
반응형
'연구 > pytorch' 카테고리의 다른 글
[pytorch] model 학습 중 nan이 뜨는 원인2 (0) | 2023.09.21 |
---|---|
[pytorch] Pretrained model의 일부 weights만 가져오기 (0) | 2023.04.07 |
[Transformer] torch.nn.MultiheadAttention 모듈의 mask 인자 개념 (0) | 2023.01.20 |
[npz] npz 데이터에서 keys 확인하기 (0) | 2022.07.19 |
[model freeze] layer 의 일부만 freeze 하기 (2) | 2022.02.13 |