first commit
This commit is contained in:
85
network/createNet.py
Normal file
85
network/createNet.py
Normal file
@ -0,0 +1,85 @@
|
||||
import torch
|
||||
from torchvision import models
|
||||
import torch.nn as nn
|
||||
from network.BaseNet import *
|
||||
from network.mobilevit import *
|
||||
|
||||
def initnet(flag='resnet50', mvit = 128):
|
||||
if flag == 'resnet50':
|
||||
model_ft = models.resnet50(pretrained=True)
|
||||
for param in model_ft.parameters():
|
||||
param.require_grad = False
|
||||
num_ftrs = model_ft.fc.in_features
|
||||
model_ft.fc = nn.Linear(num_ftrs, 2048)
|
||||
return model_ft
|
||||
elif flag == 'resnet50_fpn':
|
||||
model_ft = ResnetFpn()
|
||||
return model_ft
|
||||
elif flag == 'mobilevit':
|
||||
model_ft = mobilevit_s(mvit)
|
||||
return model_ft
|
||||
else:
|
||||
raise ValueError("Please select the correct model .......")
|
||||
|
||||
class L2N(nn.Module):
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
super(L2N,self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + self.eps).expand_as(x)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
|
||||
class TripletNet(nn.Module):
|
||||
def __init__(self, initnet):
|
||||
super(TripletNet, self).__init__()
|
||||
self.initnet =initnet
|
||||
|
||||
def forward(self, x1, x2, x3):
|
||||
output1 = self.initnet(x1)
|
||||
output2 = self.initnet(x2)
|
||||
output3 = self.initnet(x3)
|
||||
return output1, output2, output3
|
||||
|
||||
def get_ininet(self, x):
|
||||
return self.initnet(x)
|
||||
|
||||
class extractNet(nn.Module):
|
||||
def __init__(self, initnet, norm):
|
||||
super(extractNet, self).__init__()
|
||||
self.initnet =initnet
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, x):
|
||||
output = self.initnet(x)
|
||||
output = self.norm(output).squeeze(-1).squeeze(-1)
|
||||
return output
|
||||
|
||||
def get_ininet(self, x):
|
||||
return self.initnet(x)
|
||||
|
||||
import torch.nn.functional as F
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super(GeM, self).__init__()
|
||||
self.p = nn.Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return self.gem(x, p=self.p, eps=self.eps)
|
||||
|
||||
def gem(self, x, p=3, eps=1e-6):
|
||||
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||
', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(initnet())
|
||||
|
Reference in New Issue
Block a user