import os import os.path as osp import torch import numpy as np from model import resnet18 from PIL import Image from torch.nn.functional import softmax from config import config as conf import time embedding_size = conf.embedding_size img_size = conf.img_size device = conf.device def load_contrast_model(): model = resnet18().to(conf.device) model.load_state_dict(torch.load(conf.test_model, map_location=conf.device)) model.eval() print('load model {} '.format(conf.testbackbone)) return model def group_image(imageDirs, batch) -> list: images = [] """Group image paths by batch size""" with os.scandir(imageDirs) as entries: for imgpth in entries: print(imgpth) images.append(os.sep.join([imageDirs, imgpth.name])) print(f"{len(images)} images in {imageDirs}") size = len(images) res = [] for i in range(0, size, batch): end = min(batch + i, size) res.append(images[i: end]) return res def test_preprocess(images: list, transform) -> torch.Tensor: res = [] for img in images: # print(img) im = Image.open(img) im = transform(im) res.append(im) # data = torch.cat(res, dim=0) # shape: (batch, 128, 128) # data = data[:, None, :, :] # shape: (batch, 1, 128, 128) data = torch.stack(res) return data def featurize(images: list, transform, net, device) -> dict: """featurize each image and save into a dictionary Args: images: image paths transform: test transform net: pretrained model device: cpu or cuda Returns: Dict (key: imagePath, value: feature) """ data = test_preprocess(images, transform) data = data.to(device) net = net.to(device) with torch.no_grad(): features = net(data) # res = {img: feature for (img, feature) in zip(images, features)} return features if __name__ == '__main__': # Network Setup if conf.testbackbone == 'resnet18': model = resnet18().to(device) else: raise ValueError('Have not model {}'.format(conf.backbone)) print('load model {} '.format(conf.testbackbone)) # model = nn.DataParallel(model).to(conf.device) model.load_state_dict(torch.load(conf.test_model, map_location=conf.device)) model.eval() # images = unique_image(conf.test_list) # images = [osp.join(conf.test_val, img) for img in images] # print('images', images) # images = ['./data/2250_train/val/6920616313186/6920616313186_6920616313186_20240220-124502_53d2e103-ae3a-4689-b745-9d8723b770fe_front_returnGood_70f75407b7ae_31_01.jpg'] # groups = group_image(conf.test_val, conf.test_batch_size) ##根据batch_size取图片 groups = group_image('img_test', 1) ##根据batch_size取图片, 默认batch_size = 8 feature_dict = dict() for group in groups: s = time.time() features = featurize(group, conf.test_transform, model, conf.device) e = time.time() print('time: {}'.format(e - s)) # out = softmax(features, dim=1).argmax(dim=1) # print('d >>> {}'. format(out)) # feature_dict.update(d)