86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
import torch
|
|
from torchvision import models
|
|
import torch.nn as nn
|
|
from network.BaseNet import *
|
|
from network.mobilevit import *
|
|
|
|
def initnet(flag='resnet50', mvit = 128):
|
|
if flag == 'resnet50':
|
|
model_ft = models.resnet50(pretrained=True)
|
|
for param in model_ft.parameters():
|
|
param.require_grad = False
|
|
num_ftrs = model_ft.fc.in_features
|
|
model_ft.fc = nn.Linear(num_ftrs, 2048)
|
|
return model_ft
|
|
elif flag == 'resnet50_fpn':
|
|
model_ft = ResnetFpn()
|
|
return model_ft
|
|
elif flag == 'mobilevit':
|
|
model_ft = mobilevit_s(mvit)
|
|
return model_ft
|
|
else:
|
|
raise ValueError("Please select the correct model .......")
|
|
|
|
class L2N(nn.Module):
|
|
|
|
def __init__(self, eps=1e-6):
|
|
super(L2N,self).__init__()
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + self.eps).expand_as(x)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
|
|
|
|
|
|
class TripletNet(nn.Module):
|
|
def __init__(self, initnet):
|
|
super(TripletNet, self).__init__()
|
|
self.initnet =initnet
|
|
|
|
def forward(self, x1, x2, x3):
|
|
output1 = self.initnet(x1)
|
|
output2 = self.initnet(x2)
|
|
output3 = self.initnet(x3)
|
|
return output1, output2, output3
|
|
|
|
def get_ininet(self, x):
|
|
return self.initnet(x)
|
|
|
|
class extractNet(nn.Module):
|
|
def __init__(self, initnet, norm):
|
|
super(extractNet, self).__init__()
|
|
self.initnet =initnet
|
|
self.norm = norm
|
|
|
|
def forward(self, x):
|
|
output = self.initnet(x)
|
|
output = self.norm(output).squeeze(-1).squeeze(-1)
|
|
return output
|
|
|
|
def get_ininet(self, x):
|
|
return self.initnet(x)
|
|
|
|
import torch.nn.functional as F
|
|
class GeM(nn.Module):
|
|
def __init__(self, p=3, eps=1e-6):
|
|
super(GeM, self).__init__()
|
|
self.p = nn.Parameter(torch.ones(1) * p)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return self.gem(x, p=self.p, eps=self.eps)
|
|
|
|
def gem(self, x, p=3, eps=1e-6):
|
|
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + \
|
|
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
|
', ' + 'eps=' + str(self.eps) + ')'
|
|
|
|
if __name__ == '__main__':
|
|
print(initnet())
|
|
|