본문 바로가기

연구/pytorch

[transfer learning] 전이학습할 때 신경 쓸 부분

반응형

전이학습할 때 신경쓸 부분은 파라미터 업데이트할 부분과 업데이트하지 않을 부분을 지정하는 것이다.

신경써줄 부분은 크게 두 부분이다.

1. Network 설계할 때

2. 학습할 때

 

1. network 설계할 때 :

아래는 ResNet에서 FC만 파라미터 업데이트를 원할 때의 소스코드이다.

net = models.resnet18(pretrained=True)

# 모든 파라미터를 미분대상에서 제외
for p in net.parameters():
    p.requires_grad = False

fc_input_dim = net.fc.in_features
net.fc = nn.Linear(fc_input_dim,2)

 

 

2. 학습할 때

아래 코드에서는 fc만 업데이트를 원하기 때문에 아래처럼 optimizer를 통해서 fc의 파라미터만 업데이트한다고 지정해주어야 한다.

 


def train_net(net,train_loader,test_loader,only_fc = True,
              optimizer_cls = optim.Adam, loss_fn=nn.CrossEntropyLoss(), n_iter=10,device="cpu"):
    train_loss=[]
    train_acc=[]
    test_acc=[]

    if only_fc:
        optimizer = optimizer_cls(net.fc.parameters())
    else:
        optimizer = optimizer_cls(net.parameters())

    for epoch in range(n_iter):
        net.train()
        cnt = 0
        model_accuracy = 0.0
        running_loss = 0.0

        for i, (xx,yy) in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
            xx=xx.to(device)
            yy=yy.to(device)
            print(type(xx))
            h=net(xx)
            loss = loss_fn(h,yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, y_pred = h.max(1)
            cnt += len(xx)
            model_accuracy += (yy==y_pred).float().sum().item()
            running_loss += loss

        train_loss.append(running_loss/i)
        train_acc.append(model_accuracy/cnt)
        test_acc.append(eval_net(net,test_loader,device=device))

        print('epoch : ', epoch, 'train_loss[-1]: ', train_loss[-1],
              'train_acc[-1] : ', train_acc[-1], 'test_acc[-1] : ',test_acc[-1])
              

 

반응형