34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
class GeM(nn.Module):
|
|
def __init__(self, p=3, eps=1e-6):
|
|
super(GeM, self).__init__()
|
|
self.p = nn.Parameter(torch.ones(1) * p)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return self.gem(x, p=self.p, eps=self.eps, stride = 2)
|
|
|
|
def gem(self, x, p=3, eps=1e-6, stride = 2):
|
|
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + \
|
|
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
|
', ' + 'eps=' + str(self.eps) + ')'
|
|
|
|
class TripletLoss(nn.Module):
|
|
def __init__(self, margin):
|
|
super(TripletLoss, self).__init__()
|
|
self.margin = margin
|
|
|
|
def forward(self, anchor, positive, negative, size_average = True):
|
|
distance_positive = (anchor-positive).pow(2).sum(1)
|
|
distance_negative = (anchor-negative).pow(2).sum(1)
|
|
losses = F.relu(distance_negative-distance_positive+self.margin)
|
|
return losses.mean() if size_average else losses.sum()
|
|
|
|
if __name__ == '__main__':
|
|
print('')
|