48 lines
1.8 KiB
Python
48 lines
1.8 KiB
Python
import torch
|
|
from config import config as conf
|
|
import torch.nn as nn
|
|
import torchvision.models as models
|
|
from model.resnet_pre import resnet18, resnet50
|
|
# from model.vit import vit_base_patch16_224, vit_base_patch32_224
|
|
|
|
|
|
class ContrastiveModel(nn.Module):
|
|
def __init__(self, projection_dim, model_name, contraposition=False):
|
|
super(ContrastiveModel, self).__init__()
|
|
self.contraposition = contraposition
|
|
self.base_model = self._get_model(model_name)
|
|
if not self.contraposition:
|
|
if 'vit' in model_name:
|
|
dim_mlp = self.base_model.head.weight.shape[1]
|
|
self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim)
|
|
else:
|
|
dim_mlp = self.base_model.fc.weight.shape[1]
|
|
self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim)
|
|
# # 冻结除 FC 层之外的所有层
|
|
# for name, param in self.base_model.named_parameters():
|
|
# if 'fc' not in name:
|
|
# param.requires_grad = False
|
|
|
|
def _get_projection_layer(self, dim_mlp, projection_dim):
|
|
return nn.Sequential(
|
|
nn.Linear(dim_mlp, dim_mlp),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(dim_mlp, projection_dim)
|
|
)
|
|
|
|
def _get_model(self, model_name):
|
|
base_model = None
|
|
if model_name == 'resnet18':
|
|
base_model = resnet18(pretrained=True)
|
|
elif model_name == 'resnet50':
|
|
base_model = resnet50(pretrained=True)
|
|
# elif model_name == 'vit':
|
|
# base_model = vit_base_patch32_224()
|
|
return base_model
|
|
def forward(self, x):
|
|
assert self.base_model is not None, 'base_model is none'
|
|
x = self.base_model(x)
|
|
return x
|
|
|
|
if __name__ == '__main__':
|
|
pass |