rebuild
This commit is contained in:
48
model/compare.py
Normal file
48
model/compare.py
Normal file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from config import config as conf
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from model.resnet_pre import resnet18, resnet50
|
||||
# from model.vit import vit_base_patch16_224, vit_base_patch32_224
|
||||
|
||||
|
||||
class ContrastiveModel(nn.Module):
|
||||
def __init__(self, projection_dim, model_name, contraposition=False):
|
||||
super(ContrastiveModel, self).__init__()
|
||||
self.contraposition = contraposition
|
||||
self.base_model = self._get_model(model_name)
|
||||
if not self.contraposition:
|
||||
if 'vit' in model_name:
|
||||
dim_mlp = self.base_model.head.weight.shape[1]
|
||||
self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim)
|
||||
else:
|
||||
dim_mlp = self.base_model.fc.weight.shape[1]
|
||||
self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim)
|
||||
# # 冻结除 FC 层之外的所有层
|
||||
# for name, param in self.base_model.named_parameters():
|
||||
# if 'fc' not in name:
|
||||
# param.requires_grad = False
|
||||
|
||||
def _get_projection_layer(self, dim_mlp, projection_dim):
|
||||
return nn.Sequential(
|
||||
nn.Linear(dim_mlp, dim_mlp),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(dim_mlp, projection_dim)
|
||||
)
|
||||
|
||||
def _get_model(self, model_name):
|
||||
base_model = None
|
||||
if model_name == 'resnet18':
|
||||
base_model = resnet18(pretrained=True)
|
||||
elif model_name == 'resnet50':
|
||||
base_model = resnet50(pretrained=True)
|
||||
# elif model_name == 'vit':
|
||||
# base_model = vit_base_patch32_224()
|
||||
return base_model
|
||||
def forward(self, x):
|
||||
assert self.base_model is not None, 'base_model is none'
|
||||
x = self.base_model(x)
|
||||
return x
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
Reference in New Issue
Block a user