import torch from torchvision import models import torch.nn as nn from Networks.BaseNet import * from Networks.mobilevit import * def initnet(flag='resnet50', mvit = 128): if flag == 'resnet50': model_ft = models.resnet50(pretrained=True) for param in model_ft.parameters(): param.require_grad = False num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, 2048) return model_ft elif flag == 'resnet50_fpn': model_ft = ResnetFpn() return model_ft elif flag == 'mobilevit': model_ft = mobilevit_s(mvit) return model_ft else: raise ValueError("Please select the correct model .......") class L2N(nn.Module): def __init__(self, eps=1e-6): super(L2N,self).__init__() self.eps = eps def forward(self, x): return x / (torch.norm(x, p=2, dim=1, keepdim=True) + self.eps).expand_as(x) def __repr__(self): return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' class TripletNet(nn.Module): def __init__(self, initnet): super(TripletNet, self).__init__() self.initnet =initnet def forward(self, x1, x2, x3): output1 = self.initnet(x1) output2 = self.initnet(x2) output3 = self.initnet(x3) return output1, output2, output3 def get_ininet(self, x): return self.initnet(x) class extractNet(nn.Module): def __init__(self, initnet, norm): super(extractNet, self).__init__() self.initnet =initnet self.norm = norm def forward(self, x): output = self.initnet(x) output = self.norm(output).squeeze(-1).squeeze(-1) return output def get_ininet(self, x): return self.initnet(x) import torch.nn.functional as F class GeM(nn.Module): def __init__(self, p=3, eps=1e-6): super(GeM, self).__init__() self.p = nn.Parameter(torch.ones(1) * p) self.eps = eps def forward(self, x): return self.gem(x, p=self.p, eps=self.eps) def gem(self, x, p=3, eps=1e-6): return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) def __repr__(self): return self.__class__.__name__ + \ '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \ ', ' + 'eps=' + str(self.eps) + ')' if __name__ == '__main__': print(initnet())