first push
This commit is contained in:
74
nts/test.py
Normal file
74
nts/test.py
Normal file
@ -0,0 +1,74 @@
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.data
|
||||
from torch.nn import DataParallel
|
||||
from config import BATCH_SIZE, PROPOSAL_NUM, test_model
|
||||
from core import model, dataset
|
||||
from core.utils import progress_bar
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
if not test_model:
|
||||
raise NameError('please set the test_model file to choose the checkpoint!')
|
||||
# read dataset
|
||||
trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=8, drop_last=False)
|
||||
testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=8, drop_last=False)
|
||||
# define model
|
||||
net = model.attention_net(topN=PROPOSAL_NUM)
|
||||
ckpt = torch.load(test_model)
|
||||
net.load_state_dict(ckpt['net_state_dict'])
|
||||
net = net.cuda()
|
||||
net = DataParallel(net)
|
||||
creterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# evaluate on train set
|
||||
train_loss = 0
|
||||
train_correct = 0
|
||||
total = 0
|
||||
net.eval()
|
||||
|
||||
for i, data in enumerate(trainloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
train_correct += torch.sum(concat_predict.data == label.data)
|
||||
train_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(trainloader), 'eval on train set')
|
||||
|
||||
train_acc = float(train_correct) / total
|
||||
train_loss = train_loss / total
|
||||
print('train set loss: {:.3f} and train set acc: {:.3f} total sample: {}'.format(train_loss, train_acc, total))
|
||||
|
||||
|
||||
# evaluate on test set
|
||||
test_loss = 0
|
||||
test_correct = 0
|
||||
total = 0
|
||||
for i, data in enumerate(testloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
test_correct += torch.sum(concat_predict.data == label.data)
|
||||
test_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(testloader), 'eval on test set')
|
||||
|
||||
test_acc = float(test_correct) / total
|
||||
test_loss = test_loss / total
|
||||
print('test set loss: {:.3f} and test set acc: {:.3f} total sample: {}'.format(test_loss, test_acc, total))
|
||||
|
||||
print('finishing testing')
|
Reference in New Issue
Block a user