first push
This commit is contained in:
33
nts/README.md
Normal file
33
nts/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# NTS-Net
|
||||
|
||||
This is a PyTorch implementation of the ECCV2018 paper "Learning to Navigate for Fine-grained Classification" (Ze Yang, Tiange Luo, Dong Wang, Zhiqiang Hu, Jun Gao, Liwei Wang).
|
||||
|
||||
## Requirements
|
||||
- python 3+
|
||||
- pytorch 0.4+
|
||||
- numpy
|
||||
- datetime
|
||||
|
||||
## Datasets
|
||||
Download the [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) datasets and put it in the root directory named **CUB_200_2011**, You can also try other fine-grained datasets.
|
||||
|
||||
## Train the model
|
||||
If you want to train the NTS-Net, just run ``python train.py``. You may need to change the configurations in ``config.py``. The parameter ``PROPOSAL_NUM`` is ``M`` in the original paper and the parameter ``CAT_NUM`` is ``K`` in the original paper. During training, the log file and checkpoint file will be saved in ``save_dir`` directory. You can change the parameter ``resume`` to choose the checkpoint model to resume.
|
||||
|
||||
## Test the model
|
||||
If you want to test the NTS-Net, just run ``python test.py``. You need to specify the ``test_model`` in ``config.py`` to choose the checkpoint model for testing.
|
||||
|
||||
## Model
|
||||
We also provide the checkpoint model trained by ourselves, you can download it from [here](https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing). If you test on our provided model, you will get a 87.6% test accuracy.
|
||||
|
||||
## Reference
|
||||
If you are interested in our work and want to cite it, please acknowledge the following paper:
|
||||
|
||||
```
|
||||
@inproceedings{Yang2018Learning,
|
||||
author = {Yang, Ze and Luo, Tiange and Wang, Dong and Hu, Zhiqiang and Gao, Jun and Wang, Liwei},
|
||||
title = {Learning to Navigate for Fine-grained Classification},
|
||||
booktitle = {ECCV},
|
||||
year = {2018}
|
||||
}
|
||||
```
|
10
nts/config.py
Normal file
10
nts/config.py
Normal file
@ -0,0 +1,10 @@
|
||||
BATCH_SIZE = 16
|
||||
PROPOSAL_NUM = 6
|
||||
CAT_NUM = 4
|
||||
INPUT_SIZE = (448, 448) # (w, h)
|
||||
LR = 0.001
|
||||
WD = 1e-4
|
||||
SAVE_FREQ = 1
|
||||
resume = ''
|
||||
test_model = 'model.ckpt'
|
||||
save_dir = '/data_4t/yangz/models/'
|
100
nts/core/anchors.py
Normal file
100
nts/core/anchors.py
Normal file
@ -0,0 +1,100 @@
|
||||
import numpy as np
|
||||
from config import INPUT_SIZE
|
||||
|
||||
_default_anchors_setting = (
|
||||
dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
)
|
||||
|
||||
|
||||
def generate_default_anchor_maps(anchors_setting=None, input_shape=INPUT_SIZE):
|
||||
"""
|
||||
generate default anchor
|
||||
|
||||
:param anchors_setting: all informations of anchors
|
||||
:param input_shape: shape of input images, e.g. (h, w)
|
||||
:return: center_anchors: # anchors * 4 (oy, ox, h, w)
|
||||
edge_anchors: # anchors * 4 (y0, x0, y1, x1)
|
||||
anchor_area: # anchors * 1 (area)
|
||||
"""
|
||||
if anchors_setting is None:
|
||||
anchors_setting = _default_anchors_setting
|
||||
|
||||
center_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
edge_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
anchor_areas = np.zeros((0,), dtype=np.float32)
|
||||
input_shape = np.array(input_shape, dtype=int)
|
||||
|
||||
for anchor_info in anchors_setting:
|
||||
|
||||
stride = anchor_info['stride']
|
||||
size = anchor_info['size']
|
||||
scales = anchor_info['scale']
|
||||
aspect_ratios = anchor_info['aspect_ratio']
|
||||
|
||||
output_map_shape = np.ceil(input_shape.astype(np.float32) / stride)
|
||||
output_map_shape = output_map_shape.astype(np.int)
|
||||
output_shape = tuple(output_map_shape) + (4,)
|
||||
ostart = stride / 2.
|
||||
oy = np.arange(ostart, ostart + stride * output_shape[0], stride)
|
||||
oy = oy.reshape(output_shape[0], 1)
|
||||
ox = np.arange(ostart, ostart + stride * output_shape[1], stride)
|
||||
ox = ox.reshape(1, output_shape[1])
|
||||
center_anchor_map_template = np.zeros(output_shape, dtype=np.float32)
|
||||
center_anchor_map_template[:, :, 0] = oy
|
||||
center_anchor_map_template[:, :, 1] = ox
|
||||
for scale in scales:
|
||||
for aspect_ratio in aspect_ratios:
|
||||
center_anchor_map = center_anchor_map_template.copy()
|
||||
center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5
|
||||
center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5
|
||||
|
||||
edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2.,
|
||||
center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.),
|
||||
axis=-1)
|
||||
anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3]
|
||||
center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4)))
|
||||
edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4)))
|
||||
anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1)))
|
||||
|
||||
return center_anchors, edge_anchors, anchor_areas
|
||||
|
||||
|
||||
def hard_nms(cdds, topn=10, iou_thresh=0.25):
|
||||
if not (type(cdds).__module__ == 'numpy' and len(cdds.shape) == 2 and cdds.shape[1] >= 5):
|
||||
raise TypeError('edge_box_map should be N * 5+ ndarray')
|
||||
|
||||
cdds = cdds.copy()
|
||||
indices = np.argsort(cdds[:, 0])
|
||||
cdds = cdds[indices]
|
||||
cdd_results = []
|
||||
|
||||
res = cdds
|
||||
|
||||
while res.any():
|
||||
cdd = res[-1]
|
||||
cdd_results.append(cdd)
|
||||
if len(cdd_results) == topn:
|
||||
return np.array(cdd_results)
|
||||
res = res[:-1]
|
||||
|
||||
start_max = np.maximum(res[:, 1:3], cdd[1:3])
|
||||
end_min = np.minimum(res[:, 3:5], cdd[3:5])
|
||||
lengths = end_min - start_max
|
||||
intersec_map = lengths[:, 0] * lengths[:, 1]
|
||||
intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0
|
||||
iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * (
|
||||
cdd[4] - cdd[2]) - intersec_map)
|
||||
res = res[iou_map_cur < iou_thresh]
|
||||
|
||||
return np.array(cdd_results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = hard_nms(np.array([
|
||||
[0.4, 1, 10, 12, 20],
|
||||
[0.5, 1, 11, 11, 20],
|
||||
[0.55, 20, 30, 40, 50]
|
||||
]), topn=100, iou_thresh=0.4)
|
||||
print(a)
|
77
nts/core/dataset.py
Normal file
77
nts/core/dataset.py
Normal file
@ -0,0 +1,77 @@
|
||||
import numpy as np
|
||||
import scipy.misc
|
||||
import os
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from config import INPUT_SIZE
|
||||
|
||||
|
||||
class CUB():
|
||||
def __init__(self, root, is_train=True, data_len=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target = self.train_img[index], self.train_label[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
img = transforms.Resize((600, 600), Image.BILINEAR)(img)
|
||||
img = transforms.RandomCrop(INPUT_SIZE)(img)
|
||||
img = transforms.RandomHorizontalFlip()(img)
|
||||
img = transforms.ToTensor()(img)
|
||||
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
|
||||
|
||||
else:
|
||||
img, target = self.test_img[index], self.test_label[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
img = transforms.Resize((600, 600), Image.BILINEAR)(img)
|
||||
img = transforms.CenterCrop(INPUT_SIZE)(img)
|
||||
img = transforms.ToTensor()(img)
|
||||
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = CUB(root='./CUB_200_2011')
|
||||
print(len(dataset.train_img))
|
||||
print(len(dataset.train_label))
|
||||
for data in dataset:
|
||||
print(data[0].size(), data[1])
|
||||
dataset = CUB(root='./CUB_200_2011', is_train=False)
|
||||
print(len(dataset.test_img))
|
||||
print(len(dataset.test_label))
|
||||
for data in dataset:
|
||||
print(data[0].size(), data[1])
|
96
nts/core/model.py
Normal file
96
nts/core/model.py
Normal file
@ -0,0 +1,96 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from core import resnet
|
||||
import numpy as np
|
||||
from core.anchors import generate_default_anchor_maps, hard_nms
|
||||
from config import CAT_NUM, PROPOSAL_NUM
|
||||
|
||||
|
||||
class ProposalNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ProposalNet, self).__init__()
|
||||
self.down1 = nn.Conv2d(2048, 128, 3, 1, 1)
|
||||
self.down2 = nn.Conv2d(128, 128, 3, 2, 1)
|
||||
self.down3 = nn.Conv2d(128, 128, 3, 2, 1)
|
||||
self.ReLU = nn.ReLU()
|
||||
self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0)
|
||||
self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0)
|
||||
self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size(0)
|
||||
d1 = self.ReLU(self.down1(x))
|
||||
d2 = self.ReLU(self.down2(d1))
|
||||
d3 = self.ReLU(self.down3(d2))
|
||||
t1 = self.tidy1(d1).view(batch_size, -1)
|
||||
t2 = self.tidy2(d2).view(batch_size, -1)
|
||||
t3 = self.tidy3(d3).view(batch_size, -1)
|
||||
return torch.cat((t1, t2, t3), dim=1)
|
||||
|
||||
|
||||
class attention_net(nn.Module):
|
||||
def __init__(self, topN=4):
|
||||
super(attention_net, self).__init__()
|
||||
self.pretrained_model = resnet.resnet50(pretrained=True)
|
||||
self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.pretrained_model.fc = nn.Linear(512 * 4, 200)
|
||||
self.proposal_net = ProposalNet()
|
||||
self.topN = topN
|
||||
self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200)
|
||||
self.partcls_net = nn.Linear(512 * 4, 200)
|
||||
_, edge_anchors, _ = generate_default_anchor_maps()
|
||||
self.pad_side = 224
|
||||
self.edge_anchors = (edge_anchors + 224).astype(np.int)
|
||||
|
||||
def forward(self, x):
|
||||
resnet_out, rpn_feature, feature = self.pretrained_model(x)
|
||||
x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0)
|
||||
batch = x.size(0)
|
||||
# we will reshape rpn to shape: batch * nb_anchor
|
||||
rpn_score = self.proposal_net(rpn_feature.detach())
|
||||
all_cdds = [
|
||||
np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1)
|
||||
for x in rpn_score.data.cpu().numpy()]
|
||||
top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds]
|
||||
top_n_cdds = np.array(top_n_cdds)
|
||||
top_n_index = top_n_cdds[:, :, -1].astype(np.int)
|
||||
top_n_index = torch.from_numpy(top_n_index).cuda()
|
||||
top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index)
|
||||
part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda()
|
||||
for i in range(batch):
|
||||
for j in range(self.topN):
|
||||
[y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int)
|
||||
part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear',
|
||||
align_corners=True)
|
||||
part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224)
|
||||
_, _, part_features = self.pretrained_model(part_imgs.detach())
|
||||
part_feature = part_features.view(batch, self.topN, -1)
|
||||
part_feature = part_feature[:, :CAT_NUM, ...].contiguous()
|
||||
part_feature = part_feature.view(batch, -1)
|
||||
# concat_logits have the shape: B*200
|
||||
concat_out = torch.cat([part_feature, feature], dim=1)
|
||||
concat_logits = self.concat_net(concat_out)
|
||||
raw_logits = resnet_out
|
||||
# part_logits have the shape: B*N*200
|
||||
part_logits = self.partcls_net(part_features).view(batch, self.topN, -1)
|
||||
return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob]
|
||||
|
||||
|
||||
def list_loss(logits, targets):
|
||||
temp = F.log_softmax(logits, -1)
|
||||
loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))]
|
||||
return torch.stack(loss)
|
||||
|
||||
|
||||
def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM):
|
||||
loss = Variable(torch.zeros(1).cuda())
|
||||
batch_size = score.size(0)
|
||||
for i in range(proposal_num):
|
||||
targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor)
|
||||
pivot = score[:, i].unsqueeze(1)
|
||||
loss_p = (1 - pivot + score) * targets_p
|
||||
loss_p = torch.sum(F.relu(loss_p))
|
||||
loss += loss_p
|
||||
return loss / batch_size
|
212
nts/core/resnet.py
Normal file
212
nts/core/resnet.py
Normal file
@ -0,0 +1,212 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1000):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
feature1 = x
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = nn.Dropout(p=0.5)(x)
|
||||
feature2 = x
|
||||
x = self.fc(x)
|
||||
|
||||
return x, feature1, feature2
|
||||
|
||||
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
||||
return model
|
104
nts/core/utils.py
Normal file
104
nts/core/utils.py
Normal file
@ -0,0 +1,104 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
_, term_width = os.popen('stty size', 'r').read().split()
|
||||
term_width = int(term_width)
|
||||
|
||||
TOTAL_BAR_LENGTH = 40.
|
||||
last_time = time.time()
|
||||
begin_time = last_time
|
||||
|
||||
|
||||
def progress_bar(current, total, msg=None):
|
||||
global last_time, begin_time
|
||||
if current == 0:
|
||||
begin_time = time.time() # Reset for new bar.
|
||||
|
||||
cur_len = int(TOTAL_BAR_LENGTH * current / total)
|
||||
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
||||
|
||||
sys.stdout.write(' [')
|
||||
for i in range(cur_len):
|
||||
sys.stdout.write('=')
|
||||
sys.stdout.write('>')
|
||||
for i in range(rest_len):
|
||||
sys.stdout.write('.')
|
||||
sys.stdout.write(']')
|
||||
|
||||
cur_time = time.time()
|
||||
step_time = cur_time - last_time
|
||||
last_time = cur_time
|
||||
tot_time = cur_time - begin_time
|
||||
|
||||
L = []
|
||||
L.append(' Step: %s' % format_time(step_time))
|
||||
L.append(' | Tot: %s' % format_time(tot_time))
|
||||
if msg:
|
||||
L.append(' | ' + msg)
|
||||
|
||||
msg = ''.join(L)
|
||||
sys.stdout.write(msg)
|
||||
for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
|
||||
sys.stdout.write(' ')
|
||||
|
||||
# Go back to the center of the bar.
|
||||
for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)):
|
||||
sys.stdout.write('\b')
|
||||
sys.stdout.write(' %d/%d ' % (current + 1, total))
|
||||
|
||||
if current < total - 1:
|
||||
sys.stdout.write('\r')
|
||||
else:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def format_time(seconds):
|
||||
days = int(seconds / 3600 / 24)
|
||||
seconds = seconds - days * 3600 * 24
|
||||
hours = int(seconds / 3600)
|
||||
seconds = seconds - hours * 3600
|
||||
minutes = int(seconds / 60)
|
||||
seconds = seconds - minutes * 60
|
||||
secondsf = int(seconds)
|
||||
seconds = seconds - secondsf
|
||||
millis = int(seconds * 1000)
|
||||
|
||||
f = ''
|
||||
i = 1
|
||||
if days > 0:
|
||||
f += str(days) + 'D'
|
||||
i += 1
|
||||
if hours > 0 and i <= 2:
|
||||
f += str(hours) + 'h'
|
||||
i += 1
|
||||
if minutes > 0 and i <= 2:
|
||||
f += str(minutes) + 'm'
|
||||
i += 1
|
||||
if secondsf > 0 and i <= 2:
|
||||
f += str(secondsf) + 's'
|
||||
i += 1
|
||||
if millis > 0 and i <= 2:
|
||||
f += str(millis) + 'ms'
|
||||
i += 1
|
||||
if f == '':
|
||||
f = '0ms'
|
||||
return f
|
||||
|
||||
|
||||
def init_log(output_dir):
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(message)s',
|
||||
datefmt='%Y%m%d-%H:%M:%S',
|
||||
filename=os.path.join(output_dir, 'log.log'),
|
||||
filemode='w')
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
logging.getLogger('').addHandler(console)
|
||||
return logging
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
74
nts/test.py
Normal file
74
nts/test.py
Normal file
@ -0,0 +1,74 @@
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.data
|
||||
from torch.nn import DataParallel
|
||||
from config import BATCH_SIZE, PROPOSAL_NUM, test_model
|
||||
from core import model, dataset
|
||||
from core.utils import progress_bar
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
if not test_model:
|
||||
raise NameError('please set the test_model file to choose the checkpoint!')
|
||||
# read dataset
|
||||
trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=8, drop_last=False)
|
||||
testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=8, drop_last=False)
|
||||
# define model
|
||||
net = model.attention_net(topN=PROPOSAL_NUM)
|
||||
ckpt = torch.load(test_model)
|
||||
net.load_state_dict(ckpt['net_state_dict'])
|
||||
net = net.cuda()
|
||||
net = DataParallel(net)
|
||||
creterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# evaluate on train set
|
||||
train_loss = 0
|
||||
train_correct = 0
|
||||
total = 0
|
||||
net.eval()
|
||||
|
||||
for i, data in enumerate(trainloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
train_correct += torch.sum(concat_predict.data == label.data)
|
||||
train_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(trainloader), 'eval on train set')
|
||||
|
||||
train_acc = float(train_correct) / total
|
||||
train_loss = train_loss / total
|
||||
print('train set loss: {:.3f} and train set acc: {:.3f} total sample: {}'.format(train_loss, train_acc, total))
|
||||
|
||||
|
||||
# evaluate on test set
|
||||
test_loss = 0
|
||||
test_correct = 0
|
||||
total = 0
|
||||
for i, data in enumerate(testloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
test_correct += torch.sum(concat_predict.data == label.data)
|
||||
test_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(testloader), 'eval on test set')
|
||||
|
||||
test_acc = float(test_correct) / total
|
||||
test_loss = test_loss / total
|
||||
print('test set loss: {:.3f} and test set acc: {:.3f} total sample: {}'.format(test_loss, test_acc, total))
|
||||
|
||||
print('finishing testing')
|
152
nts/train.py
Normal file
152
nts/train.py
Normal file
@ -0,0 +1,152 @@
|
||||
import os
|
||||
import torch.utils.data
|
||||
from torch.nn import DataParallel
|
||||
from datetime import datetime
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from config import BATCH_SIZE, PROPOSAL_NUM, SAVE_FREQ, LR, WD, resume, save_dir
|
||||
from core import model, dataset
|
||||
from core.utils import init_log, progress_bar
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
start_epoch = 1
|
||||
save_dir = os.path.join(save_dir, datetime.now().strftime('%Y%m%d_%H%M%S'))
|
||||
if os.path.exists(save_dir):
|
||||
raise NameError('model dir exists!')
|
||||
os.makedirs(save_dir)
|
||||
logging = init_log(save_dir)
|
||||
_print = logging.info
|
||||
|
||||
# read dataset
|
||||
trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=8, drop_last=False)
|
||||
testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=8, drop_last=False)
|
||||
# define model
|
||||
net = model.attention_net(topN=PROPOSAL_NUM)
|
||||
if resume:
|
||||
ckpt = torch.load(resume)
|
||||
net.load_state_dict(ckpt['net_state_dict'])
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
creterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# define optimizers
|
||||
raw_parameters = list(net.pretrained_model.parameters())
|
||||
part_parameters = list(net.proposal_net.parameters())
|
||||
concat_parameters = list(net.concat_net.parameters())
|
||||
partcls_parameters = list(net.partcls_net.parameters())
|
||||
|
||||
raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
schedulers = [MultiStepLR(raw_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(concat_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(part_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(partcls_optimizer, milestones=[60, 100], gamma=0.1)]
|
||||
net = net.cuda()
|
||||
net = DataParallel(net)
|
||||
|
||||
for epoch in range(start_epoch, 500):
|
||||
for scheduler in schedulers:
|
||||
scheduler.step()
|
||||
|
||||
# begin training
|
||||
_print('--' * 50)
|
||||
net.train()
|
||||
for i, data in enumerate(trainloader):
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
raw_optimizer.zero_grad()
|
||||
part_optimizer.zero_grad()
|
||||
concat_optimizer.zero_grad()
|
||||
partcls_optimizer.zero_grad()
|
||||
|
||||
raw_logits, concat_logits, part_logits, _, top_n_prob = net(img)
|
||||
part_loss = model.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1),
|
||||
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM)
|
||||
raw_loss = creterion(raw_logits, label)
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
rank_loss = model.ranking_loss(top_n_prob, part_loss)
|
||||
partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1),
|
||||
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1))
|
||||
|
||||
total_loss = raw_loss + rank_loss + concat_loss + partcls_loss
|
||||
total_loss.backward()
|
||||
raw_optimizer.step()
|
||||
part_optimizer.step()
|
||||
concat_optimizer.step()
|
||||
partcls_optimizer.step()
|
||||
progress_bar(i, len(trainloader), 'train')
|
||||
|
||||
if epoch % SAVE_FREQ == 0:
|
||||
train_loss = 0
|
||||
train_correct = 0
|
||||
total = 0
|
||||
net.eval()
|
||||
for i, data in enumerate(trainloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
train_correct += torch.sum(concat_predict.data == label.data)
|
||||
train_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(trainloader), 'eval train set')
|
||||
|
||||
train_acc = float(train_correct) / total
|
||||
train_loss = train_loss / total
|
||||
|
||||
_print(
|
||||
'epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format(
|
||||
epoch,
|
||||
train_loss,
|
||||
train_acc,
|
||||
total))
|
||||
|
||||
# evaluate on test set
|
||||
test_loss = 0
|
||||
test_correct = 0
|
||||
total = 0
|
||||
for i, data in enumerate(testloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
test_correct += torch.sum(concat_predict.data == label.data)
|
||||
test_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(testloader), 'eval test set')
|
||||
|
||||
test_acc = float(test_correct) / total
|
||||
test_loss = test_loss / total
|
||||
_print(
|
||||
'epoch:{} - test loss: {:.3f} and test acc: {:.3f} total sample: {}'.format(
|
||||
epoch,
|
||||
test_loss,
|
||||
test_acc,
|
||||
total))
|
||||
|
||||
# save model
|
||||
net_state_dict = net.module.state_dict()
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'train_loss': train_loss,
|
||||
'train_acc': train_acc,
|
||||
'test_loss': test_loss,
|
||||
'test_acc': test_acc,
|
||||
'net_state_dict': net_state_dict},
|
||||
os.path.join(save_dir, '%03d.ckpt' % epoch))
|
||||
|
||||
print('finishing training')
|
Reference in New Issue
Block a user