Files
ieemoo-ai-contrast/model/resnet_attention.py
2025-06-11 15:23:50 +08:00

271 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelAttention(nn.Module):
"""通道注意力模块通过全局平均池化和最大池化提取特征经过MLP生成通道权重"""
def __init__(self, in_channels, reduction_ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# 共享的MLP层
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
)
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return torch.sigmoid(out)
class SpatialAttention(nn.Module):
"""空间注意力模块,通过通道维度的平均和最大值操作,生成空间权重"""
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv(out)
return torch.sigmoid(out)
class CBAM(nn.Module):
"""CBAM注意力模块串联通道注意力和空间注意力"""
def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
super(CBAM, self).__init__()
self.channel_att = ChannelAttention(in_channels, reduction_ratio)
self.spatial_att = SpatialAttention(kernel_size)
def forward(self, x):
x = x * self.channel_att(x)
x = x * self.spatial_att(x)
return x
class BasicBlock(nn.Module):
"""ResNet基础残差块适用于ResNet18和ResNet34"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
self.stride = stride
# 是否使用CBAM注意力机制
self.use_cbam = use_cbam
if use_cbam:
self.cbam = CBAM(out_channels)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# # 如果使用注意力机制应用CBAM
if self.use_cbam:
out = self.cbam(out)
# 如果有下采样调整shortcut连接
if self.downsample is not None:
identity = self.downsample(x)
# 残差连接
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""ResNet瓶颈残差块适用于ResNet50及更深的网络"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_cbam=False):
super(Bottleneck, self).__init__()
# 1x1卷积降维
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
# 3x3卷积
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 1x1卷积升维
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
# 是否使用CBAM注意力机制
self.use_cbam = use_cbam
if use_cbam:
self.cbam = CBAM(out_channels * self.expansion)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
# # 如果使用注意力机制应用CBAM
if self.use_cbam:
out = self.cbam(out)
# 如果有下采样调整shortcut连接
if self.downsample is not None:
identity = self.downsample(x)
# 残差连接
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
"""集成了CBAM注意力机制的ResNet模型"""
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, use_cbam=True):
super(ResNet, self).__init__()
self.in_channels = 64
self.use_cbam = use_cbam
# 初始卷积层
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.cbam1 = CBAM(64)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 残差块层
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.cbam2 = CBAM(512)
# 全局平均池化和分类器
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# 初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# 零初始化最后一个BN层的权重使残差分支初始为0
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
# 如果通道数不匹配或需要调整步长,创建下采样层
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = []
# 第一个块可能需要下采样
layers.append(block(self.in_channels, out_channels, stride, downsample, use_cbam=self.use_cbam))
self.in_channels = out_channels * block.expansion
# 添加剩余的块
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels, use_cbam=self.use_cbam))
return nn.Sequential(*layers)
def forward(self, x):
# 特征提取
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
# if self.use_cbam:
# x = self.cbam1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# if self.use_cbam:
# x = self.cbam2(x)
# 分类
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# 工厂函数创建不同深度的ResNet模型
def resnet18_cbam(pretrained=False, **kwargs):
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
def resnet34_cbam(pretrained=False, **kwargs):
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
def resnet50_cbam(pretrained=False, **kwargs):
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnet101_cbam(pretrained=False, **kwargs):
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
def resnet152_cbam(pretrained=False, **kwargs):
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
# 测试模型
if __name__ == "__main__":
# 创建一个带有CBAM注意力机制的ResNet50模型
model = resnet50_cbam(num_classes=10)
# 测试输入
x = torch.randn(1, 3, 224, 224)
y = model(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {y.shape}")