import torch from config import config as conf import torch.nn as nn import torchvision.models as models from model.resnet_pre import resnet18, resnet50 # from model.vit import vit_base_patch16_224, vit_base_patch32_224 class ContrastiveModel(nn.Module): def __init__(self, projection_dim, model_name, contraposition=False): super(ContrastiveModel, self).__init__() self.contraposition = contraposition self.base_model = self._get_model(model_name) if not self.contraposition: if 'vit' in model_name: dim_mlp = self.base_model.head.weight.shape[1] self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim) else: dim_mlp = self.base_model.fc.weight.shape[1] self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim) # # 冻结除 FC 层之外的所有层 # for name, param in self.base_model.named_parameters(): # if 'fc' not in name: # param.requires_grad = False def _get_projection_layer(self, dim_mlp, projection_dim): return nn.Sequential( nn.Linear(dim_mlp, dim_mlp), nn.ReLU(inplace=True), nn.Linear(dim_mlp, projection_dim) ) def _get_model(self, model_name): base_model = None if model_name == 'resnet18': base_model = resnet18(pretrained=True) elif model_name == 'resnet50': base_model = resnet50(pretrained=True) # elif model_name == 'vit': # base_model = vit_base_patch32_224() return base_model def forward(self, x): assert self.base_model is not None, 'base_model is none' x = self.base_model(x) return x if __name__ == '__main__': pass