from torch import nn import torch import torch.nn.functional as F from torch.autograd import Variable from core import resnet import numpy as np from core.anchors import generate_default_anchor_maps, hard_nms from config import CAT_NUM, PROPOSAL_NUM class ProposalNet(nn.Module): def __init__(self): super(ProposalNet, self).__init__() self.down1 = nn.Conv2d(2048, 128, 3, 1, 1) self.down2 = nn.Conv2d(128, 128, 3, 2, 1) self.down3 = nn.Conv2d(128, 128, 3, 2, 1) self.ReLU = nn.ReLU() self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0) self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0) self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0) def forward(self, x): batch_size = x.size(0) d1 = self.ReLU(self.down1(x)) d2 = self.ReLU(self.down2(d1)) d3 = self.ReLU(self.down3(d2)) t1 = self.tidy1(d1).view(batch_size, -1) t2 = self.tidy2(d2).view(batch_size, -1) t3 = self.tidy3(d3).view(batch_size, -1) return torch.cat((t1, t2, t3), dim=1) class attention_net(nn.Module): def __init__(self, topN=4): super(attention_net, self).__init__() self.pretrained_model = resnet.resnet50(pretrained=True) self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1) self.pretrained_model.fc = nn.Linear(512 * 4, 200) self.proposal_net = ProposalNet() self.topN = topN self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200) self.partcls_net = nn.Linear(512 * 4, 200) _, edge_anchors, _ = generate_default_anchor_maps() self.pad_side = 224 self.edge_anchors = (edge_anchors + 224).astype(np.int) def forward(self, x): resnet_out, rpn_feature, feature = self.pretrained_model(x) x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0) batch = x.size(0) # we will reshape rpn to shape: batch * nb_anchor rpn_score = self.proposal_net(rpn_feature.detach()) all_cdds = [ np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1) for x in rpn_score.data.cpu().numpy()] top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds] top_n_cdds = np.array(top_n_cdds) top_n_index = top_n_cdds[:, :, -1].astype(np.int) top_n_index = torch.from_numpy(top_n_index).cuda() top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index) part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda() for i in range(batch): for j in range(self.topN): [y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int) part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear', align_corners=True) part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224) _, _, part_features = self.pretrained_model(part_imgs.detach()) part_feature = part_features.view(batch, self.topN, -1) part_feature = part_feature[:, :CAT_NUM, ...].contiguous() part_feature = part_feature.view(batch, -1) # concat_logits have the shape: B*200 concat_out = torch.cat([part_feature, feature], dim=1) concat_logits = self.concat_net(concat_out) raw_logits = resnet_out # part_logits have the shape: B*N*200 part_logits = self.partcls_net(part_features).view(batch, self.topN, -1) return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob] def list_loss(logits, targets): temp = F.log_softmax(logits, -1) loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))] return torch.stack(loss) def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM): loss = Variable(torch.zeros(1).cuda()) batch_size = score.size(0) for i in range(proposal_num): targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor) pivot = score[:, i].unsqueeze(1) loss_p = (1 - pivot + score) * targets_p loss_p = torch.sum(F.relu(loss_p)) loss += loss_p return loss / batch_size