본문 바로가기

연구/pytorch

[model 저장] 훈련중인/훈련이 완료된 모델 저장하기

반응형

실험을 하다보면 훈련이 완료되거나 훈련중인 모델을 저장해야할 일이 생긴다. 

아래는 모델을 저장하는 예시코드이다.

res20 = ResNet(20)
res20.to("cuda:0")

train_net(res20,trainloader, testloader, n_iter=80 ,device="cuda:0", lr=0.1, train_err=res20.train_err, val_err=res20.test_err)

# res20 저장
torch.save(res20,'./model_res20')

 

아래는 저장한 모델을 불러오는 코드이다.

모델을 불러오면 바로 그 상태에서 training을 계속할 수 도 있고 평가할 수 도 있으며, 모델에 구현되어 있는 기능을 자유롭게 사용할 수 도 있다.

model = my_resnet.ResNet(20) # load 하기전에 먼저 선언을 해야한다.
model = torch.load('./res20_model') # load 한다.
print(model.test_err)
반응형