Files
ieemoo-ai-conpurchase/network/network.py
2023-06-25 13:55:22 +08:00

85 lines
2.5 KiB
Python

import torch
from torchvision import models
import torch.nn as nn
from Networks.BaseNet import *
from Networks.mobilevit import *
def initnet(flag='resnet50', mvit = 128):
if flag == 'resnet50':
model_ft = models.resnet50(pretrained=True)
for param in model_ft.parameters():
param.require_grad = False
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2048)
return model_ft
elif flag == 'resnet50_fpn':
model_ft = ResnetFpn()
return model_ft
elif flag == 'mobilevit':
model_ft = mobilevit_s(mvit)
return model_ft
else:
raise ValueError("Please select the correct model .......")
class L2N(nn.Module):
def __init__(self, eps=1e-6):
super(L2N,self).__init__()
self.eps = eps
def forward(self, x):
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + self.eps).expand_as(x)
def __repr__(self):
return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
class TripletNet(nn.Module):
def __init__(self, initnet):
super(TripletNet, self).__init__()
self.initnet =initnet
def forward(self, x1, x2, x3):
output1 = self.initnet(x1)
output2 = self.initnet(x2)
output3 = self.initnet(x3)
return output1, output2, output3
def get_ininet(self, x):
return self.initnet(x)
class extractNet(nn.Module):
def __init__(self, initnet, norm):
super(extractNet, self).__init__()
self.initnet =initnet
self.norm = norm
def forward(self, x):
output = self.initnet(x)
output = self.norm(output).squeeze(-1).squeeze(-1)
return output
def get_ininet(self, x):
return self.initnet(x)
import torch.nn.functional as F
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps
def forward(self, x):
return self.gem(x, p=self.p, eps=self.eps)
def gem(self, x, p=3, eps=1e-6):
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
def __repr__(self):
return self.__class__.__name__ + \
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
', ' + 'eps=' + str(self.eps) + ')'
if __name__ == '__main__':
print(initnet())