import torch import torch.nn as nn import torch.nn.functional as F class GeM(nn.Module): def __init__(self, p=3, eps=1e-6): super(GeM, self).__init__() self.p = nn.Parameter(torch.ones(1) * p) self.eps = eps def forward(self, x): return self.gem(x, p=self.p, eps=self.eps, stride = 2) def gem(self, x, p=3, eps=1e-6, stride = 2): return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p) def __repr__(self): return self.__class__.__name__ + \ '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \ ', ' + 'eps=' + str(self.eps) + ')' class TripletLoss(nn.Module): def __init__(self, margin): super(TripletLoss, self).__init__() self.margin = margin def forward(self, anchor, positive, negative, size_average = True): distance_positive = (anchor-positive).pow(2).sum(1) distance_negative = (anchor-negative).pow(2).sum(1) losses = F.relu(distance_negative-distance_positive+self.margin) return losses.mean() if size_average else losses.sum() if __name__ == '__main__': print('')