Files
detecttracking/contrast/feat_extract/resnet_vit/inference.py
王庆刚 8bbee310ba bakeup
2024-11-25 18:05:08 +08:00

104 lines
3.1 KiB
Python

import os
import os.path as osp
import torch
import numpy as np
from model import resnet18
from PIL import Image
from torch.nn.functional import softmax
from config import config as conf
import time
embedding_size = conf.embedding_size
img_size = conf.img_size
device = conf.device
def load_contrast_model():
model = resnet18().to(conf.device)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()
print('load model {} '.format(conf.testbackbone))
return model
def group_image(imageDirs, batch) -> list:
images = []
"""Group image paths by batch size"""
with os.scandir(imageDirs) as entries:
for imgpth in entries:
print(imgpth)
images.append(os.sep.join([imageDirs, imgpth.name]))
print(f"{len(images)} images in {imageDirs}")
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i: end])
return res
def test_preprocess(images: list, transform) -> torch.Tensor:
res = []
for img in images:
# print(img)
im = Image.open(img)
im = transform(im)
res.append(im)
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
data = torch.stack(res)
return data
def featurize(images: list, transform, net, device) -> dict:
"""featurize each image and save into a dictionary
Args:
images: image paths
transform: test transform
net: pretrained model
device: cpu or cuda
Returns:
Dict (key: imagePath, value: feature)
"""
data = test_preprocess(images, transform)
data = data.to(device)
net = net.to(device)
with torch.no_grad():
features = net(data)
# res = {img: feature for (img, feature) in zip(images, features)}
return features
if __name__ == '__main__':
# Network Setup
if conf.testbackbone == 'resnet18':
model = resnet18().to(device)
else:
raise ValueError('Have not model {}'.format(conf.backbone))
print('load model {} '.format(conf.testbackbone))
# model = nn.DataParallel(model).to(conf.device)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()
# images = unique_image(conf.test_list)
# images = [osp.join(conf.test_val, img) for img in images]
# print('images', images)
# images = ['./data/2250_train/val/6920616313186/6920616313186_6920616313186_20240220-124502_53d2e103-ae3a-4689-b745-9d8723b770fe_front_returnGood_70f75407b7ae_31_01.jpg']
# groups = group_image(conf.test_val, conf.test_batch_size) ##根据batch_size取图片
groups = group_image('img_test', 1) ##根据batch_size取图片, 默认batch_size = 8
feature_dict = dict()
for group in groups:
s = time.time()
features = featurize(group, conf.test_transform, model, conf.device)
e = time.time()
print('time: {}'.format(e - s))
# out = softmax(features, dim=1).argmax(dim=1)
# print('d >>> {}'. format(out))
# feature_dict.update(d)