104 lines
3.1 KiB
Python
104 lines
3.1 KiB
Python
import os
|
|
import os.path as osp
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
from model import resnet18
|
|
from PIL import Image
|
|
|
|
from torch.nn.functional import softmax
|
|
from config import config as conf
|
|
import time
|
|
|
|
embedding_size = conf.embedding_size
|
|
img_size = conf.img_size
|
|
device = conf.device
|
|
|
|
def load_contrast_model():
|
|
model = resnet18().to(conf.device)
|
|
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
|
model.eval()
|
|
print('load model {} '.format(conf.testbackbone))
|
|
|
|
return model
|
|
|
|
|
|
def group_image(imageDirs, batch) -> list:
|
|
images = []
|
|
"""Group image paths by batch size"""
|
|
with os.scandir(imageDirs) as entries:
|
|
for imgpth in entries:
|
|
print(imgpth)
|
|
images.append(os.sep.join([imageDirs, imgpth.name]))
|
|
print(f"{len(images)} images in {imageDirs}")
|
|
size = len(images)
|
|
res = []
|
|
for i in range(0, size, batch):
|
|
end = min(batch + i, size)
|
|
res.append(images[i: end])
|
|
return res
|
|
|
|
def test_preprocess(images: list, transform) -> torch.Tensor:
|
|
res = []
|
|
for img in images:
|
|
# print(img)
|
|
im = Image.open(img)
|
|
im = transform(im)
|
|
res.append(im)
|
|
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
|
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
|
data = torch.stack(res)
|
|
return data
|
|
|
|
def featurize(images: list, transform, net, device) -> dict:
|
|
"""featurize each image and save into a dictionary
|
|
Args:
|
|
images: image paths
|
|
transform: test transform
|
|
net: pretrained model
|
|
device: cpu or cuda
|
|
Returns:
|
|
Dict (key: imagePath, value: feature)
|
|
"""
|
|
data = test_preprocess(images, transform)
|
|
data = data.to(device)
|
|
net = net.to(device)
|
|
with torch.no_grad():
|
|
features = net(data)
|
|
# res = {img: feature for (img, feature) in zip(images, features)}
|
|
return features
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Network Setup
|
|
if conf.testbackbone == 'resnet18':
|
|
model = resnet18().to(device)
|
|
else:
|
|
raise ValueError('Have not model {}'.format(conf.backbone))
|
|
|
|
print('load model {} '.format(conf.testbackbone))
|
|
# model = nn.DataParallel(model).to(conf.device)
|
|
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
|
model.eval()
|
|
|
|
# images = unique_image(conf.test_list)
|
|
# images = [osp.join(conf.test_val, img) for img in images]
|
|
# print('images', images)
|
|
# images = ['./data/2250_train/val/6920616313186/6920616313186_6920616313186_20240220-124502_53d2e103-ae3a-4689-b745-9d8723b770fe_front_returnGood_70f75407b7ae_31_01.jpg']
|
|
|
|
|
|
# groups = group_image(conf.test_val, conf.test_batch_size) ##根据batch_size取图片
|
|
groups = group_image('img_test', 1) ##根据batch_size取图片, 默认batch_size = 8
|
|
|
|
feature_dict = dict()
|
|
for group in groups:
|
|
s = time.time()
|
|
features = featurize(group, conf.test_transform, model, conf.device)
|
|
e = time.time()
|
|
print('time: {}'.format(e - s))
|
|
# out = softmax(features, dim=1).argmax(dim=1)
|
|
# print('d >>> {}'. format(out))
|
|
# feature_dict.update(d)
|