contrast performance evaluatation have done!

This commit is contained in:
王庆刚
2024-09-05 19:01:49 +08:00
parent f978d4174f
commit 7309dec166
85 changed files with 3941 additions and 248 deletions

View File

@ -0,0 +1,103 @@
import os
import os.path as osp
import torch
import numpy as np
from model import resnet18
from PIL import Image
from torch.nn.functional import softmax
from config import config as conf
import time
embedding_size = conf.embedding_size
img_size = conf.img_size
device = conf.device
def load_contrast_model():
model = resnet18().to(conf.device)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()
print('load model {} '.format(conf.testbackbone))
return model
def group_image(imageDirs, batch) -> list:
images = []
"""Group image paths by batch size"""
with os.scandir(imageDirs) as entries:
for imgpth in entries:
print(imgpth)
images.append(os.sep.join([imageDirs, imgpth.name]))
print(f"{len(images)} images in {imageDirs}")
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i: end])
return res
def test_preprocess(images: list, transform) -> torch.Tensor:
res = []
for img in images:
# print(img)
im = Image.open(img)
im = transform(im)
res.append(im)
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
data = torch.stack(res)
return data
def featurize(images: list, transform, net, device) -> dict:
"""featurize each image and save into a dictionary
Args:
images: image paths
transform: test transform
net: pretrained model
device: cpu or cuda
Returns:
Dict (key: imagePath, value: feature)
"""
data = test_preprocess(images, transform)
data = data.to(device)
net = net.to(device)
with torch.no_grad():
features = net(data)
# res = {img: feature for (img, feature) in zip(images, features)}
return features
if __name__ == '__main__':
# Network Setup
if conf.testbackbone == 'resnet18':
model = resnet18().to(device)
else:
raise ValueError('Have not model {}'.format(conf.backbone))
print('load model {} '.format(conf.testbackbone))
# model = nn.DataParallel(model).to(conf.device)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()
# images = unique_image(conf.test_list)
# images = [osp.join(conf.test_val, img) for img in images]
# print('images', images)
# images = ['./data/2250_train/val/6920616313186/6920616313186_6920616313186_20240220-124502_53d2e103-ae3a-4689-b745-9d8723b770fe_front_returnGood_70f75407b7ae_31_01.jpg']
# groups = group_image(conf.test_val, conf.test_batch_size) ##根据batch_size取图片
groups = group_image('img_test', 1) ##根据batch_size取图片, 默认batch_size = 8
feature_dict = dict()
for group in groups:
s = time.time()
features = featurize(group, conf.test_transform, model, conf.device)
e = time.time()
print('time: {}'.format(e - s))
# out = softmax(features, dim=1).argmax(dim=1)
# print('d >>> {}'. format(out))
# feature_dict.update(d)