from model.CBAM import CBAM import torch import torch.nn as nn from model.Tool import GeM as gem class Bottleneck(nn.Module): expansion = 4 def __init__(self, inchannel, outchannel, stride=1, dowsample=None): # super(Bottleneck, self).__init__() super().__init__() self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=outchannel, kernel_size=1, stride=1, bias=False) self.bn1 = nn.BatchNorm2d(outchannel) self.conv2 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel, kernel_size=3, bias=False, stride=stride, padding=1) self.bn2 = nn.BatchNorm2d(outchannel) self.conv3 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel * self.expansion, stride=1, bias=False, kernel_size=1) self.bn3 = nn.BatchNorm2d(outchannel * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = dowsample def forward(self, x): self.identity = x # print('>>>>>>>>',type(x)) if self.downsample is not None: # print('>>>>downsample>>>>', type(self.downsample)) self.identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) # print('>>>>out>>>identity',out.size(),self.identity.size()) out = out + self.identity out = self.relu(out) return out class resnet(nn.Module): def __init__(self, block=Bottleneck, block_num=[3, 4, 6, 3], num_class=1000): super().__init__() self.in_channel = 64 self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channel, stride=2, kernel_size=7, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(self.in_channel) self.relu = nn.ReLU(inplace=True) self.cbam = CBAM(self.in_channel) self.cbam1 = CBAM(2048) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, block_num[0], stride=1) self.layer2 = self._make_layer(block, 128, block_num[1], stride=2) self.layer3 = self._make_layer(block, 256, block_num[2], stride=2) self.layer4 = self._make_layer(block, 512, block_num[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.gem = gem() self.fc = nn.Linear(512 * block.expansion, num_class) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal(m.weight, mode='fan_out', nonlinearity='relu') if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 1.0) def _make_layer(self, block, channel, block_num, stride=1): downsample = None if stride != 1 or self.in_channel != channel * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(channel * block.expansion)) layer = [] layer.append(block(self.in_channel, channel, stride, downsample)) self.in_channel = channel * block.expansion for _ in range(1, block_num): layer.append(block(self.in_channel, channel)) return nn.Sequential(*layer) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.cbam(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.cbam1(x) # x = self.avgpool(x) x = self.gem(x) x = torch.flatten(x, 1) x = self.fc(x) return x class TripletNet(nn.Module): def __init__(self, num_class, flag=True): super(TripletNet, self).__init__() self.initnet = rescbam(num_class) self.flag = flag def forward(self, x1, x2=None, x3=None): if self.flag: output1 = self.initnet(x1) output2 = self.initnet(x2) output3 = self.initnet(x3) return output1, output2, output3 else: output = self.initnet(x1) return output def rescbam(num_class): return resnet(block=Bottleneck, block_num=[3, 4, 6, 3], num_class=num_class) if __name__ == '__main__': input1 = torch.randn(4, 3, 640, 640) input2 = torch.randn(4, 3, 640, 640) input3 = torch.randn(4, 3, 640, 640) # rescbam测试 # Resnet50 = rescbam(512) # output = Resnet50.forward(input1) # print(Resnet50) # trnet测试 trnet = TripletNet(512) output = trnet(input1, input2, input3) print(output)