rebuild
This commit is contained in:
88
model/BAM.py
Normal file
88
model/BAM.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
def __int__(self, channel, reduction, num_layers):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
gate_channels = [channel]
|
||||
gate_channels += [len(channel) // reduction] * num_layers
|
||||
gate_channels += [channel]
|
||||
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('flatten', Flatten())
|
||||
for i in range(len(gate_channels) - 2):
|
||||
self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.ReLU())
|
||||
self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.avgpool(x)
|
||||
res = self.ca(res)
|
||||
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
|
||||
super(SpatialAttention).__init__()
|
||||
self.sa = nn.Sequential()
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
|
||||
self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
for i in range(num_lay):
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=3,
|
||||
in_channels=(channel // reduction),
|
||||
out_channels=(channel // reduction),
|
||||
padding=1,
|
||||
dilation=2))
|
||||
self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.sa(x)
|
||||
res = res.expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class BAMblock(nn.Module):
|
||||
def __init__(self, channel=512, reduction=16, dia_val=2):
|
||||
super(BAMblock, self).__init__()
|
||||
self.ca = ChannelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(channel, reduction, dia_val)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal(m.weight, mode='fan_out')
|
||||
if m.bais is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
sa_out = self.sa(x)
|
||||
ca_out = self.ca(x)
|
||||
weight = self.sigmoid(sa_out + ca_out)
|
||||
out = (1 + weight) * x
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(512 // 14)
|
70
model/CBAM.py
Normal file
70
model/CBAM.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
class channelAttention(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(channelAttention, self).__init__()
|
||||
self.Maxpooling = nn.AdaptiveMaxPool2d(1)
|
||||
self.Avepooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
|
||||
self.ca.add_module('Relu', nn.ReLU())
|
||||
self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
M_out = self.Maxpooling(x)
|
||||
A_out = self.Avepooling(x)
|
||||
M_out = self.ca(M_out)
|
||||
A_out = self.ca(A_out)
|
||||
out = self.sigmod(M_out+A_out)
|
||||
return out
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
max_result, _ = torch.max(x, dim=1, keepdim=True)
|
||||
avg_result = torch.mean(x, dim=1, keepdim=True)
|
||||
result = torch.cat([max_result, avg_result], dim=1)
|
||||
output = self.conv(result)
|
||||
output = self.sigmoid(output)
|
||||
return output
|
||||
|
||||
class CBAM(nn.Module):
|
||||
def __init__(self, channel, reduction=16, kernel_size=7):
|
||||
super().__init__()
|
||||
self.ca = channelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(kernel_size)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():#权重初始化
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
# b,c_,_ = x.size()
|
||||
# residual = x
|
||||
out = x*self.ca(x)
|
||||
out = out*self.sa(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
input=torch.randn(50,512,7,7)
|
||||
kernel_size=input.shape[2]
|
||||
cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
|
||||
output=cbam(input)
|
||||
print(output.shape)
|
37
model/Tool.py
Normal file
37
model/Tool.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super(GeM, self).__init__()
|
||||
self.p = nn.Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return self.gem(x, p=self.p, eps=self.eps, stride=2)
|
||||
|
||||
def gem(self, x, p=3, eps=1e-6, stride=2):
|
||||
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||
', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
def __init__(self, margin):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, anchor, positive, negative, size_average=True):
|
||||
distance_positive = (anchor - positive).pow(2).sum(1)
|
||||
distance_negative = (anchor - negative).pow(2).sum(1)
|
||||
losses = F.relu(distance_negative - distance_positive + self.margin)
|
||||
return losses.mean() if size_average else losses.sum()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('')
|
14
model/__init__.py
Normal file
14
model/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from .fmobilenet import FaceMobileNet
|
||||
# from .resnet_face import ResIRSE
|
||||
from .mobilevit import mobilevit_s
|
||||
from .metric import ArcFace, CosFace
|
||||
from .loss import FocalLoss
|
||||
from .resbam import resnet
|
||||
from .resnet_pre import resnet18, resnet34, resnet50, resnet14, CustomResNet18
|
||||
from .mobilenet_v2 import mobilenet_v2
|
||||
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
||||
# from .mobilenet_v1 import mobilenet_v1
|
||||
from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, \
|
||||
PPLCNET_x2_5
|
||||
from .vit import vit_base
|
||||
from .mlp import MLP
|
BIN
model/__pycache__/CBAM.cpython-38.pyc
Normal file
BIN
model/__pycache__/CBAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/Tool.cpython-38.pyc
Normal file
BIN
model/__pycache__/Tool.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
model/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
BIN
model/__pycache__/fmobilenet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/lcnet.cpython-38.pyc
Normal file
BIN
model/__pycache__/lcnet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
model/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
model/__pycache__/metric.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mlp.cpython-38.pyc
Normal file
BIN
model/__pycache__/mlp.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v1.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v1.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v2.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilenet_v3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
model/__pycache__/mobilevit.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/resbam.cpython-38.pyc
Normal file
BIN
model/__pycache__/resbam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
BIN
model/__pycache__/resnet_pre.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/utils.cpython-38.pyc
Normal file
BIN
model/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
model/__pycache__/vit.cpython-38.pyc
Normal file
BIN
model/__pycache__/vit.cpython-38.pyc
Normal file
Binary file not shown.
142
model/benchmark.py
Normal file
142
model/benchmark.py
Normal file
@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import numpy as np
|
||||
from resnet_attention import resnet18_cbam, resnet34_cbam, resnet50_cbam
|
||||
|
||||
# 设置随机种子以确保结果可复现
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
# 设备配置
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"测试设备: {device}")
|
||||
|
||||
# 测试参数
|
||||
batch_sizes = [1, 4, 8, 16]
|
||||
image_sizes = [224, 384, 512]
|
||||
num_runs = 100 # 每个配置运行的次数
|
||||
warmup_runs = 20 # 预热运行次数,排除启动开销
|
||||
|
||||
# 模型配置
|
||||
model_configs = {
|
||||
"resnet18": {
|
||||
"base_model": lambda: resnet18_cbam(use_cbam=False),
|
||||
"attention_model": lambda: resnet18_cbam(use_cbam=True)
|
||||
},
|
||||
"resnet34": {
|
||||
"base_model": lambda: resnet34_cbam(use_cbam=False),
|
||||
"attention_model": lambda: resnet34_cbam(use_cbam=True)
|
||||
},
|
||||
"resnet50": {
|
||||
"base_model": lambda: resnet50_cbam(use_cbam=False),
|
||||
"attention_model": lambda: resnet50_cbam(use_cbam=True)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 基准测试函数
|
||||
def benchmark_model(model, input_size, batch_size, num_runs, warmup_runs):
|
||||
"""
|
||||
测试模型的推理性能
|
||||
|
||||
参数:
|
||||
- model: 待测试的模型
|
||||
- input_size: 输入图像尺寸
|
||||
- batch_size: 批次大小
|
||||
- num_runs: 测试运行次数
|
||||
- warmup_runs: 预热运行次数
|
||||
|
||||
返回:
|
||||
- 平均推理时间(毫秒)
|
||||
- 吞吐量(样本/秒)
|
||||
"""
|
||||
# 设置为评估模式
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# 创建随机输入
|
||||
input_tensor = torch.randn(batch_size, 3, input_size, input_size, device=device)
|
||||
|
||||
# 预热
|
||||
with torch.no_grad():
|
||||
for _ in range(warmup_runs):
|
||||
_ = model(input_tensor)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize() # 同步GPU操作
|
||||
|
||||
# 测量推理时间
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _ in range(num_runs):
|
||||
_ = model(input_tensor)
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize() # 同步GPU操作
|
||||
end_time = time.time()
|
||||
|
||||
# 计算指标
|
||||
total_time = end_time - start_time
|
||||
avg_time_per_batch = total_time / num_runs * 1000 # 毫秒
|
||||
throughput = batch_size * num_runs / total_time # 样本/秒
|
||||
|
||||
return avg_time_per_batch, throughput
|
||||
|
||||
|
||||
# 运行测试
|
||||
results = {}
|
||||
|
||||
for model_name, config in model_configs.items():
|
||||
results[model_name] = {}
|
||||
|
||||
# 创建模型
|
||||
base_model = config["base_model"]()
|
||||
attention_model = config["attention_model"]()
|
||||
|
||||
# 计算参数量
|
||||
base_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
|
||||
attention_params = sum(p.numel() for p in attention_model.parameters() if p.requires_grad)
|
||||
param_increase = (attention_params - base_params) / base_params * 100
|
||||
|
||||
print(f"\n测试模型: {model_name}")
|
||||
print(f" 基础参数量: {base_params / 1e6:.2f}M")
|
||||
print(f" 带注意力参数量: {attention_params / 1e6:.2f}M")
|
||||
print(f" 参数量增加: {param_increase:.2f}%")
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for image_size in image_sizes:
|
||||
key = f"batch_{batch_size}_size_{image_size}"
|
||||
results[model_name][key] = {}
|
||||
|
||||
# 测试基础模型
|
||||
base_time, base_throughput = benchmark_model(
|
||||
base_model, image_size, batch_size, num_runs, warmup_runs
|
||||
)
|
||||
|
||||
# 测试注意力模型
|
||||
attention_time, attention_throughput = benchmark_model(
|
||||
attention_model, image_size, batch_size, num_runs, warmup_runs
|
||||
)
|
||||
|
||||
# 计算增加的百分比
|
||||
time_increase = (attention_time - base_time) / base_time * 100
|
||||
throughput_decrease = (base_throughput - attention_throughput) / base_throughput * 100
|
||||
|
||||
results[model_name][key]["base_time"] = base_time
|
||||
results[model_name][key]["attention_time"] = attention_time
|
||||
results[model_name][key]["time_increase"] = time_increase
|
||||
results[model_name][key]["base_throughput"] = base_throughput
|
||||
results[model_name][key]["attention_throughput"] = attention_throughput
|
||||
results[model_name][key]["throughput_decrease"] = throughput_decrease
|
||||
|
||||
print(f" 配置: 批次大小={batch_size}, 图像尺寸={image_size}x{image_size}")
|
||||
print(f" 基础模型: 平均时间={base_time:.2f}ms, 吞吐量={base_throughput:.2f}样本/秒")
|
||||
print(f" 注意力模型: 平均时间={attention_time:.2f}ms, 吞吐量={attention_throughput:.2f}样本/秒")
|
||||
print(f" 时间增加: {time_increase:.2f}%, 吞吐量下降: {throughput_decrease:.2f}%")
|
||||
|
||||
# 保存结果
|
||||
import json
|
||||
|
||||
with open('benchmark_results.json', 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print("\n测试完成,结果已保存到 benchmark_results.json")
|
48
model/compare.py
Normal file
48
model/compare.py
Normal file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from config import config as conf
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from model.resnet_pre import resnet18, resnet50
|
||||
# from model.vit import vit_base_patch16_224, vit_base_patch32_224
|
||||
|
||||
|
||||
class ContrastiveModel(nn.Module):
|
||||
def __init__(self, projection_dim, model_name, contraposition=False):
|
||||
super(ContrastiveModel, self).__init__()
|
||||
self.contraposition = contraposition
|
||||
self.base_model = self._get_model(model_name)
|
||||
if not self.contraposition:
|
||||
if 'vit' in model_name:
|
||||
dim_mlp = self.base_model.head.weight.shape[1]
|
||||
self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim)
|
||||
else:
|
||||
dim_mlp = self.base_model.fc.weight.shape[1]
|
||||
self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim)
|
||||
# # 冻结除 FC 层之外的所有层
|
||||
# for name, param in self.base_model.named_parameters():
|
||||
# if 'fc' not in name:
|
||||
# param.requires_grad = False
|
||||
|
||||
def _get_projection_layer(self, dim_mlp, projection_dim):
|
||||
return nn.Sequential(
|
||||
nn.Linear(dim_mlp, dim_mlp),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(dim_mlp, projection_dim)
|
||||
)
|
||||
|
||||
def _get_model(self, model_name):
|
||||
base_model = None
|
||||
if model_name == 'resnet18':
|
||||
base_model = resnet18(pretrained=True)
|
||||
elif model_name == 'resnet50':
|
||||
base_model = resnet50(pretrained=True)
|
||||
# elif model_name == 'vit':
|
||||
# base_model = vit_base_patch32_224()
|
||||
return base_model
|
||||
def forward(self, x):
|
||||
assert self.base_model is not None, 'base_model is none'
|
||||
x = self.base_model(x)
|
||||
return x
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
182
model/distill.py
Normal file
182
model/distill.py
Normal file
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
|
||||
from einops import repeat
|
||||
from config import config as conf
|
||||
# helpers
|
||||
# Data Setup
|
||||
from tools.dataset import load_data
|
||||
train_dataloader, class_num = load_data(conf, training=True)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
def forward(self, img, distill_token=None):
|
||||
distilling = exists(distill_token)
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b=b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b=b)
|
||||
x = torch.cat((x, distill_tokens), dim=1)
|
||||
|
||||
x = self._attend(x)
|
||||
|
||||
if distilling:
|
||||
x, distill_tokens = x[:, :-1], x[:, -1]
|
||||
|
||||
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
out = self.mlp_head(x)
|
||||
|
||||
if distilling:
|
||||
return out, distill_tokens
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DistillableViT(DistillMixin, ViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = ViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
|
||||
class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableT2TViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = T2TViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
|
||||
class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
|
||||
def to_vit(self):
|
||||
v = EfficientViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x):
|
||||
return self.transformer(x)
|
||||
|
||||
|
||||
# knowledge distillation wrapper
|
||||
|
||||
class DistillWrapper(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
teacher,
|
||||
student,
|
||||
temperature=1.,
|
||||
alpha=0.5,
|
||||
hard=False,
|
||||
mlp_layernorm=False
|
||||
):
|
||||
super().__init__()
|
||||
# assert (isinstance(student, (
|
||||
# DistillableViT, DistillableT2TViT, DistillableEfficientViT))), 'student must be a vision transformer'
|
||||
if isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT)):
|
||||
pass
|
||||
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
dim = conf.embedding_size # student.dim
|
||||
num_classes = class_num # class_num # student.num_classes
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
# student is vit
|
||||
# self.distill_mlp = nn.Sequential(
|
||||
# nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
# nn.Linear(dim, num_classes)
|
||||
# )
|
||||
|
||||
# student is resnet
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
nn.Linear(dim, num_classes).to(device)
|
||||
)
|
||||
|
||||
def forward(self, img, labels, temperature=None, alpha=None, **kwargs):
|
||||
|
||||
alpha = default(alpha, self.alpha)
|
||||
T = default(temperature, self.temperature)
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_logits = self.teacher(img)
|
||||
teacher_logits = self.distill_mlp(teacher_logits) # teach is vit 初始化
|
||||
# student is vit
|
||||
# student_logits, distill_tokens = self.student(img, distill_token=self.distillation_token, **kwargs)
|
||||
# distill_logits = self.distill_mlp(distill_tokens)
|
||||
|
||||
# student is resnet
|
||||
student_logits = self.student(img)
|
||||
distill_logits = self.distill_mlp(student_logits)
|
||||
loss = F.cross_entropy(distill_logits, labels)
|
||||
# pdb.set_trace()
|
||||
if not self.hard:
|
||||
distill_loss = F.kl_div(
|
||||
F.log_softmax(distill_logits / T, dim=-1),
|
||||
F.softmax(teacher_logits / T, dim=-1).detach(),
|
||||
reduction='batchmean')
|
||||
distill_loss *= T ** 2
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim=-1)
|
||||
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
|
||||
# pdb.set_trace()
|
||||
return loss * (1 - alpha) + distill_loss * alpha
|
124
model/fmobilenet.py
Normal file
124
model/fmobilenet.py
Normal file
@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
class ConvBn(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class ConvBnPrelu(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBn(in_c, out_c, kernel, stride, padding, groups),
|
||||
nn.PReLU(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWise(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBnPrelu(in_c, groups, kernel=(1, 1), stride=1, padding=0),
|
||||
ConvBnPrelu(groups, groups, kernel=kernel, stride=stride, padding=padding, groups=groups),
|
||||
ConvBn(groups, out_c, kernel=(1, 1), stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWiseRes(nn.Module):
|
||||
"""DepthWise with Residual"""
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = DepthWise(in_c, out_c, kernel, stride, padding, groups)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
|
||||
class MultiDepthWiseRes(nn.Module):
|
||||
|
||||
def __init__(self, num_block, channels, kernel=(3, 3), stride=1, padding=1, groups=1):
|
||||
super().__init__()
|
||||
|
||||
self.net = nn.Sequential(*[
|
||||
DepthWiseRes(channels, channels, kernel, stride, padding, groups)
|
||||
for _ in range(num_block)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class FaceMobileNet(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBnPrelu(1, 64, kernel=(3, 3), stride=2, padding=1)
|
||||
self.conv2 = ConvBn(64, 64, kernel=(3, 3), stride=1, padding=1, groups=64)
|
||||
self.conv3 = DepthWise(64, 64, kernel=(3, 3), stride=2, padding=1, groups=128)
|
||||
self.conv4 = MultiDepthWiseRes(num_block=4, channels=64, kernel=3, stride=1, padding=1, groups=128)
|
||||
self.conv5 = DepthWise(64, 128, kernel=(3, 3), stride=2, padding=1, groups=256)
|
||||
self.conv6 = MultiDepthWiseRes(num_block=6, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv7 = DepthWise(128, 128, kernel=(3, 3), stride=2, padding=1, groups=512)
|
||||
self.conv8 = MultiDepthWiseRes(num_block=2, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv9 = ConvBnPrelu(128, 512, kernel=(1, 1))
|
||||
self.conv10 = ConvBn(512, 512, groups=512, kernel=(7, 7))
|
||||
self.flatten = Flatten()
|
||||
self.linear = nn.Linear(2048, embedding_size, bias=False)
|
||||
self.bn = nn.BatchNorm1d(embedding_size)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
out = self.conv4(out)
|
||||
out = self.conv5(out)
|
||||
out = self.conv6(out)
|
||||
out = self.conv7(out)
|
||||
out = self.conv8(out)
|
||||
out = self.conv9(out)
|
||||
out = self.conv10(out)
|
||||
out = self.flatten(out)
|
||||
out = self.linear(out)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
x = Image.open("../samples/009.jpg").convert('L')
|
||||
x = x.resize((128, 128))
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
x = x[None, None, ...]
|
||||
x = torch.from_numpy(x)
|
||||
net = FaceMobileNet(512)
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
out = net(x)
|
||||
print(out.shape)
|
233
model/lcnet.py
Normal file
233
model/lcnet.py
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import thop
|
||||
|
||||
# try:
|
||||
# import softpool_cuda
|
||||
# from SoftPool import soft_pool2d, SoftPool2d
|
||||
# except ImportError:
|
||||
# print('Please install SoftPool first: https://github.com/alexandrosstergiou/SoftPool')
|
||||
# exit(0)
|
||||
|
||||
NET_CONFIG = {
|
||||
# k, in_c, out_c, s, use_se
|
||||
"blocks2": [[3, 16, 32, 1, False]],
|
||||
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||
"blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||
}
|
||||
|
||||
|
||||
def autopad(k, p=None):
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||
return p
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class HardSwish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSwish, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.relu6(x+3) / 6
|
||||
|
||||
|
||||
class HardSigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSigmoid, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.relu6(x+3)) / 6
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
HardSigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
y = self.avgpool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(self, inp, oup, dw_size, stride, use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self.stride = stride
|
||||
self.inp = inp
|
||||
self.oup = oup
|
||||
self.dw_size = dw_size
|
||||
self.dw_sp = nn.Sequential(
|
||||
nn.Conv2d(self.inp, self.inp, kernel_size=self.dw_size, stride=self.stride,
|
||||
padding=autopad(self.dw_size, None), groups=self.inp, bias=False),
|
||||
nn.BatchNorm2d(self.inp),
|
||||
HardSwish(),
|
||||
|
||||
nn.Conv2d(self.inp, self.oup, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.oup),
|
||||
HardSwish(),
|
||||
)
|
||||
self.se = SELayer(self.oup)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dw_sp(x)
|
||||
if self.use_se:
|
||||
x = self.se(x)
|
||||
return x
|
||||
|
||||
|
||||
class PP_LCNet(nn.Module):
|
||||
def __init__(self, scale=1.0, class_num=256, class_expand=1280, dropout_prob=0.2):
|
||||
super(PP_LCNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv1 = nn.Conv2d(3, out_channels=make_divisible(16 * self.scale),
|
||||
kernel_size=3, stride=2, padding=1, bias=False)
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks2 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks2"])
|
||||
])
|
||||
|
||||
self.blocks3 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks3"])
|
||||
])
|
||||
|
||||
self.blocks4 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks4"])
|
||||
])
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks5 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks5"])
|
||||
])
|
||||
|
||||
self.blocks6 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks6"])
|
||||
])
|
||||
|
||||
self.GAP = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.last_conv = nn.Conv2d(in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
|
||||
out_channels=class_expand,
|
||||
kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
self.hardswish = HardSwish()
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
|
||||
self.fc = nn.Linear(class_expand, class_num)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks2(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks3(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks4(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks5(x)
|
||||
# print(x.shape)
|
||||
x = self.blocks6(x)
|
||||
# print(x.shape)
|
||||
|
||||
x = self.GAP(x)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
x = torch.flatten(x, start_dim=1, end_dim=-1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def PPLCNET_x0_25(**kwargs):
|
||||
model = PP_LCNet(scale=0.25, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_35(**kwargs):
|
||||
model = PP_LCNet(scale=0.35, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_5(**kwargs):
|
||||
model = PP_LCNet(scale=0.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_75(**kwargs):
|
||||
model = PP_LCNet(scale=0.75, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_0(**kwargs):
|
||||
model = PP_LCNet(scale=1.0, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_5(**kwargs):
|
||||
model = PP_LCNet(scale=1.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x2_0(**kwargs):
|
||||
model = PP_LCNet(scale=2.0, **kwargs)
|
||||
return model
|
||||
|
||||
def PPLCNET_x2_5(**kwargs):
|
||||
model = PP_LCNet(scale=2.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# input = torch.randn(1, 3, 640, 640)
|
||||
# model = PPLCNET_x2_5()
|
||||
# flops, params = thop.profile(model, inputs=(input,))
|
||||
# print('flops:', flops / 1000000000)
|
||||
# print('params:', params / 1000000)
|
||||
|
||||
model = PPLCNET_x1_0()
|
||||
# model_1 = PW_Conv(3, 16)
|
||||
input = torch.randn(2, 3, 256, 256)
|
||||
print(input.shape)
|
||||
output = model(input)
|
||||
print(output.shape) # [1, num_class]
|
||||
|
18
model/loss.py
Normal file
18
model/loss.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self, gamma=2):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.ce = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, input, target):
|
||||
|
||||
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
|
||||
logp = self.ce(input, target)
|
||||
p = torch.exp(-logp)
|
||||
loss = (1 - p) ** self.gamma * logp
|
||||
return loss.mean()
|
94
model/metric.py
Normal file
94
model/metric.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Definition of ArcFace loss and CosFace loss
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ArcFace(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
|
||||
"""ArcFace formula:
|
||||
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
|
||||
Note that:
|
||||
0 <= m + theta <= Pi
|
||||
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
|
||||
we have:
|
||||
cos(theta) < cos(Pi - m)
|
||||
So we can use cos(Pi - m) as threshold to check whether
|
||||
(m + theta) go out of [0, Pi]
|
||||
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = embedding_size
|
||||
self.out_features = class_num
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
self.cos_m = math.cos(m)
|
||||
self.sin_m = math.sin(m)
|
||||
self.th = math.cos(math.pi - m)
|
||||
self.mm = math.sin(math.pi - m) * m
|
||||
|
||||
def forward(self, input, label):
|
||||
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
# print('F.normalize(input)',input.shape)
|
||||
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
|
||||
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
|
||||
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
|
||||
# update y_i by phi in cosine
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
# print(f'output {(output * self.s).shape}')
|
||||
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
|
||||
return output * self.s
|
||||
|
||||
|
||||
class CosFace(nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||
"""
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, input, label):
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
phi = cosine - self.m
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
return output * self.s
|
||||
|
||||
class Distillation(nn.Module):
|
||||
def __init__(self, in_features, out_features, T=1.0):
|
||||
super(Distillation, self).__init__()
|
||||
self.T = T
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
def forward(self, input, labels):
|
||||
pass
|
274
model/mlp.py
Normal file
274
model/mlp.py
Normal file
@ -0,0 +1,274 @@
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
from model.resnet_pre import resnet18, conv1x1, BasicBlock, load_state_dict_from_url, model_urls
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim=256, output_dim=1):
|
||||
super(MLP, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.fc1 = nn.Linear(self.input_dim, 128) # 32
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 32)
|
||||
self.fc4 = nn.Linear(32, 16)
|
||||
self.fc5 = nn.Linear(16, self.output_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.bn1 = nn.BatchNorm1d(128)
|
||||
self.bn2 = nn.BatchNorm1d(64)
|
||||
self.bn3 = nn.BatchNorm1d(32)
|
||||
self.bn4 = nn.BatchNorm1d(16)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(self.bn1(x))
|
||||
x = self.fc2(x)
|
||||
x = self.relu(self.bn2(x))
|
||||
x = self.fc3(x)
|
||||
x = self.relu(self.bn3(x))
|
||||
x = self.fc4(x)
|
||||
x = self.relu(self.bn4(x))
|
||||
x = self.sigmoid(self.fc5(x))
|
||||
return x
|
||||
|
||||
|
||||
class Net2(nn.Module): # 该网络部署有风险,dnn推理有障碍
|
||||
def __init__(self, input_dim=960, output_dim=1):
|
||||
super(Net2, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
|
||||
# self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
|
||||
# self.conv4 = nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=1)
|
||||
self.maxPool1 = nn.MaxPool1d(kernel_size=3, stride=2)
|
||||
self.conv5 = nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=1)
|
||||
self.maxPool2 = nn.MaxPool1d(kernel_size=3, stride=2)
|
||||
|
||||
self.avgPool = nn.AdaptiveAvgPool1d(1)
|
||||
self.MaxPool = nn.AdaptiveMaxPool1d(1)
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.flatten = nn.Flatten()
|
||||
# self.conv6 = nn.Conv1d(128, 128, kernel_size=5, stride=2, padding=1)
|
||||
self.fc1 = nn.Linear(960, 128)
|
||||
self.fc21 = nn.Linear(960, 32)
|
||||
self.fc22 = nn.Linear(32, 128)
|
||||
self.fc3 = nn.Linear(128, 1)
|
||||
self.bn1 = nn.BatchNorm1d(16)
|
||||
self.bn2 = nn.BatchNorm1d(32)
|
||||
self.bn3 = nn.BatchNorm1d(64)
|
||||
self.bn4 = nn.BatchNorm1d(128)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x) # 16
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x) # 32
|
||||
x = self.relu(x)
|
||||
# x = self.conv3(x)
|
||||
# x = self.relu(x)
|
||||
# x = self.conv4(x) # 64
|
||||
# x = self.relu(x)
|
||||
# x = self.maxPool1(x)
|
||||
|
||||
x = self.conv5(x)
|
||||
x = self.relu(x)
|
||||
# x = self.conv6(x)
|
||||
# x = self.relu(x)
|
||||
# x = self.maxPool2(x)
|
||||
# x = self.MaxPool(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(x)
|
||||
x = self.flatten(x)
|
||||
|
||||
# pdb.set_trace()
|
||||
x1 = self.fc1(x)
|
||||
x2 = self.fc22(self.fc21(x))
|
||||
x = self.fc3(x1 + x2)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
class Net3(nn.Module): # 目前较合适的网络结构,相较于Net2,Net3的输出结果更加准确
|
||||
def __init__(self, pretrained=True, progress=True, num_classes=1, scale=0.75):
|
||||
super(Net3, self).__init__()
|
||||
self.resnet18 = resnet18(pretrained=pretrained, progress=progress)
|
||||
|
||||
# Remove the last three layers (layer3, layer4, avgpool, fc)
|
||||
# self.resnet18.layer3 = nn.Identity()
|
||||
# self.resnet18.layer4 = nn.Identity()
|
||||
self.resnet18.avgpool = nn.Identity()
|
||||
self.resnet18.fc = nn.Identity()
|
||||
self.flatten = nn.Flatten()
|
||||
# Calculate the output size after layer2
|
||||
# Assuming input size is 224x224, layer2 will have output size of 56x56
|
||||
# So, the flattened size will be 128 * scale * 56 * 56
|
||||
self.flattened_size = int(128 * (56 * 56) * scale * scale)
|
||||
|
||||
# Add new layers for classification
|
||||
self.classifier = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(),
|
||||
nn.Linear(384, num_classes), # layer1, layer2 in_features=96 # layer1 in_features=48 #layer3 in_features=192
|
||||
# nn.ReLU(),
|
||||
nn.Dropout(0.6),
|
||||
# nn.Linear(256, num_classes),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.resnet18.layer1(x)
|
||||
x = self.resnet18.layer2(x)
|
||||
x = self.resnet18.layer3(x)
|
||||
x = self.resnet18.layer4(x)
|
||||
|
||||
# Debugging: Print the shape of the tensor before flattening
|
||||
# print("Shape before flattening:", x.shape)
|
||||
|
||||
# Ensure the tensor is flattened correctly
|
||||
# x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, scale=0.75):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def Net4(arch, pretrained, progress, **kwargs):
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
net2 = Net2()
|
||||
input_tensor = torch.randn(10, 1, 64)
|
||||
# 前向传播
|
||||
output_tensor = net2(input_tensor)
|
||||
# pdb.set_trace()
|
||||
print("输入张量形状:", input_tensor.shape)
|
||||
print("输出张量形状:", output_tensor.shape)
|
||||
'''
|
||||
|
||||
# model = Net3(pretrained=True, num_classes=1) # 预训练从resnet中间结果获取数据训练模型
|
||||
model = Net4('resnet18', True, True)
|
||||
input_tensor = torch.randn(1, 3, 224, 244) # Adjust batch size to 10
|
||||
output = model(input_tensor)
|
||||
print(output.shape) # Should be [10, 2]
|
148
model/mobilenet_v1.py
Normal file
148
model/mobilenet_v1.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torchvision.ops.misc import Conv2dNormActivation
|
||||
from config import config as conf
|
||||
|
||||
__all__ = [
|
||||
"MobileNetV1",
|
||||
"DepthWiseSeparableConv2d",
|
||||
"mobilenet_v1",
|
||||
]
|
||||
|
||||
|
||||
class MobileNetV1(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = conf.embedding_size,
|
||||
) -> None:
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
Conv2dNormActivation(3,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
DepthWiseSeparableConv2d(32, 64, 1),
|
||||
DepthWiseSeparableConv2d(64, 128, 2),
|
||||
DepthWiseSeparableConv2d(128, 128, 1),
|
||||
DepthWiseSeparableConv2d(128, 256, 2),
|
||||
DepthWiseSeparableConv2d(256, 256, 1),
|
||||
DepthWiseSeparableConv2d(256, 512, 2),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 1024, 2),
|
||||
DepthWiseSeparableConv2d(1024, 1024, 1),
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d((7, 7))
|
||||
|
||||
self.classifier = nn.Linear(1024, num_classes)
|
||||
|
||||
# Initialize neural network weights
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self._forward_impl(x)
|
||||
|
||||
return out
|
||||
|
||||
# Support torch.script function
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
out = self.features(x)
|
||||
out = self.avgpool(out)
|
||||
out = torch.flatten(out, 1)
|
||||
out = self.classifier(out)
|
||||
|
||||
return out
|
||||
|
||||
def _initialize_weights(self) -> None:
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0, 0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
class DepthWiseSeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(DepthWiseSeparableConv2d, self).__init__()
|
||||
self.stride = stride
|
||||
if stride not in [1, 2]:
|
||||
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
Conv2dNormActivation(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
Conv2dNormActivation(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self.conv(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
|
||||
model = MobileNetV1(**kwargs)
|
||||
|
||||
return model
|
200
model/mobilenet_v2.py
Normal file
200
model/mobilenet_v2.py
Normal file
@ -0,0 +1,200 @@
|
||||
from torch import nn
|
||||
from .utils import load_state_dict_from_url
|
||||
from config import config as conf
|
||||
|
||||
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
|
||||
padding = (kernel_size - 1) // 2
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
norm_layer(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
norm_layer(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=conf.embedding_size,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
norm_layer=None):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
norm_layer: Module specifying the normalization layer to use
|
||||
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
|
||||
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def mobilenet_v2(pretrained=True, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a MobileNetV2 architecture from
|
||||
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||
progress=progress)
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
#.load_state_dict(state_dict)
|
||||
return model
|
200
model/mobilenet_v3.py
Normal file
200
model/mobilenet_v3.py
Normal file
@ -0,0 +1,200 @@
|
||||
'''MobileNetV3 in PyTorch.
|
||||
|
||||
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
from config import config as conf
|
||||
|
||||
|
||||
class hswish(nn.Module):
|
||||
def forward(self, x):
|
||||
out = x * F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class hsigmoid(nn.Module):
|
||||
def forward(self, x):
|
||||
out = F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class SeModule(nn.Module):
|
||||
def __init__(self, in_size, reduction=4):
|
||||
super(SeModule, self).__init__()
|
||||
self.se = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size),
|
||||
hsigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.se(x)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''expand + depthwise + pointwise'''
|
||||
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
||||
super(Block, self).__init__()
|
||||
self.stride = stride
|
||||
self.se = semodule
|
||||
|
||||
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear1 = nolinear
|
||||
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear2 = nolinear
|
||||
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_size)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride == 1 and in_size != out_size:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_size),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.se != None:
|
||||
out = self.se(out)
|
||||
out = out + self.shortcut(x) if self.stride==1 else out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV3_Large(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Large, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
|
||||
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(3, 40, 240, 80, hswish(), None, 2),
|
||||
Block(3, 80, 200, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
|
||||
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
|
||||
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
|
||||
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
|
||||
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(960)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(960, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class MobileNetV3_Small(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Small, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
|
||||
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(576)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(576, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def test():
|
||||
net = MobileNetV3_Small()
|
||||
x = torch.randn(2,3,224,224)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
265
model/mobilevit.py
Normal file
265
model/mobilevit.py
Normal file
@ -0,0 +1,265 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
from config import config as conf
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
|
||||
pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
|
||||
super().__init__()
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
L = [2, 4, 3]
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
|
||||
|
||||
self.mv2 = nn.ModuleList([])
|
||||
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
|
||||
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
|
||||
|
||||
self.mvit = nn.ModuleList([])
|
||||
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
|
||||
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
|
||||
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
|
||||
|
||||
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
|
||||
|
||||
self.pool = nn.AvgPool2d(ih // 32, 1)
|
||||
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
x = self.conv1(x)
|
||||
x = self.mv2[0](x)
|
||||
|
||||
x = self.mv2[1](x)
|
||||
x = self.mv2[2](x)
|
||||
x = self.mv2[3](x) # Repeat
|
||||
|
||||
x = self.mv2[4](x)
|
||||
x = self.mvit[0](x)
|
||||
|
||||
x = self.mv2[5](x)
|
||||
x = self.mvit[1](x)
|
||||
|
||||
x = self.mv2[6](x)
|
||||
x = self.mvit[2](x)
|
||||
x = self.conv2(x)
|
||||
|
||||
|
||||
#print('pool_before',x.shape)
|
||||
x = self.pool(x).view(-1, x.shape[1])
|
||||
#print('self_pool',self.pool)
|
||||
#print('pool_after',x.shape)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobilevit_xxs():
|
||||
dims = [64, 80, 96]
|
||||
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
|
||||
|
||||
|
||||
def mobilevit_xs():
|
||||
dims = [96, 120, 144]
|
||||
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000)
|
||||
|
||||
|
||||
def mobilevit_s():
|
||||
dims = [144, 192, 240]
|
||||
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
|
||||
return MobileViT((conf.img_size, conf.img_size), dims, channels, num_classes=conf.embedding_size)
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = torch.randn(5, 3, 256, 256)
|
||||
|
||||
vit = mobilevit_xxs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_xs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_s()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
412
model/quant_test_resnet.py
Normal file
412
model/quant_test_resnet.py
Normal file
@ -0,0 +1,412 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from .utils import load_state_dict_from_url
|
||||
from typing import Type, Any, Callable, Union, List, Optional
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class QuantizableBasicBlock(BasicBlock):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_relu = torch.nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.add_relu.add_relu(out, identity)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
num_classes: int = 1000,
|
||||
zero_init_residual: bool = False,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||
|
||||
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _resnet(
|
||||
arch: str,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any
|
||||
) -> ResNet:
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
||||
return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
142
model/resbam.py
Normal file
142
model/resbam.py
Normal file
@ -0,0 +1,142 @@
|
||||
from model.CBAM import CBAM
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.Tool import GeM as gem
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inchannel, outchannel, stride=1, dowsample=None):
|
||||
# super(Bottleneck, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=outchannel, kernel_size=1, stride=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(outchannel)
|
||||
self.conv2 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel, kernel_size=3, bias=False,
|
||||
stride=stride, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(outchannel)
|
||||
self.conv3 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel * self.expansion, stride=1, bias=False,
|
||||
kernel_size=1)
|
||||
self.bn3 = nn.BatchNorm2d(outchannel * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = dowsample
|
||||
|
||||
def forward(self, x):
|
||||
self.identity = x
|
||||
# print('>>>>>>>>',type(x))
|
||||
if self.downsample is not None:
|
||||
# print('>>>>downsample>>>>', type(self.downsample))
|
||||
self.identity = self.downsample(x)
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
# print('>>>>out>>>identity',out.size(),self.identity.size())
|
||||
out = out + self.identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class resnet(nn.Module):
|
||||
def __init__(self, block=Bottleneck, block_num=[3, 4, 6, 3], num_class=1000):
|
||||
super().__init__()
|
||||
self.in_channel = 64
|
||||
self.conv1 = nn.Conv2d(in_channels=3,
|
||||
out_channels=self.in_channel,
|
||||
stride=2,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.cbam = CBAM(self.in_channel)
|
||||
self.cbam1 = CBAM(2048)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, block_num[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.gem = gem()
|
||||
self.fc = nn.Linear(512 * block.expansion, num_class)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight, mode='fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, channel, block_num, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(channel * block.expansion))
|
||||
layer = []
|
||||
layer.append(block(self.in_channel, channel, stride, downsample))
|
||||
self.in_channel = channel * block.expansion
|
||||
for _ in range(1, block_num):
|
||||
layer.append(block(self.in_channel, channel))
|
||||
return nn.Sequential(*layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.cbam(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.cbam1(x)
|
||||
# x = self.avgpool(x)
|
||||
x = self.gem(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class TripletNet(nn.Module):
|
||||
def __init__(self, num_class, flag=True):
|
||||
super(TripletNet, self).__init__()
|
||||
self.initnet = rescbam(num_class)
|
||||
self.flag = flag
|
||||
|
||||
def forward(self, x1, x2=None, x3=None):
|
||||
if self.flag:
|
||||
output1 = self.initnet(x1)
|
||||
output2 = self.initnet(x2)
|
||||
output3 = self.initnet(x3)
|
||||
return output1, output2, output3
|
||||
else:
|
||||
output = self.initnet(x1)
|
||||
return output
|
||||
|
||||
|
||||
def rescbam(num_class):
|
||||
return resnet(block=Bottleneck, block_num=[3, 4, 6, 3], num_class=num_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
input1 = torch.randn(4, 3, 640, 640)
|
||||
input2 = torch.randn(4, 3, 640, 640)
|
||||
input3 = torch.randn(4, 3, 640, 640)
|
||||
|
||||
# rescbam测试
|
||||
# Resnet50 = rescbam(512)
|
||||
# output = Resnet50.forward(input1)
|
||||
# print(Resnet50)
|
||||
|
||||
# trnet测试
|
||||
trnet = TripletNet(512)
|
||||
output = trnet(input1, input2, input3)
|
||||
print(output)
|
189
model/resnet.py
Normal file
189
model/resnet.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""resnet in pytorch
|
||||
|
||||
|
||||
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||
|
||||
Deep Residual Learning for Image Recognition
|
||||
https://arxiv.org/abs/1512.03385v1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from config import config as conf
|
||||
from CBAM import CBAM
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic Block for resnet 18 and resnet 34
|
||||
|
||||
"""
|
||||
|
||||
#BasicBlock and BottleNeck block
|
||||
#have different output size
|
||||
#we use class attribute expansion
|
||||
#to distinct
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
|
||||
#residual function
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
#shortcut
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
#the shortcut output dimension is not the same with residual function
|
||||
#use 1*1 convolution to match the dimension
|
||||
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
"""Residual block for resnet over 50 layers
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_block, cbam = False, num_classes=conf.embedding_size):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = 64
|
||||
|
||||
# self.conv1 = nn.Sequential(
|
||||
# nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64),
|
||||
# nn.ReLU(inplace=True))
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64,stride=2,kernel_size=7,padding=3,bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
|
||||
self.cbam = CBAM(self.in_channels)
|
||||
|
||||
#we use a different inputsize than the original paper
|
||||
#so conv2_x's stride is 1
|
||||
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||
self.cbam1 = CBAM(self.in_channels)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight,mode = 'fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||
same as a neuron netowork layer, ex. conv layer), one layer may
|
||||
contain more than one residual block
|
||||
|
||||
Args:
|
||||
block: block type, basic block or bottle neck block
|
||||
out_channels: output depth channel number of this layer
|
||||
num_blocks: how many blocks per layer
|
||||
stride: the stride of the first block of this layer
|
||||
|
||||
Return:
|
||||
return a resnet layer
|
||||
"""
|
||||
|
||||
# we have num_block blocks per layer, the first block
|
||||
# could be 1 or 2, other blocks would always be 1
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.conv1(x)
|
||||
if cbam:
|
||||
output = self.cbam(x)
|
||||
output = self.conv2_x(output)
|
||||
output = self.conv3_x(output)
|
||||
output = self.conv4_x(output)
|
||||
output = self.conv5_x(output)
|
||||
if cbam:
|
||||
output = self.cbam1(x)
|
||||
print('pollBefore',output.shape)
|
||||
output = self.avg_pool(output)
|
||||
print('poolAfter',output.shape)
|
||||
output = output.view(output.size(0), -1)
|
||||
print('fcBefore',output.shape)
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
||||
|
||||
def resnet18(cbam = False):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], cbam)
|
||||
|
||||
def resnet34():
|
||||
""" return a ResNet 34 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
def resnet50():
|
||||
""" return a ResNet 50 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||
|
||||
def resnet101():
|
||||
""" return a ResNet 101 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||
|
||||
def resnet152():
|
||||
""" return a ResNet 152 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||
|
||||
|
271
model/resnet_attention.py
Normal file
271
model/resnet_attention.py
Normal file
@ -0,0 +1,271 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
"""通道注意力模块,通过全局平均池化和最大池化提取特征,经过MLP生成通道权重"""
|
||||
|
||||
def __init__(self, in_channels, reduction_ratio=16):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
# 共享的MLP层
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = self.fc(self.avg_pool(x))
|
||||
max_out = self.fc(self.max_pool(x))
|
||||
out = avg_out + max_out
|
||||
return torch.sigmoid(out)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""空间注意力模块,通过通道维度的平均和最大值操作,生成空间权重"""
|
||||
|
||||
def __init__(self, kernel_size=7):
|
||||
super(SpatialAttention, self).__init__()
|
||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
out = torch.cat([avg_out, max_out], dim=1)
|
||||
out = self.conv(out)
|
||||
return torch.sigmoid(out)
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
"""CBAM注意力模块,串联通道注意力和空间注意力"""
|
||||
|
||||
def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
|
||||
super(CBAM, self).__init__()
|
||||
self.channel_att = ChannelAttention(in_channels, reduction_ratio)
|
||||
self.spatial_att = SpatialAttention(kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * self.channel_att(x)
|
||||
x = x * self.spatial_att(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""ResNet基础残差块,适用于ResNet18和ResNet34"""
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
# 是否使用CBAM注意力机制
|
||||
self.use_cbam = use_cbam
|
||||
if use_cbam:
|
||||
self.cbam = CBAM(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
# # 如果使用注意力机制,应用CBAM
|
||||
if self.use_cbam:
|
||||
out = self.cbam(out)
|
||||
|
||||
# 如果有下采样,调整shortcut连接
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# 残差连接
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""ResNet瓶颈残差块,适用于ResNet50及更深的网络"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
# 1x1卷积降维
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
# 3x3卷积
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
# 1x1卷积升维
|
||||
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
# 是否使用CBAM注意力机制
|
||||
self.use_cbam = use_cbam
|
||||
if use_cbam:
|
||||
self.cbam = CBAM(out_channels * self.expansion)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
# # 如果使用注意力机制,应用CBAM
|
||||
if self.use_cbam:
|
||||
out = self.cbam(out)
|
||||
|
||||
# 如果有下采样,调整shortcut连接
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# 残差连接
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""集成了CBAM注意力机制的ResNet模型"""
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, use_cbam=True):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_channels = 64
|
||||
self.use_cbam = use_cbam
|
||||
|
||||
# 初始卷积层
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.cbam1 = CBAM(64)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# 残差块层
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
|
||||
self.cbam2 = CBAM(512)
|
||||
# 全局平均池化和分类器
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
# 初始化权重
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# 零初始化最后一个BN层的权重,使残差分支初始为0
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, out_channels, blocks, stride=1):
|
||||
downsample = None
|
||||
# 如果通道数不匹配或需要调整步长,创建下采样层
|
||||
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
# 第一个块可能需要下采样
|
||||
layers.append(block(self.in_channels, out_channels, stride, downsample, use_cbam=self.use_cbam))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
# 添加剩余的块
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.in_channels, out_channels, use_cbam=self.use_cbam))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# 特征提取
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
# if self.use_cbam:
|
||||
# x = self.cbam1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
# if self.use_cbam:
|
||||
# x = self.cbam2(x)
|
||||
# 分类
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# 工厂函数,创建不同深度的ResNet模型
|
||||
def resnet18_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
def resnet34_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet50_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet101_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet152_cbam(pretrained=False, **kwargs):
|
||||
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
|
||||
|
||||
# 测试模型
|
||||
if __name__ == "__main__":
|
||||
# 创建一个带有CBAM注意力机制的ResNet50模型
|
||||
model = resnet50_cbam(num_classes=10)
|
||||
# 测试输入
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
y = model(x)
|
||||
print(f"输入形状: {x.shape}")
|
||||
print(f"输出形状: {y.shape}")
|
480
model/resnet_pre.py
Normal file
480
model/resnet_pre.py
Normal file
@ -0,0 +1,480 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from config import config as conf
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
# from .utils import load_state_dict_from_url
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super(SpatialAttention, self).__init__()
|
||||
|
||||
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||
padding = 3 if kernel_size == 7 else 1
|
||||
|
||||
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
x = torch.cat([avg_out, max_out], dim=1)
|
||||
x = self.conv1(x)
|
||||
return self.sigmoid(x)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
|
||||
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
if self.bam:
|
||||
self.bam = SpatialAttention()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
|
||||
if self.bam:
|
||||
out = out * self.bam(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
|
||||
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, scale=conf.channel_ratio):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
print("ResNet scale: >>>>>>>>>> ", scale)
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
||||
self.maxpool2 = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
||||
)
|
||||
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# return model
|
||||
|
||||
class CustomResNet18(nn.Module):
|
||||
def __init__(self, model, num_classes=conf.custom_num_classes):
|
||||
super(CustomResNet18, self).__init__()
|
||||
self.custom_model = nn.Sequential(*list(model.children())[:-1])
|
||||
self.fc = nn.Linear(model.fc.in_features, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.custom_model(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resnet14(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-14 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet18(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
**kwargs: Additional arguments passed to ResNet, including:
|
||||
scale (float): Channel scaling ratio (default: conf.channel_ratio)
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
4
model/utils.py
Normal file
4
model/utils.py
Normal file
@ -0,0 +1,4 @@
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
42
model/vit.py
Normal file
42
model/vit.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial, reduce
|
||||
from operator import mul
|
||||
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
|
||||
__all__ = [
|
||||
'vit_small',
|
||||
'vit_base',
|
||||
]
|
||||
|
||||
|
||||
def vit_small(**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
# model.default_cfg = _cfg()
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(**kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, num_classes=256,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg(num_classes=256)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = torch.randn(8, 3, 224, 224)
|
||||
vit = vit_base()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
# print(count_parameters(vit))
|
Reference in New Issue
Block a user