更新 detacttracking
This commit is contained in:
33
detecttracking/tracking/trackers/reid/model/Tool.py
Normal file
33
detecttracking/tracking/trackers/reid/model/Tool.py
Normal file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super(GeM, self).__init__()
|
||||
self.p = nn.Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return self.gem(x, p=self.p, eps=self.eps, stride = 2)
|
||||
|
||||
def gem(self, x, p=3, eps=1e-6, stride = 2):
|
||||
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||
', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
def __init__(self, margin):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, anchor, positive, negative, size_average = True):
|
||||
distance_positive = (anchor-positive).pow(2).sum(1)
|
||||
distance_negative = (anchor-negative).pow(2).sum(1)
|
||||
losses = F.relu(distance_negative-distance_positive+self.margin)
|
||||
return losses.mean() if size_average else losses.sum()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('')
|
Reference in New Issue
Block a user