Files
ieemoo-ai-contrast/model/compare.py
2025-06-11 15:23:50 +08:00

48 lines
1.8 KiB
Python

import torch
from config import config as conf
import torch.nn as nn
import torchvision.models as models
from model.resnet_pre import resnet18, resnet50
# from model.vit import vit_base_patch16_224, vit_base_patch32_224
class ContrastiveModel(nn.Module):
def __init__(self, projection_dim, model_name, contraposition=False):
super(ContrastiveModel, self).__init__()
self.contraposition = contraposition
self.base_model = self._get_model(model_name)
if not self.contraposition:
if 'vit' in model_name:
dim_mlp = self.base_model.head.weight.shape[1]
self.base_model.head = self._get_projection_layer(dim_mlp, projection_dim)
else:
dim_mlp = self.base_model.fc.weight.shape[1]
self.base_model.fc = self._get_projection_layer(dim_mlp, projection_dim)
# # 冻结除 FC 层之外的所有层
# for name, param in self.base_model.named_parameters():
# if 'fc' not in name:
# param.requires_grad = False
def _get_projection_layer(self, dim_mlp, projection_dim):
return nn.Sequential(
nn.Linear(dim_mlp, dim_mlp),
nn.ReLU(inplace=True),
nn.Linear(dim_mlp, projection_dim)
)
def _get_model(self, model_name):
base_model = None
if model_name == 'resnet18':
base_model = resnet18(pretrained=True)
elif model_name == 'resnet50':
base_model = resnet50(pretrained=True)
# elif model_name == 'vit':
# base_model = vit_base_patch32_224()
return base_model
def forward(self, x):
assert self.base_model is not None, 'base_model is none'
x = self.base_model(x)
return x
if __name__ == '__main__':
pass