first push
This commit is contained in:
152
nts/train.py
Normal file
152
nts/train.py
Normal file
@ -0,0 +1,152 @@
|
||||
import os
|
||||
import torch.utils.data
|
||||
from torch.nn import DataParallel
|
||||
from datetime import datetime
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from config import BATCH_SIZE, PROPOSAL_NUM, SAVE_FREQ, LR, WD, resume, save_dir
|
||||
from core import model, dataset
|
||||
from core.utils import init_log, progress_bar
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
start_epoch = 1
|
||||
save_dir = os.path.join(save_dir, datetime.now().strftime('%Y%m%d_%H%M%S'))
|
||||
if os.path.exists(save_dir):
|
||||
raise NameError('model dir exists!')
|
||||
os.makedirs(save_dir)
|
||||
logging = init_log(save_dir)
|
||||
_print = logging.info
|
||||
|
||||
# read dataset
|
||||
trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=8, drop_last=False)
|
||||
testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=8, drop_last=False)
|
||||
# define model
|
||||
net = model.attention_net(topN=PROPOSAL_NUM)
|
||||
if resume:
|
||||
ckpt = torch.load(resume)
|
||||
net.load_state_dict(ckpt['net_state_dict'])
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
creterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# define optimizers
|
||||
raw_parameters = list(net.pretrained_model.parameters())
|
||||
part_parameters = list(net.proposal_net.parameters())
|
||||
concat_parameters = list(net.concat_net.parameters())
|
||||
partcls_parameters = list(net.partcls_net.parameters())
|
||||
|
||||
raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
schedulers = [MultiStepLR(raw_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(concat_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(part_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(partcls_optimizer, milestones=[60, 100], gamma=0.1)]
|
||||
net = net.cuda()
|
||||
net = DataParallel(net)
|
||||
|
||||
for epoch in range(start_epoch, 500):
|
||||
for scheduler in schedulers:
|
||||
scheduler.step()
|
||||
|
||||
# begin training
|
||||
_print('--' * 50)
|
||||
net.train()
|
||||
for i, data in enumerate(trainloader):
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
raw_optimizer.zero_grad()
|
||||
part_optimizer.zero_grad()
|
||||
concat_optimizer.zero_grad()
|
||||
partcls_optimizer.zero_grad()
|
||||
|
||||
raw_logits, concat_logits, part_logits, _, top_n_prob = net(img)
|
||||
part_loss = model.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1),
|
||||
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM)
|
||||
raw_loss = creterion(raw_logits, label)
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
rank_loss = model.ranking_loss(top_n_prob, part_loss)
|
||||
partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1),
|
||||
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1))
|
||||
|
||||
total_loss = raw_loss + rank_loss + concat_loss + partcls_loss
|
||||
total_loss.backward()
|
||||
raw_optimizer.step()
|
||||
part_optimizer.step()
|
||||
concat_optimizer.step()
|
||||
partcls_optimizer.step()
|
||||
progress_bar(i, len(trainloader), 'train')
|
||||
|
||||
if epoch % SAVE_FREQ == 0:
|
||||
train_loss = 0
|
||||
train_correct = 0
|
||||
total = 0
|
||||
net.eval()
|
||||
for i, data in enumerate(trainloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
train_correct += torch.sum(concat_predict.data == label.data)
|
||||
train_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(trainloader), 'eval train set')
|
||||
|
||||
train_acc = float(train_correct) / total
|
||||
train_loss = train_loss / total
|
||||
|
||||
_print(
|
||||
'epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format(
|
||||
epoch,
|
||||
train_loss,
|
||||
train_acc,
|
||||
total))
|
||||
|
||||
# evaluate on test set
|
||||
test_loss = 0
|
||||
test_correct = 0
|
||||
total = 0
|
||||
for i, data in enumerate(testloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
test_correct += torch.sum(concat_predict.data == label.data)
|
||||
test_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(testloader), 'eval test set')
|
||||
|
||||
test_acc = float(test_correct) / total
|
||||
test_loss = test_loss / total
|
||||
_print(
|
||||
'epoch:{} - test loss: {:.3f} and test acc: {:.3f} total sample: {}'.format(
|
||||
epoch,
|
||||
test_loss,
|
||||
test_acc,
|
||||
total))
|
||||
|
||||
# save model
|
||||
net_state_dict = net.module.state_dict()
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'train_loss': train_loss,
|
||||
'train_acc': train_acc,
|
||||
'test_loss': test_loss,
|
||||
'test_acc': test_acc,
|
||||
'net_state_dict': net_state_dict},
|
||||
os.path.join(save_dir, '%03d.ckpt' % epoch))
|
||||
|
||||
print('finishing training')
|
Reference in New Issue
Block a user