Files
ieemoo-ai-imageassessment/contrast/test_logic.py
2024-11-27 15:37:10 +08:00

318 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import pdb
import random
import json
import time
import torch
from PIL import Image
from contrast.model import resnet18, MobileNetV3_Large
# import pymilvus
# from pymilvus import (
# connections,
# utility,
# FieldSchema, CollectionSchema, DataType,
# Collection,
# Milvus
# )
# from config import config as conf
from contrast.search import ImgSearch
from contrast.img_data import queueImgs_add
import sys
from threading import Thread
sys.path.append('../tools')
from tools.config import cfg as conf
from tools.config import gvalue
def test_preprocess(images: list, actionModel) -> torch.Tensor:
res = []
for img in images:
# print(img)
try:
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(images, model, actionModel):
data = test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf.device)
features = model(data)
return features
def group_image(images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def barcode_state(barcodeIDList):
with open('contrast/main_barcodes.json', 'r') as file:
data = json.load(file)
main_barcode = list(data.values())[0]
barIdList_true = []
barIdList_false = []
for barId in barcodeIDList:
bar = barId.split('_')[1]
if bar in main_barcode:
barIdList_true.append(barId)
else:
barIdList_false.append(barId)
return barIdList_true, barIdList_false
def getFeatureList(barList, imgList, model, actionModel):
featList = [[] for i in range(len(barList))]
for index, feat in enumerate(imgList):
groups = group_image(feat)
for group in groups:
feat_tensor = inference(group, model, actionModel)
for fe in feat_tensor:
if fe.device == 'cpu':
fe_np = fe.squeeze().detach().numpy()
else:
fe_np = fe.squeeze().detach().cpu().numpy()
featList[index].append(fe_np)
return featList
def img2feature(imgs_dict, model, actionModel, barcode_flag):
if not len(imgs_dict) > 0:
raise ValueError("Tracking fail no images files provided")
queBarIdList = list(imgs_dict.keys())
if barcode_flag:
# # ========判断barcode是否在特征库============
queBarIdList_t, barIdList_f = barcode_state(queBarIdList)
queFeatList_t = []
if len(queBarIdList_t) == 0:
print(f"All barcodes are not in the main_library: {barIdList_f}")
return queBarIdList_t, queFeatList_t
else:
if len(barIdList_f) > 0: ## 将不在barcode库中的barcode及图片删除
print(f"These barcodes are not in the main_library: {barIdList_f}")
for bar_f in barIdList_f:
del imgs_dict[bar_f]
queImgList_t = list(imgs_dict.values())
queFeatList_t = getFeatureList(queBarIdList_t, queImgList_t, model, actionModel)
return queBarIdList_t, queFeatList_t
else:
queImgsList = list(imgs_dict.values())
queFeatList = getFeatureList(queBarIdList, queImgsList, model, actionModel)
return queBarIdList, queFeatList
# def create_milvus(collection_name, host, port, barcode_list, features):
# # 1. connect to Milvus
# fmt = "\n=== {:30} ===\n"
# connections.connect('default', host=host, port=port) # 连接到 Milvus 服务器
# has = utility.has_collection(collection_name) ##检查collection_name是否存在milvus中
# print(f"Does collection {collection_name} exist in Milvus: {has}")
# # if has: ## 删除collection_name的库
# # utility.drop_collection(collection_name)
#
# # 2. create colllection
# fields = [
# FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), ###图片路径
# FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=256)
# ]
# schema = CollectionSchema(fields)
# print(fmt.format(f"Create collection {collection_name}"))
# hello_milvus = Collection(collection_name, schema, consistency_level="Strong")
# # 3. insert data
# for i in range(len(features)):
# entities = [
# # provide the pk field because `auto_id` is set to False
# [barcode_list[i]] * len(features[i]), ## 图片维度和向量维度需匹配 每个向量都生成一个barcode
# features[i],
# ]
# print(fmt.format("Start inserting entities"))
# insert_result = hello_milvus.insert(entities)
# hello_milvus.flush()
# print(f"Number of entities in {collection_name}: {hello_milvus.num_entities}") # check the num_entities
# return hello_milvus
# def load_collection(collection_name):
# collection = Collection(collection_name)
# # collection.release() ### 将collection从加载状态变成未加载
# # collection.drop_index() ### 删除索引
#
# index_params = {
# "index_type": "IVF_FLAT",
# # "index_type": "IVF_SQ8",
# # "index_type": "GPU_IVF_FLAT",
# "metric_type": "COSINE",
# "params": {
# "nlist": 10000
# }
# }
# #### 准确率低
# # index_params = {
# # "index_type": "IVF_PQ",
# # "metric_type": "COSINE",
# # "params": {
# # "nlist": 99,
# # "m": 2,
# # "nbits": 8
# # }
# # }
# collection.create_index(
# field_name="embeddings",
# index_params=index_params,
# index_name="SQ8"
# )
# collection.load()
# return collection
# def similarity(queImgsDict, add_flag, barcode_flag, main_milvus, model, barcode_list, actionModel):
# searchImg = ImgSearch() ## 相似度比较
# # 将输入图片加入临时库
# if add_flag:
# if actionModel:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-2]), model, actionModel, barcode_flag)
# else:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-2]), model, actionModel, barcode_flag)
#
# if barcode_flag: ### 加购 有barcode -> 输出top10和top1
# if len(queBarIdList) == 0:
# top10, top1 = {}, {}
# else:
# for bar in queBarIdList:
# # gvalue.tempLibList.append(bar) ## 临时特征库key值为macID_barcode_trackID
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
# gvalue.tempLibLists[gvalue.mac_id] += [bar] ## 临时特征库key值为macID_barcode_trackID
# else:
# gvalue.tempLibLists[gvalue.mac_id] = [bar]
# # 存入临时特征库
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList, queBarIdFeatures)
#
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
# 'host': conf.host,
# 'port': conf.port,
# 'barcode_list': queBarIdList,
# 'features': queBarIdFeatures})
# thread.start()
# start1 = time.time()
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# start2 = time.time()
# print('search top10 time>>>> {}'.format(start2-start1))
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
# start3 = time.time()
# print('search top1 time>>>>> {}'.format(start3-start2))
# return top10, top1, gvalue.tempLibLists
# else: # 加购 无barcode -> 输出top10
# # 无barcode时生成随机数作为字典key值
# queBarIdList_rand = []
# for i in range(len(queBarIdList)):
# random_number = ''.join(random.choices('0123456789', k=10))
# queBarIdList_rand.append(str(random_number))
# # gvalue.tempLibList.append(str(random_number))
# if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
# gvalue.tempLibLists[gvalue.mac_id] += [str(random_number)] ## 临时特征库key值为macID_barcode_trackID
# else:
# gvalue.tempLibLists[gvalue.mac_id] = [str(random_number)]
# # create_milvus('temp_features', conf.host, conf.port, queBarIdList_rand, queBarIdFeatures)
# thread = Thread(target=create_milvus, kwargs={'collection_name': 'temp_features',
# 'host': conf.host,
# 'port': conf.port,
# 'barcode_list': queBarIdList_rand,
# 'features': queBarIdFeatures})
# thread.start()
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# # print(f'top10: {top10}')
# return top10, gvalue.tempLibLists
# else: # 退购 -> 输出top10和topn
# if gvalue.tempLibLists.get(gvalue.mac_id) is None:
# gvalue.tempLibList = []
# else:
# gvalue.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
# ## 加载临时特征库
# tempMilvusName = "temp_features"
# has = utility.has_collection(tempMilvusName)
# print(f"Does collection {tempMilvusName} exist in Milvus: {has}")
# tempMilvus = load_collection(tempMilvusName)
# print(f"Number of entities in {tempMilvusName}: {tempMilvus.num_entities}")
# if actionModel:
# barcode_list = barcode_list
# else:
# barcode_list = queueImgs_add['barcode_list']
# if actionModel:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[2:-1]), model, actionModel, barcode_flag)
# else:
# queBarIdList, queBarIdFeatures = img2feature(dict(list(queImgsDict.items())[:-3]), model, actionModel, barcode_flag)
# if barcode_flag:
# if len(queBarIdList) == 0:
# top10, top1, top_n = {}, {}, {}
# else:
# start1 = time.time()
# top1 = searchImg.mainSearch1(main_milvus, queBarIdList, queBarIdFeatures)
# start2 = time.time()
# print('search top1 time>>>> {}'.format(start2 - start1))
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# start3 = time.time()
# print('search top10 time>>>> {}'.format(start3 - start2))
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
# # print(f'top10: {top10}, top1: {top1}, topn: {top_n}')
# return top10, top1, top_n
# else:
# top10 = searchImg.mainSearch10(main_milvus, queBarIdList, queBarIdFeatures)
# top_n = searchImg.tempSearch(tempMilvus, queBarIdList, queBarIdFeatures, barcode_list, gvalue.tempLibList)
# # print(f'top10: {top10}, topn: {top_n}')
# return top10, top_n
def similarity_interface(dataCollection):
queImgsDict = dataCollection.queImgsDict
add_flag = dataCollection.add_flag
barcode_flag = dataCollection.barcode_flag
main_milvus = dataCollection.mainMilvus
#tempLibList = dataCollection.tempLibList
model = dataCollection.model
actionModel = dataCollection.actionModel
barcode_list = dataCollection.barcode_list
#return similarity(queImgsDict, add_flag, barcode_flag, main_milvus, tempLibList, model, barcode_list, actionModel)
return 0
if __name__ == '__main__':
pass
# connections.connect('default', host=conf.host, port=conf.port)
# # 加载主特征库
# mainMilvusName = "main_features"
# has = utility.has_collection(mainMilvusName)
# print(f"Does collection {mainMilvusName} exist in Milvus: {has}")
# mainMilvus = Collection(mainMilvusName)
# mainMilvus.load()
# model = initModel()
# # queueImgs_add queueImgs_back 分别为加购和退购时的入参
# add_flag = queueImgs_add['add_flag']
# barcode_flag = queueImgs_add['barcode_flag']
# tempLibList = [] # 临时特征库的barcodeId_list
# # tempLibList = ['3500610085338_01', '4260290263776_01'] ##test
# if add_flag:
# if barcode_flag: # 加购 有barcode -> 输出top10和top1
# top10, top1, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
# print(f"top10: {top10}\ntop1: {top1}")
# else: # 加购 无barcode -> 输出top10
# top10, tempLibList = similarity(queueImgs_add, add_flag, barcode_flag, mainMilvus, tempLibList, model)
# else: # 退购 -> 输出top10和topn
# top10, topn = similarity(queueImgs_back, add_flag, barcode_flag, mainMilvus, tempLibList, model)