본문 바로가기

연구/pytorch

[pytorch] Pretrained model의 일부 weights만 가져오기

반응형

Pretrained model과 fine-tuning 할 모델의 구조가 다음과 같이 생겼다고 가정한다.

여기서 pretrained model의 backbone weights만 가져와서 학습하고 싶다면 코드를 어떻게 짜야할까?
별로 어렵지는 않지만 매번 새로 짜기 귀찮아서 과거에 내가 짰던 코드를 올려두고 계속 참고하고자 한다.

우선 pretrained model의 checkpoints를 불러와서 어떻게 생겼는지 확인해본다.

state_dict = torch.load(args.resume_checkpoint, map_location=cpu_device)
print(state_dict.keys())

결과화면

(['backbone.0.weight', 'backbone.1.weight', 'backbone.1.bias', 'backbone.1.running_mean', 'backbone.1.running_var', 'backbone.1.num_batches_tracked', 'backbone.4.0.conv1.weight', 'backbone.4.0.bn1.weight', 'backbone.4.0.bn1.bias', 
...
...
...
...
'transformer_2.decoder.layers.0.linear2.weight', 'transformer_2.decoder.layers.0.linear2.bias', 'transformer_2.decoder.layers.0.norm1.weight', 'transformer_2.decoder.layers.0.norm1.bias', 'transformer_2.decoder.layers.0.norm2.weight', 'transformer_2.decoder.layers.0.norm2.bias', 'transformer_2.decoder.layers.0.norm3.weight', 'transformer_2.decoder.layers.0.norm3.bias', 'dim_reduce_enc_cam.weight', 'dim_reduce_enc_cam.bias', 'dim_reduce_enc_img.weight', 'dim_reduce_enc_img.bias', 'dim_reduce_dec.weight', 'dim_reduce_dec.bias', 'cam_token_embed.weight', 'joint_token_embed.weight', 'joint2d_regressor.weight', 'joint2d_regressor.bias', 'cam_predictor.weight', 'cam_predictor.bias', 'conv_1x1.weight', 'conv_1x1.bias'])

이제 이 weight들 중에서 backbone. 으로 시작하는 weights 들만 fine-tuning할 모델에 넣을려고 한다.
weights는 아래의 코드로 넣을 수 있다.

model.backbone.load_state_dict(state_dict, strict=False)

근데 이 코드를 바로 돌리면 문제가 생기는게, 위의 state_dict 구성물들은 보면 'backbone.' 으로 파라미터 이름이 시작 되지만 model.backbone.load_state_dict 로 바로 넣으려면 'backbone.'이라는 이름을 없애줘야 한다.

따라서 state_dict의 key들의 이름을 아래의 코드로 먼저 정리해준다.

for key in list(state_dict.keys()):
    new_key = key.replace("backbone.", "")
    state_dict[new_key] = state_dict.pop(key)

아래는 정리된 실행 코드이다.

state_dict = torch.load(args.resume_checkpoint, map_location=cpu_device)

for key in list(state_dict.keys()):
    new_key = key.replace("backbone.", "")
    state_dict[new_key] = state_dict.pop(key)

model.backbone.load_state_dict(state_dict, strict=False)

근데 코드를 짜는 단계에서는 내 코드가 원하는대로 돌아가는지 확인하는게 중요하므로, 정말 model.backbone의 파라미터가 바뀌었는지까지 확인해주자.
아래는 weight initialize하기 앞뒤로 모델 파라미터의 일부를 출력까지 하는 코드다.

# 여기서 backbone 의 weight출력해보고,
for i, (name, param) in enumerate(model.backbone.named_parameters()):
    if i < 2:
        print(f'[{name}] : {param}')

set_trace()

for key in list(state_dict.keys()):
    new_key = key.replace("backbone.", "")
    state_dict[new_key] = state_dict.pop(key)

model.backbone.load_state_dict(state_dict, strict=False)

# 여기서 다시 backbone의 weight를 출력해본다.
for i, (name, param) in enumerate(model.backbone.named_parameters()):
    if i < 2:
        print(f'[{name}] : {param}')

set_trace()

참고 링크 :
https://powerofsummary.tistory.com/264

반응형