import torch import torch.nn as nn from torch.nn.parameter import Parameter import cirtorch.layers.functional as LF from cirtorch.layers.normalization import L2N # -------------------------------------- # Pooling layers # -------------------------------------- class MAC(nn.Module): def __init__(self): super(MAC,self).__init__() def forward(self, x): return LF.mac(x) def __repr__(self): return self.__class__.__name__ + '()' class SPoC(nn.Module): def __init__(self): super(SPoC,self).__init__() def forward(self, x): return LF.spoc(x) def __repr__(self): return self.__class__.__name__ + '()' class GeM(nn.Module): def __init__(self, p=3, eps=1e-6): super(GeM,self).__init__() self.p = Parameter(torch.ones(1)*p) self.eps = eps def forward(self, x): return LF.gem(x, p=self.p, eps=self.eps) def __repr__(self): return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' class GeMmp(nn.Module): def __init__(self, p=3, mp=1, eps=1e-6): super(GeMmp,self).__init__() self.p = Parameter(torch.ones(mp)*p) self.mp = mp self.eps = eps def forward(self, x): return LF.gem(x, p=self.p.unsqueeze(-1).unsqueeze(-1), eps=self.eps) def __repr__(self): return self.__class__.__name__ + '(' + 'p=' + '[{}]'.format(self.mp) + ', ' + 'eps=' + str(self.eps) + ')' class RMAC(nn.Module): def __init__(self, L=3, eps=1e-6): super(RMAC,self).__init__() self.L = L self.eps = eps def forward(self, x): return LF.rmac(x, L=self.L, eps=self.eps) def __repr__(self): return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')' class Rpool(nn.Module): def __init__(self, rpool, whiten=None, L=3, eps=1e-6): super(Rpool,self).__init__() self.rpool = rpool self.L = L self.whiten = whiten self.norm = L2N() self.eps = eps def forward(self, x, aggregate=True): # features -> roipool o = LF.roipool(x, self.rpool, self.L, self.eps) # size: #im, #reg, D, 1, 1 # concatenate regions from all images in the batch s = o.size() o = o.view(s[0]*s[1], s[2], s[3], s[4]) # size: #im x #reg, D, 1, 1 # rvecs -> norm o = self.norm(o) # rvecs -> whiten -> norm if self.whiten is not None: o = self.norm(self.whiten(o.squeeze(-1).squeeze(-1))) # reshape back to regions per image o = o.view(s[0], s[1], s[2], s[3], s[4]) # size: #im, #reg, D, 1, 1 # aggregate regions into a single global vector per image if aggregate: # rvecs -> sumpool -> norm o = self.norm(o.sum(1, keepdim=False)) # size: #im, D, 1, 1 return o def __repr__(self): return super(Rpool, self).__repr__() + '(' + 'L=' + '{}'.format(self.L) + ')'