113 lines
3.0 KiB
Python
Executable File
113 lines
3.0 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.parameter import Parameter
|
|
|
|
import cirtorch.layers.functional as LF
|
|
from cirtorch.layers.normalization import L2N
|
|
|
|
# --------------------------------------
|
|
# Pooling layers
|
|
# --------------------------------------
|
|
|
|
class MAC(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(MAC,self).__init__()
|
|
|
|
def forward(self, x):
|
|
return LF.mac(x)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '()'
|
|
|
|
|
|
class SPoC(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(SPoC,self).__init__()
|
|
|
|
def forward(self, x):
|
|
return LF.spoc(x)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '()'
|
|
|
|
|
|
class GeM(nn.Module):
|
|
|
|
def __init__(self, p=3, eps=1e-6):
|
|
super(GeM,self).__init__()
|
|
self.p = Parameter(torch.ones(1)*p)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return LF.gem(x, p=self.p, eps=self.eps)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
|
|
|
|
class GeMmp(nn.Module):
|
|
|
|
def __init__(self, p=3, mp=1, eps=1e-6):
|
|
super(GeMmp,self).__init__()
|
|
self.p = Parameter(torch.ones(mp)*p)
|
|
self.mp = mp
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return LF.gem(x, p=self.p.unsqueeze(-1).unsqueeze(-1), eps=self.eps)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' + 'p=' + '[{}]'.format(self.mp) + ', ' + 'eps=' + str(self.eps) + ')'
|
|
|
|
class RMAC(nn.Module):
|
|
|
|
def __init__(self, L=3, eps=1e-6):
|
|
super(RMAC,self).__init__()
|
|
self.L = L
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return LF.rmac(x, L=self.L, eps=self.eps)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')'
|
|
|
|
|
|
class Rpool(nn.Module):
|
|
|
|
def __init__(self, rpool, whiten=None, L=3, eps=1e-6):
|
|
super(Rpool,self).__init__()
|
|
self.rpool = rpool
|
|
self.L = L
|
|
self.whiten = whiten
|
|
self.norm = L2N()
|
|
self.eps = eps
|
|
|
|
def forward(self, x, aggregate=True):
|
|
# features -> roipool
|
|
o = LF.roipool(x, self.rpool, self.L, self.eps) # size: #im, #reg, D, 1, 1
|
|
|
|
# concatenate regions from all images in the batch
|
|
s = o.size()
|
|
o = o.view(s[0]*s[1], s[2], s[3], s[4]) # size: #im x #reg, D, 1, 1
|
|
|
|
# rvecs -> norm
|
|
o = self.norm(o)
|
|
|
|
# rvecs -> whiten -> norm
|
|
if self.whiten is not None:
|
|
o = self.norm(self.whiten(o.squeeze(-1).squeeze(-1)))
|
|
|
|
# reshape back to regions per image
|
|
o = o.view(s[0], s[1], s[2], s[3], s[4]) # size: #im, #reg, D, 1, 1
|
|
|
|
# aggregate regions into a single global vector per image
|
|
if aggregate:
|
|
# rvecs -> sumpool -> norm
|
|
o = self.norm(o.sum(1, keepdim=False)) # size: #im, D, 1, 1
|
|
|
|
return o
|
|
|
|
def __repr__(self):
|
|
return super(Rpool, self).__repr__() + '(' + 'L=' + '{}'.format(self.L) + ')' |