Files
ieemoo-ai-contrast/model/Tool.py
2025-06-11 15:23:50 +08:00

38 lines
1.1 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps
def forward(self, x):
return self.gem(x, p=self.p, eps=self.eps, stride=2)
def gem(self, x, p=3, eps=1e-6, stride=2):
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
def __repr__(self):
return self.__class__.__name__ + \
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
', ' + 'eps=' + str(self.eps) + ')'
class TripletLoss(nn.Module):
def __init__(self, margin):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, size_average=True):
distance_positive = (anchor - positive).pow(2).sum(1)
distance_negative = (anchor - negative).pow(2).sum(1)
losses = F.relu(distance_negative - distance_positive + self.margin)
return losses.mean() if size_average else losses.sum()
if __name__ == '__main__':
print('')