回传数据解析,兼容v5和v10
This commit is contained in:
88
contrast/feat_extract/model/BAM.py
Normal file
88
contrast/feat_extract/model/BAM.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
def __int__(self, channel, reduction, num_layers):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
gate_channels = [channel]
|
||||
gate_channels += [len(channel) // reduction] * num_layers
|
||||
gate_channels += [channel]
|
||||
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('flatten', Flatten())
|
||||
for i in range(len(gate_channels) - 2):
|
||||
self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.ReLU())
|
||||
self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.avgpool(x)
|
||||
res = self.ca(res)
|
||||
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
|
||||
super(SpatialAttention).__init__()
|
||||
self.sa = nn.Sequential()
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
|
||||
self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
for i in range(num_lay):
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=3,
|
||||
in_channels=(channel // reduction),
|
||||
out_channels=(channel // reduction),
|
||||
padding=1,
|
||||
dilation=2))
|
||||
self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.sa(x)
|
||||
res = res.expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class BAMblock(nn.Module):
|
||||
def __init__(self, channel=512, reduction=16, dia_val=2):
|
||||
super(BAMblock, self).__init__()
|
||||
self.ca = ChannelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(channel, reduction, dia_val)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal(m.weight, mode='fan_out')
|
||||
if m.bais is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
sa_out = self.sa(x)
|
||||
ca_out = self.ca(x)
|
||||
weight = self.sigmoid(sa_out + ca_out)
|
||||
out = (1 + weight) * x
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(512 // 14)
|
70
contrast/feat_extract/model/CBAM.py
Normal file
70
contrast/feat_extract/model/CBAM.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
class channelAttention(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(channelAttention, self).__init__()
|
||||
self.Maxpooling = nn.AdaptiveMaxPool2d(1)
|
||||
self.Avepooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
|
||||
self.ca.add_module('Relu', nn.ReLU())
|
||||
self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
M_out = self.Maxpooling(x)
|
||||
A_out = self.Avepooling(x)
|
||||
M_out = self.ca(M_out)
|
||||
A_out = self.ca(A_out)
|
||||
out = self.sigmod(M_out+A_out)
|
||||
return out
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
max_result, _ = torch.max(x, dim=1, keepdim=True)
|
||||
avg_result = torch.mean(x, dim=1, keepdim=True)
|
||||
result = torch.cat([max_result, avg_result], dim=1)
|
||||
output = self.conv(result)
|
||||
output = self.sigmoid(output)
|
||||
return output
|
||||
|
||||
class CBAM(nn.Module):
|
||||
def __init__(self, channel, reduction=16, kernel_size=7):
|
||||
super().__init__()
|
||||
self.ca = channelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(kernel_size)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():#权重初始化
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
# b,c_,_ = x.size()
|
||||
# residual = x
|
||||
out = x*self.ca(x)
|
||||
out = out*self.sa(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
input=torch.randn(50,512,7,7)
|
||||
kernel_size=input.shape[2]
|
||||
cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
|
||||
output=cbam(input)
|
||||
print(output.shape)
|
33
contrast/feat_extract/model/Tool.py
Normal file
33
contrast/feat_extract/model/Tool.py
Normal file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super(GeM, self).__init__()
|
||||
self.p = nn.Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return self.gem(x, p=self.p, eps=self.eps, stride = 2)
|
||||
|
||||
def gem(self, x, p=3, eps=1e-6, stride = 2):
|
||||
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||
', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
def __init__(self, margin):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, anchor, positive, negative, size_average = True):
|
||||
distance_positive = (anchor-positive).pow(2).sum(1)
|
||||
distance_negative = (anchor-negative).pow(2).sum(1)
|
||||
losses = F.relu(distance_negative-distance_positive+self.margin)
|
||||
return losses.mean() if size_average else losses.sum()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('')
|
11
contrast/feat_extract/model/__init__.py
Normal file
11
contrast/feat_extract/model/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .fmobilenet import FaceMobileNet
|
||||
from .resnet_face import ResIRSE
|
||||
from .mobilevit import mobilevit_s
|
||||
from .metric import ArcFace, CosFace
|
||||
from .loss import FocalLoss
|
||||
from .resbam import resnet
|
||||
from .resnet_pre import resnet18, resnet34, resnet50, resnet14
|
||||
from .mobilenet_v2 import mobilenet_v2
|
||||
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
||||
# from .mobilenet_v1 import mobilenet_v1
|
||||
from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, PPLCNET_x2_5
|
BIN
contrast/feat_extract/model/__pycache__/BAM.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/BAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-310.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
124
contrast/feat_extract/model/fmobilenet.py
Normal file
124
contrast/feat_extract/model/fmobilenet.py
Normal file
@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
class ConvBn(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class ConvBnPrelu(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBn(in_c, out_c, kernel, stride, padding, groups),
|
||||
nn.PReLU(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWise(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBnPrelu(in_c, groups, kernel=(1, 1), stride=1, padding=0),
|
||||
ConvBnPrelu(groups, groups, kernel=kernel, stride=stride, padding=padding, groups=groups),
|
||||
ConvBn(groups, out_c, kernel=(1, 1), stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWiseRes(nn.Module):
|
||||
"""DepthWise with Residual"""
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = DepthWise(in_c, out_c, kernel, stride, padding, groups)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
|
||||
class MultiDepthWiseRes(nn.Module):
|
||||
|
||||
def __init__(self, num_block, channels, kernel=(3, 3), stride=1, padding=1, groups=1):
|
||||
super().__init__()
|
||||
|
||||
self.net = nn.Sequential(*[
|
||||
DepthWiseRes(channels, channels, kernel, stride, padding, groups)
|
||||
for _ in range(num_block)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class FaceMobileNet(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBnPrelu(1, 64, kernel=(3, 3), stride=2, padding=1)
|
||||
self.conv2 = ConvBn(64, 64, kernel=(3, 3), stride=1, padding=1, groups=64)
|
||||
self.conv3 = DepthWise(64, 64, kernel=(3, 3), stride=2, padding=1, groups=128)
|
||||
self.conv4 = MultiDepthWiseRes(num_block=4, channels=64, kernel=3, stride=1, padding=1, groups=128)
|
||||
self.conv5 = DepthWise(64, 128, kernel=(3, 3), stride=2, padding=1, groups=256)
|
||||
self.conv6 = MultiDepthWiseRes(num_block=6, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv7 = DepthWise(128, 128, kernel=(3, 3), stride=2, padding=1, groups=512)
|
||||
self.conv8 = MultiDepthWiseRes(num_block=2, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv9 = ConvBnPrelu(128, 512, kernel=(1, 1))
|
||||
self.conv10 = ConvBn(512, 512, groups=512, kernel=(7, 7))
|
||||
self.flatten = Flatten()
|
||||
self.linear = nn.Linear(2048, embedding_size, bias=False)
|
||||
self.bn = nn.BatchNorm1d(embedding_size)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
out = self.conv4(out)
|
||||
out = self.conv5(out)
|
||||
out = self.conv6(out)
|
||||
out = self.conv7(out)
|
||||
out = self.conv8(out)
|
||||
out = self.conv9(out)
|
||||
out = self.conv10(out)
|
||||
out = self.flatten(out)
|
||||
out = self.linear(out)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
x = Image.open("../samples/009.jpg").convert('L')
|
||||
x = x.resize((128, 128))
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
x = x[None, None, ...]
|
||||
x = torch.from_numpy(x)
|
||||
net = FaceMobileNet(512)
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
out = net(x)
|
||||
print(out.shape)
|
233
contrast/feat_extract/model/lcnet.py
Normal file
233
contrast/feat_extract/model/lcnet.py
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import thop
|
||||
|
||||
# try:
|
||||
# import softpool_cuda
|
||||
# from SoftPool import soft_pool2d, SoftPool2d
|
||||
# except ImportError:
|
||||
# print('Please install SoftPool first: https://github.com/alexandrosstergiou/SoftPool')
|
||||
# exit(0)
|
||||
|
||||
NET_CONFIG = {
|
||||
# k, in_c, out_c, s, use_se
|
||||
"blocks2": [[3, 16, 32, 1, False]],
|
||||
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||
"blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||
}
|
||||
|
||||
|
||||
def autopad(k, p=None):
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||
return p
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class HardSwish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSwish, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.relu6(x+3) / 6
|
||||
|
||||
|
||||
class HardSigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSigmoid, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.relu6(x+3)) / 6
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
HardSigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
y = self.avgpool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(self, inp, oup, dw_size, stride, use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self.stride = stride
|
||||
self.inp = inp
|
||||
self.oup = oup
|
||||
self.dw_size = dw_size
|
||||
self.dw_sp = nn.Sequential(
|
||||
nn.Conv2d(self.inp, self.inp, kernel_size=self.dw_size, stride=self.stride,
|
||||
padding=autopad(self.dw_size, None), groups=self.inp, bias=False),
|
||||
nn.BatchNorm2d(self.inp),
|
||||
HardSwish(),
|
||||
|
||||
nn.Conv2d(self.inp, self.oup, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.oup),
|
||||
HardSwish(),
|
||||
)
|
||||
self.se = SELayer(self.oup)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dw_sp(x)
|
||||
if self.use_se:
|
||||
x = self.se(x)
|
||||
return x
|
||||
|
||||
|
||||
class PP_LCNet(nn.Module):
|
||||
def __init__(self, scale=1.0, class_num=10, class_expand=1280, dropout_prob=0.2):
|
||||
super(PP_LCNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv1 = nn.Conv2d(3, out_channels=make_divisible(16 * self.scale),
|
||||
kernel_size=3, stride=2, padding=1, bias=False)
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks2 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks2"])
|
||||
])
|
||||
|
||||
self.blocks3 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks3"])
|
||||
])
|
||||
|
||||
self.blocks4 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks4"])
|
||||
])
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks5 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks5"])
|
||||
])
|
||||
|
||||
self.blocks6 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks6"])
|
||||
])
|
||||
|
||||
self.GAP = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.last_conv = nn.Conv2d(in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
|
||||
out_channels=class_expand,
|
||||
kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
self.hardswish = HardSwish()
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
|
||||
self.fc = nn.Linear(class_expand, class_num)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
print(x.shape)
|
||||
x = self.blocks2(x)
|
||||
print(x.shape)
|
||||
x = self.blocks3(x)
|
||||
print(x.shape)
|
||||
x = self.blocks4(x)
|
||||
print(x.shape)
|
||||
x = self.blocks5(x)
|
||||
print(x.shape)
|
||||
x = self.blocks6(x)
|
||||
print(x.shape)
|
||||
|
||||
x = self.GAP(x)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
x = torch.flatten(x, start_dim=1, end_dim=-1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def PPLCNET_x0_25(**kwargs):
|
||||
model = PP_LCNet(scale=0.25, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_35(**kwargs):
|
||||
model = PP_LCNet(scale=0.35, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_5(**kwargs):
|
||||
model = PP_LCNet(scale=0.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_75(**kwargs):
|
||||
model = PP_LCNet(scale=0.75, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_0(**kwargs):
|
||||
model = PP_LCNet(scale=1.0, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_5(**kwargs):
|
||||
model = PP_LCNet(scale=1.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x2_0(**kwargs):
|
||||
model = PP_LCNet(scale=2.0, **kwargs)
|
||||
return model
|
||||
|
||||
def PPLCNET_x2_5(**kwargs):
|
||||
model = PP_LCNet(scale=2.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# input = torch.randn(1, 3, 640, 640)
|
||||
# model = PPLCNET_x2_5()
|
||||
# flops, params = thop.profile(model, inputs=(input,))
|
||||
# print('flops:', flops / 1000000000)
|
||||
# print('params:', params / 1000000)
|
||||
|
||||
model = PPLCNET_x1_0()
|
||||
# model_1 = PW_Conv(3, 16)
|
||||
input = torch.randn(2, 3, 256, 256)
|
||||
print(input.shape)
|
||||
output = model(input)
|
||||
print(output.shape) # [1, num_class]
|
||||
|
18
contrast/feat_extract/model/loss.py
Normal file
18
contrast/feat_extract/model/loss.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self, gamma=2):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.ce = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, input, target):
|
||||
|
||||
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
|
||||
logp = self.ce(input, target)
|
||||
p = torch.exp(-logp)
|
||||
loss = (1 - p) ** self.gamma * logp
|
||||
return loss.mean()
|
83
contrast/feat_extract/model/metric.py
Normal file
83
contrast/feat_extract/model/metric.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Definition of ArcFace loss and CosFace loss
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ArcFace(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
|
||||
"""ArcFace formula:
|
||||
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
|
||||
Note that:
|
||||
0 <= m + theta <= Pi
|
||||
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
|
||||
we have:
|
||||
cos(theta) < cos(Pi - m)
|
||||
So we can use cos(Pi - m) as threshold to check whether
|
||||
(m + theta) go out of [0, Pi]
|
||||
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = embedding_size
|
||||
self.out_features = class_num
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
self.cos_m = math.cos(m)
|
||||
self.sin_m = math.sin(m)
|
||||
self.th = math.cos(math.pi - m)
|
||||
self.mm = math.sin(math.pi - m) * m
|
||||
|
||||
def forward(self, input, label):
|
||||
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
# print('F.normalize(input)',input.shape)
|
||||
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
|
||||
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
|
||||
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
|
||||
# update y_i by phi in cosine
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
# print(f'output {(output * self.s).shape}')
|
||||
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
|
||||
return output * self.s
|
||||
|
||||
|
||||
class CosFace(nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||
"""
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, input, label):
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
phi = cosine - self.m
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
return output * self.s
|
148
contrast/feat_extract/model/mobilenet_v1.py
Normal file
148
contrast/feat_extract/model/mobilenet_v1.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torchvision.ops.misc import Conv2dNormActivation
|
||||
from config import config as conf
|
||||
|
||||
__all__ = [
|
||||
"MobileNetV1",
|
||||
"DepthWiseSeparableConv2d",
|
||||
"mobilenet_v1",
|
||||
]
|
||||
|
||||
|
||||
class MobileNetV1(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = conf.embedding_size,
|
||||
) -> None:
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
Conv2dNormActivation(3,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
DepthWiseSeparableConv2d(32, 64, 1),
|
||||
DepthWiseSeparableConv2d(64, 128, 2),
|
||||
DepthWiseSeparableConv2d(128, 128, 1),
|
||||
DepthWiseSeparableConv2d(128, 256, 2),
|
||||
DepthWiseSeparableConv2d(256, 256, 1),
|
||||
DepthWiseSeparableConv2d(256, 512, 2),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 1024, 2),
|
||||
DepthWiseSeparableConv2d(1024, 1024, 1),
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d((7, 7))
|
||||
|
||||
self.classifier = nn.Linear(1024, num_classes)
|
||||
|
||||
# Initialize neural network weights
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self._forward_impl(x)
|
||||
|
||||
return out
|
||||
|
||||
# Support torch.script function
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
out = self.features(x)
|
||||
out = self.avgpool(out)
|
||||
out = torch.flatten(out, 1)
|
||||
out = self.classifier(out)
|
||||
|
||||
return out
|
||||
|
||||
def _initialize_weights(self) -> None:
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0, 0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
class DepthWiseSeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(DepthWiseSeparableConv2d, self).__init__()
|
||||
self.stride = stride
|
||||
if stride not in [1, 2]:
|
||||
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
Conv2dNormActivation(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
Conv2dNormActivation(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self.conv(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
|
||||
model = MobileNetV1(**kwargs)
|
||||
|
||||
return model
|
200
contrast/feat_extract/model/mobilenet_v2.py
Normal file
200
contrast/feat_extract/model/mobilenet_v2.py
Normal file
@ -0,0 +1,200 @@
|
||||
from torch import nn
|
||||
from .utils import load_state_dict_from_url
|
||||
from ..config import config as conf
|
||||
|
||||
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
|
||||
padding = (kernel_size - 1) // 2
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
norm_layer(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
norm_layer(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=conf.embedding_size,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
norm_layer=None):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
norm_layer: Module specifying the normalization layer to use
|
||||
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
|
||||
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def mobilenet_v2(pretrained=True, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a MobileNetV2 architecture from
|
||||
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||
progress=progress)
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
#.load_state_dict(state_dict)
|
||||
return model
|
200
contrast/feat_extract/model/mobilenet_v3.py
Normal file
200
contrast/feat_extract/model/mobilenet_v3.py
Normal file
@ -0,0 +1,200 @@
|
||||
'''MobileNetV3 in PyTorch.
|
||||
|
||||
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
from ..config import config as conf
|
||||
|
||||
|
||||
class hswish(nn.Module):
|
||||
def forward(self, x):
|
||||
out = x * F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class hsigmoid(nn.Module):
|
||||
def forward(self, x):
|
||||
out = F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class SeModule(nn.Module):
|
||||
def __init__(self, in_size, reduction=4):
|
||||
super(SeModule, self).__init__()
|
||||
self.se = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size),
|
||||
hsigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.se(x)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''expand + depthwise + pointwise'''
|
||||
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
||||
super(Block, self).__init__()
|
||||
self.stride = stride
|
||||
self.se = semodule
|
||||
|
||||
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear1 = nolinear
|
||||
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear2 = nolinear
|
||||
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_size)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride == 1 and in_size != out_size:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_size),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.se != None:
|
||||
out = self.se(out)
|
||||
out = out + self.shortcut(x) if self.stride==1 else out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV3_Large(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Large, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
|
||||
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(3, 40, 240, 80, hswish(), None, 2),
|
||||
Block(3, 80, 200, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
|
||||
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
|
||||
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
|
||||
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
|
||||
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(960)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(960, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class MobileNetV3_Small(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Small, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
|
||||
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(576)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(576, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def test():
|
||||
net = MobileNetV3_Small()
|
||||
x = torch.randn(2,3,224,224)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
268
contrast/feat_extract/model/mobilevit.py
Normal file
268
contrast/feat_extract/model/mobilevit.py
Normal file
@ -0,0 +1,268 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
# import sys
|
||||
# sys.path.append(r"D:\DetectTracking")
|
||||
from ..config import config as conf
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
|
||||
pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
|
||||
super().__init__()
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
L = [2, 4, 3]
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
|
||||
|
||||
self.mv2 = nn.ModuleList([])
|
||||
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
|
||||
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
|
||||
|
||||
self.mvit = nn.ModuleList([])
|
||||
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
|
||||
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
|
||||
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
|
||||
|
||||
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
|
||||
|
||||
self.pool = nn.AvgPool2d(ih // 32, 1)
|
||||
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
x = self.conv1(x)
|
||||
x = self.mv2[0](x)
|
||||
|
||||
x = self.mv2[1](x)
|
||||
x = self.mv2[2](x)
|
||||
x = self.mv2[3](x) # Repeat
|
||||
|
||||
x = self.mv2[4](x)
|
||||
x = self.mvit[0](x)
|
||||
|
||||
x = self.mv2[5](x)
|
||||
x = self.mvit[1](x)
|
||||
|
||||
x = self.mv2[6](x)
|
||||
x = self.mvit[2](x)
|
||||
x = self.conv2(x)
|
||||
|
||||
|
||||
#print('pool_before',x.shape)
|
||||
x = self.pool(x).view(-1, x.shape[1])
|
||||
#print('self_pool',self.pool)
|
||||
#print('pool_after',x.shape)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobilevit_xxs():
|
||||
dims = [64, 80, 96]
|
||||
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
|
||||
|
||||
|
||||
def mobilevit_xs():
|
||||
dims = [96, 120, 144]
|
||||
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000)
|
||||
|
||||
|
||||
def mobilevit_s():
|
||||
dims = [144, 192, 240]
|
||||
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
|
||||
return MobileViT((conf.img_size, conf.img_size), dims, channels, num_classes=conf.embedding_size)
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = torch.randn(5, 3, 256, 256)
|
||||
|
||||
vit = mobilevit_xxs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_xs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_s()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
145
contrast/feat_extract/model/resbam.py
Normal file
145
contrast/feat_extract/model/resbam.py
Normal file
@ -0,0 +1,145 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .CBAM import CBAM
|
||||
from .Tool import GeM as gem
|
||||
# from model.CBAM import CBAM
|
||||
# from model.Tool import GeM as gem
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inchannel, outchannel, stride=1, dowsample=None):
|
||||
# super(Bottleneck, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=outchannel, kernel_size=1, stride=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(outchannel)
|
||||
self.conv2 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel, kernel_size=3, bias=False,
|
||||
stride=stride, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(outchannel)
|
||||
self.conv3 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel * self.expansion, stride=1, bias=False,
|
||||
kernel_size=1)
|
||||
self.bn3 = nn.BatchNorm2d(outchannel * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = dowsample
|
||||
|
||||
def forward(self, x):
|
||||
self.identity = x
|
||||
# print('>>>>>>>>',type(x))
|
||||
if self.downsample is not None:
|
||||
# print('>>>>downsample>>>>', type(self.downsample))
|
||||
self.identity = self.downsample(x)
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
# print('>>>>out>>>identity',out.size(),self.identity.size())
|
||||
out = out + self.identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class resnet(nn.Module):
|
||||
def __init__(self, block=Bottleneck, block_num=[3, 4, 6, 3], num_class=1000):
|
||||
super().__init__()
|
||||
self.in_channel = 64
|
||||
self.conv1 = nn.Conv2d(in_channels=3,
|
||||
out_channels=self.in_channel,
|
||||
stride=2,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.cbam = CBAM(self.in_channel)
|
||||
self.cbam1 = CBAM(2048)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, block_num[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.gem = gem()
|
||||
self.fc = nn.Linear(512 * block.expansion, num_class)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight, mode='fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, channel, block_num, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(channel * block.expansion))
|
||||
layer = []
|
||||
layer.append(block(self.in_channel, channel, stride, downsample))
|
||||
self.in_channel = channel * block.expansion
|
||||
for _ in range(1, block_num):
|
||||
layer.append(block(self.in_channel, channel))
|
||||
return nn.Sequential(*layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.cbam(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.cbam1(x)
|
||||
# x = self.avgpool(x)
|
||||
x = self.gem(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class TripletNet(nn.Module):
|
||||
def __init__(self, num_class, flag=True):
|
||||
super(TripletNet, self).__init__()
|
||||
self.initnet = rescbam(num_class)
|
||||
self.flag = flag
|
||||
|
||||
def forward(self, x1, x2=None, x3=None):
|
||||
if self.flag:
|
||||
output1 = self.initnet(x1)
|
||||
output2 = self.initnet(x2)
|
||||
output3 = self.initnet(x3)
|
||||
return output1, output2, output3
|
||||
else:
|
||||
output = self.initnet(x1)
|
||||
return output
|
||||
|
||||
|
||||
def rescbam(num_class):
|
||||
return resnet(block=Bottleneck, block_num=[3, 4, 6, 3], num_class=num_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
input1 = torch.randn(4, 3, 640, 640)
|
||||
input2 = torch.randn(4, 3, 640, 640)
|
||||
input3 = torch.randn(4, 3, 640, 640)
|
||||
|
||||
# rescbam测试
|
||||
# Resnet50 = rescbam(512)
|
||||
# output = Resnet50.forward(input1)
|
||||
# print(Resnet50)
|
||||
|
||||
# trnet测试
|
||||
trnet = TripletNet(512)
|
||||
output = trnet(input1, input2, input3)
|
||||
print(output)
|
189
contrast/feat_extract/model/resnet.py
Normal file
189
contrast/feat_extract/model/resnet.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""resnet in pytorch
|
||||
|
||||
|
||||
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||
|
||||
Deep Residual Learning for Image Recognition
|
||||
https://arxiv.org/abs/1512.03385v1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from config import config as conf
|
||||
from CBAM import CBAM
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic Block for resnet 18 and resnet 34
|
||||
|
||||
"""
|
||||
|
||||
#BasicBlock and BottleNeck block
|
||||
#have different output size
|
||||
#we use class attribute expansion
|
||||
#to distinct
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
|
||||
#residual function
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
#shortcut
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
#the shortcut output dimension is not the same with residual function
|
||||
#use 1*1 convolution to match the dimension
|
||||
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
"""Residual block for resnet over 50 layers
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_block, cbam = False, num_classes=conf.embedding_size):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = 64
|
||||
|
||||
# self.conv1 = nn.Sequential(
|
||||
# nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64),
|
||||
# nn.ReLU(inplace=True))
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64,stride=2,kernel_size=7,padding=3,bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
|
||||
self.cbam = CBAM(self.in_channels)
|
||||
|
||||
#we use a different inputsize than the original paper
|
||||
#so conv2_x's stride is 1
|
||||
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||
self.cbam1 = CBAM(self.in_channels)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight,mode = 'fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||
same as a neuron netowork layer, ex. conv layer), one layer may
|
||||
contain more than one residual block
|
||||
|
||||
Args:
|
||||
block: block type, basic block or bottle neck block
|
||||
out_channels: output depth channel number of this layer
|
||||
num_blocks: how many blocks per layer
|
||||
stride: the stride of the first block of this layer
|
||||
|
||||
Return:
|
||||
return a resnet layer
|
||||
"""
|
||||
|
||||
# we have num_block blocks per layer, the first block
|
||||
# could be 1 or 2, other blocks would always be 1
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.conv1(x)
|
||||
if cbam:
|
||||
output = self.cbam(x)
|
||||
output = self.conv2_x(output)
|
||||
output = self.conv3_x(output)
|
||||
output = self.conv4_x(output)
|
||||
output = self.conv5_x(output)
|
||||
if cbam:
|
||||
output = self.cbam1(x)
|
||||
print('pollBefore',output.shape)
|
||||
output = self.avg_pool(output)
|
||||
print('poolAfter',output.shape)
|
||||
output = output.view(output.size(0), -1)
|
||||
print('fcBefore',output.shape)
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
||||
|
||||
def resnet18(cbam = False):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], cbam)
|
||||
|
||||
def resnet34():
|
||||
""" return a ResNet 34 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
def resnet50():
|
||||
""" return a ResNet 50 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||
|
||||
def resnet101():
|
||||
""" return a ResNet 101 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||
|
||||
def resnet152():
|
||||
""" return a ResNet 152 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||
|
||||
|
121
contrast/feat_extract/model/resnet_face.py
Normal file
121
contrast/feat_extract/model/resnet_face.py
Normal file
@ -0,0 +1,121 @@
|
||||
""" Resnet_IR_SE in ArcFace """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.reshape(x.shape[0], -1)
|
||||
|
||||
|
||||
class SEConv(nn.Module):
|
||||
"""Use Convolution instead of FullyConnection in SE"""
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(channels, channels // reduction, kernel_size=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) * x
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Linear(channels, channels // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channels // reduction, channels),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) * x
|
||||
|
||||
|
||||
class IRSE(nn.Module):
|
||||
|
||||
def __init__(self, channels, depth, stride):
|
||||
super().__init__()
|
||||
if channels == depth:
|
||||
self.shortcut = nn.MaxPool2d(kernel_size=1, stride=stride)
|
||||
else:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(channels, depth, (1, 1), stride, bias=False),
|
||||
nn.BatchNorm2d(depth),
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.Conv2d(channels, depth, (3, 3), 1, 1, bias=False),
|
||||
nn.PReLU(depth),
|
||||
nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
nn.BatchNorm2d(depth),
|
||||
SEConv(depth, 16),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.shortcut(x) + self.residual(x)
|
||||
|
||||
|
||||
class ResIRSE(nn.Module):
|
||||
"""Resnet50-IRSE backbone"""
|
||||
|
||||
def __init__(self, ih, embedding_size, drop_ratio):
|
||||
super().__init__()
|
||||
ih_last = ih // 16
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.PReLU(64),
|
||||
)
|
||||
self.output_layer = nn.Sequential(
|
||||
nn.BatchNorm2d(512),
|
||||
nn.Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
nn.Linear(512 * ih_last * ih_last, embedding_size),
|
||||
nn.BatchNorm1d(embedding_size),
|
||||
)
|
||||
|
||||
# ["channels", "depth", "stride"],
|
||||
self.res50_arch = [
|
||||
[64, 64, 2], [64, 64, 1], [64, 64, 1],
|
||||
[64, 128, 2], [128, 128, 1], [128, 128, 1], [128, 128, 1],
|
||||
[128, 256, 2], [256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1],
|
||||
[256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1],
|
||||
[256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1],
|
||||
[256, 512, 2], [512, 512, 1], [512, 512, 1],
|
||||
]
|
||||
|
||||
self.body = nn.Sequential(*[IRSE(a, b, c) for (a, b, c) in self.res50_arch])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
x = Image.open("../samples/009.jpg").convert('L')
|
||||
x = x.resize((128, 128))
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
x = x[None, None, ...]
|
||||
x = torch.from_numpy(x)
|
||||
net = ResIRSE(512, 0.6)
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
out = net(x)
|
||||
print(out.shape)
|
462
contrast/feat_extract/model/resnet_pre.py
Normal file
462
contrast/feat_extract/model/resnet_pre.py
Normal file
@ -0,0 +1,462 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..config import config as conf
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
# from .utils import load_state_dict_from_url
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super(SpatialAttention, self).__init__()
|
||||
|
||||
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||
padding = 3 if kernel_size == 7 else 1
|
||||
|
||||
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
x = torch.cat([avg_out, max_out], dim=1)
|
||||
x = self.conv1(x)
|
||||
return self.sigmoid(x)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
|
||||
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
if self.bam:
|
||||
self.bam = SpatialAttention()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
|
||||
if self.bam:
|
||||
out = out*self.bam(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
|
||||
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, scale=0.75):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, int(64*scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128*scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, int(256*scale), layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, int(512*scale), layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(int(512 * block.expansion*scale), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
# print('poolBefore', x.shape)
|
||||
x = self.avgpool(x)
|
||||
# print('poolAfter', x.shape)
|
||||
x = torch.flatten(x, 1)
|
||||
# print('fcBefore',x.shape)
|
||||
x = self.fc(x)
|
||||
|
||||
# print('fcAfter',x.shape)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# return model
|
||||
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resnet14(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-14 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet18(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
4
contrast/feat_extract/model/utils.py
Normal file
4
contrast/feat_extract/model/utils.py
Normal file
@ -0,0 +1,4 @@
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
137
contrast/feat_extract/model/vit.py
Normal file
137
contrast/feat_extract/model/vit.py
Normal file
@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from vit_pytorch.mobile_vit import MobileViT
|
||||
from vit_pytorch import vit
|
||||
from vit_pytorch import SimpleViT
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
|
||||
FeedForward(dim, mlp_dim, dropout=dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
|
||||
dim_head=64, dropout=0., emb_dropout=0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
Reference in New Issue
Block a user