Files
ieemoo-ai-detecttracking/contrast/feat_extract/model/mobilenet_v1.py
2025-04-18 14:41:53 +08:00

149 lines
5.1 KiB
Python

# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Callable, Any, Optional
import torch
from torch import Tensor
from torch import nn
from torchvision.ops.misc import Conv2dNormActivation
from config import config as conf
__all__ = [
"MobileNetV1",
"DepthWiseSeparableConv2d",
"mobilenet_v1",
]
class MobileNetV1(nn.Module):
def __init__(
self,
num_classes: int = conf.embedding_size,
) -> None:
super(MobileNetV1, self).__init__()
self.features = nn.Sequential(
Conv2dNormActivation(3,
32,
kernel_size=3,
stride=2,
padding=1,
norm_layer=nn.BatchNorm2d,
activation_layer=nn.ReLU,
inplace=True,
bias=False,
),
DepthWiseSeparableConv2d(32, 64, 1),
DepthWiseSeparableConv2d(64, 128, 2),
DepthWiseSeparableConv2d(128, 128, 1),
DepthWiseSeparableConv2d(128, 256, 2),
DepthWiseSeparableConv2d(256, 256, 1),
DepthWiseSeparableConv2d(256, 512, 2),
DepthWiseSeparableConv2d(512, 512, 1),
DepthWiseSeparableConv2d(512, 512, 1),
DepthWiseSeparableConv2d(512, 512, 1),
DepthWiseSeparableConv2d(512, 512, 1),
DepthWiseSeparableConv2d(512, 512, 1),
DepthWiseSeparableConv2d(512, 1024, 2),
DepthWiseSeparableConv2d(1024, 1024, 1),
)
self.avgpool = nn.AvgPool2d((7, 7))
self.classifier = nn.Linear(1024, num_classes)
# Initialize neural network weights
self._initialize_weights()
def forward(self, x: Tensor) -> Tensor:
out = self._forward_impl(x)
return out
# Support torch.script function
def _forward_impl(self, x: Tensor) -> Tensor:
out = self.features(x)
out = self.avgpool(out)
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
def _initialize_weights(self) -> None:
for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0, 0.01)
nn.init.zeros_(module.bias)
class DepthWiseSeparableConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(DepthWiseSeparableConv2d, self).__init__()
self.stride = stride
if stride not in [1, 2]:
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.conv = nn.Sequential(
Conv2dNormActivation(in_channels,
in_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
norm_layer=norm_layer,
activation_layer=nn.ReLU,
inplace=True,
bias=False,
),
Conv2dNormActivation(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
norm_layer=norm_layer,
activation_layer=nn.ReLU,
inplace=True,
bias=False,
),
)
def forward(self, x: Tensor) -> Tensor:
out = self.conv(x)
return out
def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
model = MobileNetV1(**kwargs)
return model