Files
ieemoo-ai-contrast/model/mlp.py
2025-06-11 15:23:50 +08:00

274 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pdb
import torch
import torch.nn as nn
import torch.nn.init as init
from model.resnet_pre import resnet18, conv1x1, BasicBlock, load_state_dict_from_url, model_urls
class MLP(nn.Module):
def __init__(self, input_dim=256, output_dim=1):
super(MLP, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.fc1 = nn.Linear(self.input_dim, 128) # 32
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 32)
self.fc4 = nn.Linear(32, 16)
self.fc5 = nn.Linear(16, self.output_dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(0.5)
self.bn1 = nn.BatchNorm1d(128)
self.bn2 = nn.BatchNorm1d(64)
self.bn3 = nn.BatchNorm1d(32)
self.bn4 = nn.BatchNorm1d(16)
for m in self.modules():
if isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
x = self.fc1(x)
x = self.relu(self.bn1(x))
x = self.fc2(x)
x = self.relu(self.bn2(x))
x = self.fc3(x)
x = self.relu(self.bn3(x))
x = self.fc4(x)
x = self.relu(self.bn4(x))
x = self.sigmoid(self.fc5(x))
return x
class Net2(nn.Module): # 该网络部署有风险dnn推理有障碍
def __init__(self, input_dim=960, output_dim=1):
super(Net2, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
# self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
# self.conv4 = nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=1)
self.maxPool1 = nn.MaxPool1d(kernel_size=3, stride=2)
self.conv5 = nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=1)
self.maxPool2 = nn.MaxPool1d(kernel_size=3, stride=2)
self.avgPool = nn.AdaptiveAvgPool1d(1)
self.MaxPool = nn.AdaptiveMaxPool1d(1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(0.5)
self.flatten = nn.Flatten()
# self.conv6 = nn.Conv1d(128, 128, kernel_size=5, stride=2, padding=1)
self.fc1 = nn.Linear(960, 128)
self.fc21 = nn.Linear(960, 32)
self.fc22 = nn.Linear(32, 128)
self.fc3 = nn.Linear(128, 1)
self.bn1 = nn.BatchNorm1d(16)
self.bn2 = nn.BatchNorm1d(32)
self.bn3 = nn.BatchNorm1d(64)
self.bn4 = nn.BatchNorm1d(128)
for m in self.modules():
if isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def forward(self, x):
x = self.conv1(x) # 16
x = self.relu(x)
x = self.conv2(x) # 32
x = self.relu(x)
# x = self.conv3(x)
# x = self.relu(x)
# x = self.conv4(x) # 64
# x = self.relu(x)
# x = self.maxPool1(x)
x = self.conv5(x)
x = self.relu(x)
# x = self.conv6(x)
# x = self.relu(x)
# x = self.maxPool2(x)
# x = self.MaxPool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = self.flatten(x)
# pdb.set_trace()
x1 = self.fc1(x)
x2 = self.fc22(self.fc21(x))
x = self.fc3(x1 + x2)
x = self.sigmoid(x)
return x
class Net3(nn.Module): # 目前较合适的网络结构相较于Net2Net3的输出结果更加准确
def __init__(self, pretrained=True, progress=True, num_classes=1, scale=0.75):
super(Net3, self).__init__()
self.resnet18 = resnet18(pretrained=pretrained, progress=progress)
# Remove the last three layers (layer3, layer4, avgpool, fc)
# self.resnet18.layer3 = nn.Identity()
# self.resnet18.layer4 = nn.Identity()
self.resnet18.avgpool = nn.Identity()
self.resnet18.fc = nn.Identity()
self.flatten = nn.Flatten()
# Calculate the output size after layer2
# Assuming input size is 224x224, layer2 will have output size of 56x56
# So, the flattened size will be 128 * scale * 56 * 56
self.flattened_size = int(128 * (56 * 56) * scale * scale)
# Add new layers for classification
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(384, num_classes), # layer1, layer2 in_features=96 # layer1 in_features=48 #layer3 in_features=192
# nn.ReLU(),
nn.Dropout(0.6),
# nn.Linear(256, num_classes),
nn.Sigmoid()
)
def forward(self, x):
x = self.resnet18.layer1(x)
x = self.resnet18.layer2(x)
x = self.resnet18.layer3(x)
x = self.resnet18.layer4(x)
# Debugging: Print the shape of the tensor before flattening
# print("Shape before flattening:", x.shape)
# Ensure the tensor is flattened correctly
# x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, scale=0.75):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
self.sigmoid = nn.Sigmoid()
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = self.sigmoid(x)
return x
def forward(self, x):
return self._forward_impl(x)
def Net4(arch, pretrained, progress, **kwargs):
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
src_state_dict = state_dict
target_state_dict = model.state_dict()
skip_keys = []
# skip mismatch size tensors in case of pretraining
for k in src_state_dict.keys():
if k not in target_state_dict:
continue
if src_state_dict[k].size() != target_state_dict[k].size():
skip_keys.append(k)
for k in skip_keys:
del src_state_dict[k]
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
return model
if __name__ == '__main__':
'''
net2 = Net2()
input_tensor = torch.randn(10, 1, 64)
# 前向传播
output_tensor = net2(input_tensor)
# pdb.set_trace()
print("输入张量形状:", input_tensor.shape)
print("输出张量形状:", output_tensor.shape)
'''
# model = Net3(pretrained=True, num_classes=1) # 预训练从resnet中间结果获取数据训练模型
model = Net4('resnet18', True, True)
input_tensor = torch.randn(1, 3, 224, 244) # Adjust batch size to 10
output = model(input_tensor)
print(output.shape) # Should be [10, 2]