update
This commit is contained in:
BIN
contrast/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
209
contrast/genfeats.py
Normal file
209
contrast/genfeats.py
Normal file
@ -0,0 +1,209 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Sun Nov 3 12:05:19 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
from config import config as conf
|
||||
from model import resnet18 as resnet18
|
||||
from feat_inference import inference_image
|
||||
|
||||
|
||||
IMG_FORMAT = ['.bmp', '.jpg', '.jpeg', '.png']
|
||||
|
||||
'''======= 0. 配置特征提取模型地址 ======='''
|
||||
model_path = conf.test_model
|
||||
model_path = r"D:\exhibition\ckpt\zhanting.pth"
|
||||
|
||||
##============ load resnet mdoel
|
||||
model = resnet18().to(conf.device)
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=conf.device))
|
||||
model.eval()
|
||||
print('load model {} '.format(conf.testbackbone))
|
||||
|
||||
def get_std_barcodeDict(bcdpath, savepath):
|
||||
'''
|
||||
inputs:
|
||||
bcdpath: 已清洗的barcode样本图像,如果barcode下有'base'文件夹,只选用该文件夹下图像
|
||||
(default = r'\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_1771')
|
||||
功能:
|
||||
生成并保存只有一个key值的字典 {barcode: [imgpath1, imgpath1, ...]},
|
||||
savepath: 字典存储地址,文件名格式:barcode.pickle
|
||||
'''
|
||||
|
||||
# savepath = r'\\192.168.1.28\share\测试_202406\contrast\std_barcodes'
|
||||
|
||||
'''读取数据集中 barcode 列表'''
|
||||
stdBarcodeList = []
|
||||
for filename in os.listdir(bcdpath):
|
||||
filepath = os.path.join(bcdpath, filename)
|
||||
# if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
|
||||
# continue
|
||||
stdBarcodeList.append(filename)
|
||||
|
||||
bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList]
|
||||
|
||||
'''遍历数据集,针对每一个barcode,生成并保存字典{barcode: [imgpath1, imgpath1, ...]}'''
|
||||
k = 0
|
||||
errbarcodes = []
|
||||
for barcode, bpath in bcdPaths:
|
||||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
if os.path.isfile(pickpath):
|
||||
continue
|
||||
|
||||
stdBarcodeDict = {}
|
||||
stdBarcodeDict[barcode] = []
|
||||
for root, dirs, files in os.walk(bpath):
|
||||
imgpaths = []
|
||||
if "base" in dirs:
|
||||
broot = os.path.join(root, "base")
|
||||
for imgname in os.listdir(broot):
|
||||
imgpath = os.path.join(broot, imgname)
|
||||
file, ext = os.path.splitext(imgpath)
|
||||
|
||||
if ext not in IMG_FORMAT:
|
||||
continue
|
||||
imgpaths.append(imgpath)
|
||||
|
||||
stdBarcodeDict[barcode].extend(imgpaths)
|
||||
break
|
||||
|
||||
else:
|
||||
for imgname in files:
|
||||
imgpath = os.path.join(root, imgname)
|
||||
_, ext = os.path.splitext(imgpath)
|
||||
if ext not in IMG_FORMAT: continue
|
||||
imgpaths.append(imgpath)
|
||||
stdBarcodeDict[barcode].extend(imgpaths)
|
||||
|
||||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
with open(pickpath, 'wb') as f:
|
||||
pickle.dump(stdBarcodeDict, f)
|
||||
print(f"Barcode: {barcode}")
|
||||
|
||||
# k += 1
|
||||
# if k == 10:
|
||||
# break
|
||||
print(f"Len of errbarcodes: {len(errbarcodes)}")
|
||||
return
|
||||
|
||||
|
||||
|
||||
def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||||
'''
|
||||
inputs:
|
||||
imgPath: 该文件夹下的 pickle 文件格式 {barcode: [imgpath1, imgpath1, ...]}
|
||||
featPath: imgPath图像对应特征的存储地址
|
||||
功能:
|
||||
对 imgPath中图像进行特征提取,生成只有一个key值的字典,
|
||||
{barcode: features},features.shape=(nsample, 256),并保存至 featPath 中
|
||||
|
||||
'''
|
||||
|
||||
# imgPath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes"
|
||||
# featPath = r"\\192.168.1.28\share\测试_202406\contrast\std_features"
|
||||
stdBarcodeDict = {}
|
||||
stdBarcodeDict_ft16 = {}
|
||||
|
||||
|
||||
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名、该pickle文件中字典的key值'''
|
||||
|
||||
k = 0
|
||||
for filename in os.listdir(imgPath):
|
||||
bcd, ext = os.path.splitext(filename)
|
||||
pkpath = os.path.join(featPath, f"{bcd}.pickle")
|
||||
|
||||
if os.path.isfile(pkpath): continue
|
||||
if bcdSet is not None and bcd not in bcdSet:
|
||||
continue
|
||||
|
||||
filepath = os.path.join(imgPath, filename)
|
||||
|
||||
stdbDict = {}
|
||||
stdbDict_ft16 = {}
|
||||
stdbDict_uint8 = {}
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
try:
|
||||
with open(filepath, 'rb') as f:
|
||||
bpDict = pickle.load(f)
|
||||
for barcode, imgpaths in bpDict.items():
|
||||
# feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
|
||||
feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
|
||||
feature /= np.linalg.norm(feature, axis=1)[:, None]
|
||||
|
||||
# float16
|
||||
feature_ft16 = feature.astype(np.float16)
|
||||
feature_ft16 /= np.linalg.norm(feature_ft16, axis=1)[:, None]
|
||||
|
||||
# uint8, 两种策略,1) 精度损失小, 2) 计算复杂度小
|
||||
# feature_uint8, _ = ft16_to_uint8(feature_ft16)
|
||||
feature_uint8 = (feature_ft16*128).astype(np.int8)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error accured at: {filename}, with Exception is: {e}")
|
||||
|
||||
'''================ 保存单个barcode特征 ================'''
|
||||
##================== float32
|
||||
stdbDict["barcode"] = barcode
|
||||
stdbDict["imgpaths"] = imgpaths
|
||||
stdbDict["feats_ft32"] = feature
|
||||
stdbDict["feats_ft16"] = feature_ft16
|
||||
stdbDict["feats_uint8"] = feature_uint8
|
||||
|
||||
with open(pkpath, 'wb') as f:
|
||||
pickle.dump(stdbDict, f)
|
||||
|
||||
stdBarcodeDict[barcode] = feature
|
||||
stdBarcodeDict_ft16[barcode] = feature_ft16
|
||||
|
||||
t2 = time.time()
|
||||
print(f"Barcode: {barcode}, need time: {t2-t1:.1f} secs")
|
||||
# k += 1
|
||||
# if k == 10:
|
||||
# break
|
||||
|
||||
##================== float32
|
||||
# pickpath = os.path.join(featPath, f"barcode_features_{k}.pickle")
|
||||
# with open(pickpath, 'wb') as f:
|
||||
# pickle.dump(stdBarcodeDict, f)
|
||||
|
||||
##================== float16
|
||||
# pickpath_ft16 = os.path.join(featPath, f"barcode_features_ft16_{k}.pickle")
|
||||
# with open(pickpath_ft16, 'wb') as f:
|
||||
# pickle.dump(stdBarcodeDict_ft16, f)
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
def genfeatures(imgpath, bcdpath, featpath):
|
||||
|
||||
get_std_barcodeDict(imgpath, bcdpath)
|
||||
stdfeat_infer(bcdpath, featpath, bcdSet=None)
|
||||
|
||||
print(f"Features have generated, saved in: {featpath}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
imgpath = r"\\192.168.1.28\share\展厅barcode数据\整理\zhantingBase"
|
||||
bcdpath = r"D:\exhibition\dataset\bcdpath"
|
||||
featpath = r"D:\exhibition\dataset\feats"
|
||||
|
||||
genfeatures(imgpath, bcdpath, featpath)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -16,60 +16,13 @@ import shutil
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.append(r"D:\DetectTracking")
|
||||
from tracking.utils.plotting import Annotator, colors
|
||||
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output
|
||||
from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output, read_returnGoods_file
|
||||
from tracking.utils.plotting import draw_tracking_boxes
|
||||
|
||||
|
||||
|
||||
def showHist(err, correct):
|
||||
err = np.array(err)
|
||||
correct = np.array(correct)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(err, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([0, 1])
|
||||
axs[0].set_title('err')
|
||||
|
||||
axs[1].hist(correct, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([0, 1])
|
||||
axs[1].set_title('correct')
|
||||
# plt.show()
|
||||
|
||||
return plt
|
||||
|
||||
def show_recall_prec(recall, prec, ths):
|
||||
# x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
plt.plot(ths, recall, color='red', label='recall')
|
||||
plt.plot(ths, prec, color='blue', label='PrecisePos')
|
||||
plt.legend()
|
||||
plt.xlabel(f'threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# plt.savefig('accuracy_recall_grid.png')
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
|
||||
return plt
|
||||
|
||||
|
||||
def compute_recall_precision(err_similarity, correct_similarity):
|
||||
ths = np.linspace(0, 1, 51)
|
||||
recall, prec = [], []
|
||||
for th in ths:
|
||||
TP = len([num for num in correct_similarity if num >= th])
|
||||
FP = len([num for num in err_similarity if num >= th])
|
||||
if (TP+FP) == 0:
|
||||
prec.append(1)
|
||||
recall.append(0)
|
||||
else:
|
||||
prec.append(TP / (TP + FP))
|
||||
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
||||
return recall, prec, ths
|
||||
from contrast.utils.tools import showHist, show_recall_prec, compute_recall_precision
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@ -129,7 +82,7 @@ def read_tracking_imgs(imgspath):
|
||||
|
||||
return imgs_0, imgs_1
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# def draw_tracking_boxes(imgs, tracks):
|
||||
# '''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index]
|
||||
@ -287,27 +240,9 @@ def save_tracking_imgpairs(pair, basepath, savepath):
|
||||
cv2.imwrite(imgpath, img)
|
||||
|
||||
|
||||
# def performance_evaluate(all_list, isshow=False):
|
||||
|
||||
# corrpairs, correct_barcode_list, correct_similarity, errpairs, err_barcode_list, err_similarity = [], [], [], [], [], []
|
||||
# for s_list in all_list:
|
||||
# seqdir = s_list['SeqDir'].strip()
|
||||
# delete = s_list['Deleted'].strip()
|
||||
# barcodes = [s.strip() for s in s_list['barcode']]
|
||||
# similarity = [float(s.strip()) for s in s_list['similarity']]
|
||||
|
||||
# if delete in barcodes[:1]:
|
||||
# corrpairs.append((seqdir, delete))
|
||||
# correct_barcode_list.append(delete)
|
||||
# correct_similarity.append(similarity[0])
|
||||
# else:
|
||||
# errpairs.append((seqdir, delete, barcodes[0]))
|
||||
# err_barcode_list.append(delete)
|
||||
# err_similarity.append(similarity[0])
|
||||
|
||||
def performance_evaluate(all_list, isshow=False):
|
||||
|
||||
corrpairs, correct_barcode_list, correct_similarity, errpairs, err_barcode_list, err_similarity = [], [], [], [], [], []
|
||||
def one2n_old(all_list):
|
||||
corrpairs, errpairs, correct_similarity, err_similarity = [], [], [], []
|
||||
for s_list in all_list:
|
||||
seqdir = s_list['SeqDir'].strip()
|
||||
delete = s_list['Deleted'].strip()
|
||||
@ -332,70 +267,136 @@ def performance_evaluate(all_list, isshow=False):
|
||||
matched_barcode = barcodes[index]
|
||||
if matched_barcode == delete:
|
||||
corrpairs.append((seqdir, delete))
|
||||
correct_barcode_list.append(delete)
|
||||
correct_similarity.append(max(similarity))
|
||||
else:
|
||||
errpairs.append((seqdir, delete, matched_barcode))
|
||||
err_barcode_list.append(delete)
|
||||
err_similarity.append(max(similarity))
|
||||
|
||||
'''3. 计算比对性能 '''
|
||||
if isshow:
|
||||
recall, prec, ths = compute_recall_precision(err_similarity, correct_similarity)
|
||||
show_recall_prec(recall, prec, ths)
|
||||
showHist(err_similarity, correct_similarity)
|
||||
|
||||
return errpairs, corrpairs, err_similarity, correct_similarity
|
||||
|
||||
return corrpairs, errpairs, correct_similarity, err_similarity
|
||||
|
||||
|
||||
def contrast_analysis(del_barcode_file, basepath, savepath, saveimgs=False):
|
||||
'''
|
||||
del_barcode_file: 测试数据文件,利用该文件进行算法性能分析
|
||||
|
||||
def one2n_new(all_list):
|
||||
corrpairs, correct_similarity, errpairs, err_similarity = [], [], [], []
|
||||
for s_list in all_list:
|
||||
seqdir = s_list['SeqDir'].strip()
|
||||
delete = s_list['Deleted'].strip()
|
||||
barcodes = [s.strip() for s in s_list['barcode']]
|
||||
events = [s.strip() for s in s_list['event']]
|
||||
types = [s.strip() for s in s_list['type']]
|
||||
|
||||
## =================== 读入相似度值
|
||||
similarity_comp, similarity_front = [], []
|
||||
for simil in s_list['similarity']:
|
||||
ss = [float(s.strip()) for s in simil.split(',')]
|
||||
|
||||
similarity_comp.append(ss[0])
|
||||
if len(ss)==3:
|
||||
similarity_front.append(ss[2])
|
||||
|
||||
if len(similarity_front):
|
||||
similarity = [s for s in similarity_front]
|
||||
else:
|
||||
similarity = [s for s in similarity_comp]
|
||||
|
||||
|
||||
index = similarity.index(max(similarity))
|
||||
matched_barcode = barcodes[index]
|
||||
if matched_barcode == delete:
|
||||
corrpairs.append((seqdir, events[index]))
|
||||
correct_similarity.append(max(similarity))
|
||||
else:
|
||||
idx = [i for i, name in enumerate(events) if name.split('_')[-1] == delete]
|
||||
idxmax, simimax = -1, -1
|
||||
# idxmax, simimax = k, similarity[k] for k in idx if similarity[k] > simimax
|
||||
for k in idx:
|
||||
if similarity[k] > simimax:
|
||||
idxmax = k
|
||||
simimax = similarity[k]
|
||||
|
||||
errpairs.append((seqdir, events[idxmax], events[index]))
|
||||
err_similarity.append(max(similarity))
|
||||
|
||||
|
||||
return errpairs, corrpairs, err_similarity, correct_similarity
|
||||
|
||||
|
||||
# def contrast_analysis(del_barcode_file, basepath, savepath, saveimgs=False):
|
||||
def get_relative_paths(del_barcode_file, basepath, savepath, saveimgs=False):
|
||||
'''
|
||||
del_barcode_file:
|
||||
deletedBarcode.txt 格式的 1:n 数据结果文件
|
||||
returnGoods.txt格式数据文件不需要调用该函数,one2n_old() 函数返回的 errpairs
|
||||
中元素为三元元组(取出,放入, 错误匹配)
|
||||
'''
|
||||
relative_paths = []
|
||||
|
||||
'''1. 读取 deletedBarcode 文件 '''
|
||||
all_list = read_deletedBarcode_file(del_barcode_file)
|
||||
|
||||
|
||||
'''2. 算法性能评估,并输出 (取出,删除, 错误匹配) 对 '''
|
||||
errpairs, corrpairs, _, _ = performance_evaluate(all_list)
|
||||
|
||||
'''3. 获取 (取出,删除, 错误匹配) 对应路径,保存相应轨迹图像'''
|
||||
relative_paths = []
|
||||
errpairs, corrpairs, _, _ = one2n_old(all_list)
|
||||
|
||||
'''3. 构造事件组合(取出,放入并删除, 错误匹配) 对应路径 '''
|
||||
for errpair in errpairs:
|
||||
GetoutPath, InputPath, ErrorPath = get_contrast_paths(errpair, basepath)
|
||||
relative_paths.append((GetoutPath, InputPath, ErrorPath))
|
||||
|
||||
|
||||
'''3. 获取 (取出,放入并删除, 错误匹配) 对应路径,保存相应轨迹图像'''
|
||||
if saveimgs:
|
||||
save_tracking_imgpairs(errpair, basepath, savepath)
|
||||
|
||||
return relative_paths
|
||||
|
||||
|
||||
def contrast_loop(fpath):
|
||||
def one2n_test():
|
||||
fpath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\other'
|
||||
fpath = r'\\192.168.1.28\share\测试_202406\1030\images'
|
||||
|
||||
savepath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\illustration'
|
||||
# savepath = r'D:\contrast\dataset\1_to_n\illustration'
|
||||
if not os.path.exists(savepath):
|
||||
os.mkdir(savepath)
|
||||
|
||||
if os.path.isfile(fpath):
|
||||
fpath, filename = os.path.split(fpath)
|
||||
|
||||
if os.path.isdir(fpath):
|
||||
filepaths = [os.path.join(fpath, f) for f in os.listdir(fpath)
|
||||
if f.find('.txt')>0
|
||||
and (f.find('deletedBarcode')>=0 or f.find('returnGoods')>=0)]
|
||||
elif os.path.isfile(fpath):
|
||||
filepaths = [fpath]
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
FileFormat = {}
|
||||
|
||||
BarLists, blists = {}, []
|
||||
for filename in os.listdir(fpath):
|
||||
file = os.path.splitext(filename)[0][15:]
|
||||
|
||||
filepath = os.path.join(fpath, filename)
|
||||
blist = read_deletedBarcode_file(filepath)
|
||||
for pth in filepaths:
|
||||
file = str(Path(pth).stem)
|
||||
if file.find('deletedBarcode')>=0:
|
||||
FileFormat[file] = 'deletedBarcode'
|
||||
blist = read_deletedBarcode_file(pth)
|
||||
elif file.find('returnGoods')>=0:
|
||||
FileFormat[file] = 'returnGoods'
|
||||
blist = read_returnGoods_file(pth)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
BarLists.update({file: blist})
|
||||
blists.extend(blist)
|
||||
|
||||
BarLists.update({file: blist})
|
||||
BarLists.update({"Total": blists})
|
||||
for file, blist in BarLists.items():
|
||||
errpairs, corrpairs, err_similarity, correct_similarity = performance_evaluate(blist)
|
||||
|
||||
for file, blist in BarLists.items():
|
||||
if FileFormat[file] == 'deletedBarcode':
|
||||
_, _, err_similarity, correct_similarity = one2n_old(blist)
|
||||
elif FileFormat[file] == 'returnGoods':
|
||||
_, _, err_similarity, correct_similarity = one2n_new(blist)
|
||||
else:
|
||||
_, _, err_similarity, correct_similarity = one2n_old(blist)
|
||||
|
||||
|
||||
recall, prec, ths = compute_recall_precision(err_similarity, correct_similarity)
|
||||
|
||||
@ -411,25 +412,33 @@ def contrast_loop(fpath):
|
||||
# plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
fpath = r'\\192.168.1.28\share\测试_202406\deletedBarcode\other'
|
||||
contrast_loop(fpath)
|
||||
|
||||
def main1():
|
||||
|
||||
def test_getreltpath():
|
||||
'''
|
||||
适用于:deletedBarcode.txt,不适用于:returnGoods.txt
|
||||
'''
|
||||
|
||||
del_barcode_file = r'\\192.168.1.28\share\测试_202406\709\deletedBarcode.txt'
|
||||
basepath = r'\\192.168.1.28\share\测试_202406\709'
|
||||
savepath = r'D:\contrast\dataset\result'
|
||||
|
||||
# del_barcode_file = r'\\192.168.1.28\share\测试_202406\1030\images\returnGoods.txt'
|
||||
# basepath = r'\\192.168.1.28\share\测试_202406\1030\images'
|
||||
|
||||
savepath = r'D:\contrast\dataset\result'
|
||||
saveimgs = True
|
||||
try:
|
||||
relative_path = contrast_analysis(del_barcode_file, basepath, savepath)
|
||||
relative_path = get_relative_paths(del_barcode_file, basepath, savepath, saveimgs)
|
||||
except Exception as e:
|
||||
print(f'Error Type: {e}')
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
# main1()
|
||||
one2n_test()
|
||||
|
||||
# test_getreltpath()
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -239,71 +239,71 @@ def creat_shopping_event(eventPath, subimgPath=False):
|
||||
|
||||
return event
|
||||
|
||||
def get_std_barcodeDict(bcdpath, savepath):
|
||||
'''
|
||||
inputs:
|
||||
bcdpath: 已清洗的barcode样本图像,如果barcode下有'base'文件夹,只选用该文件夹下图像
|
||||
(default = r'\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_1771')
|
||||
功能:
|
||||
生成并保存只有一个key值的字典 {barcode: [imgpath1, imgpath1, ...]},
|
||||
savepath: 字典存储地址,文件名格式:barcode.pickle
|
||||
'''
|
||||
# def get_std_barcodeDict(bcdpath, savepath):
|
||||
# '''
|
||||
# inputs:
|
||||
# bcdpath: 已清洗的barcode样本图像,如果barcode下有'base'文件夹,只选用该文件夹下图像
|
||||
# (default = r'\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_1771')
|
||||
# 功能:
|
||||
# 生成并保存只有一个key值的字典 {barcode: [imgpath1, imgpath1, ...]},
|
||||
# savepath: 字典存储地址,文件名格式:barcode.pickle
|
||||
# '''
|
||||
|
||||
# savepath = r'\\192.168.1.28\share\测试_202406\contrast\std_barcodes'
|
||||
# # savepath = r'\\192.168.1.28\share\测试_202406\contrast\std_barcodes'
|
||||
|
||||
'''读取数据集中 barcode 列表'''
|
||||
stdBarcodeList = []
|
||||
for filename in os.listdir(bcdpath):
|
||||
filepath = os.path.join(bcdpath, filename)
|
||||
# if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
|
||||
# continue
|
||||
stdBarcodeList.append(filename)
|
||||
# '''读取数据集中 barcode 列表'''
|
||||
# stdBarcodeList = []
|
||||
# for filename in os.listdir(bcdpath):
|
||||
# filepath = os.path.join(bcdpath, filename)
|
||||
# # if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
|
||||
# # continue
|
||||
# stdBarcodeList.append(filename)
|
||||
|
||||
bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList]
|
||||
# bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList]
|
||||
|
||||
'''遍历数据集,针对每一个barcode,生成并保存字典{barcode: [imgpath1, imgpath1, ...]}'''
|
||||
k = 0
|
||||
errbarcodes = []
|
||||
for barcode, bpath in bcdPaths:
|
||||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
if os.path.isfile(pickpath):
|
||||
continue
|
||||
# '''遍历数据集,针对每一个barcode,生成并保存字典{barcode: [imgpath1, imgpath1, ...]}'''
|
||||
# k = 0
|
||||
# errbarcodes = []
|
||||
# for barcode, bpath in bcdPaths:
|
||||
# pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
# if os.path.isfile(pickpath):
|
||||
# continue
|
||||
|
||||
stdBarcodeDict = {}
|
||||
stdBarcodeDict[barcode] = []
|
||||
for root, dirs, files in os.walk(bpath):
|
||||
imgpaths = []
|
||||
if "base" in dirs:
|
||||
broot = os.path.join(root, "base")
|
||||
for imgname in os.listdir(broot):
|
||||
imgpath = os.path.join(broot, imgname)
|
||||
file, ext = os.path.splitext(imgpath)
|
||||
# stdBarcodeDict = {}
|
||||
# stdBarcodeDict[barcode] = []
|
||||
# for root, dirs, files in os.walk(bpath):
|
||||
# imgpaths = []
|
||||
# if "base" in dirs:
|
||||
# broot = os.path.join(root, "base")
|
||||
# for imgname in os.listdir(broot):
|
||||
# imgpath = os.path.join(broot, imgname)
|
||||
# file, ext = os.path.splitext(imgpath)
|
||||
|
||||
if ext not in IMG_FORMAT:
|
||||
continue
|
||||
imgpaths.append(imgpath)
|
||||
# if ext not in IMG_FORMAT:
|
||||
# continue
|
||||
# imgpaths.append(imgpath)
|
||||
|
||||
stdBarcodeDict[barcode].extend(imgpaths)
|
||||
break
|
||||
# stdBarcodeDict[barcode].extend(imgpaths)
|
||||
# break
|
||||
|
||||
else:
|
||||
for imgname in files:
|
||||
imgpath = os.path.join(root, imgname)
|
||||
_, ext = os.path.splitext(imgpath)
|
||||
if ext not in IMG_FORMAT: continue
|
||||
imgpaths.append(imgpath)
|
||||
stdBarcodeDict[barcode].extend(imgpaths)
|
||||
# else:
|
||||
# for imgname in files:
|
||||
# imgpath = os.path.join(root, imgname)
|
||||
# _, ext = os.path.splitext(imgpath)
|
||||
# if ext not in IMG_FORMAT: continue
|
||||
# imgpaths.append(imgpath)
|
||||
# stdBarcodeDict[barcode].extend(imgpaths)
|
||||
|
||||
pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
with open(pickpath, 'wb') as f:
|
||||
pickle.dump(stdBarcodeDict, f)
|
||||
print(f"Barcode: {barcode}")
|
||||
# pickpath = os.path.join(savepath, f"{barcode}.pickle")
|
||||
# with open(pickpath, 'wb') as f:
|
||||
# pickle.dump(stdBarcodeDict, f)
|
||||
# print(f"Barcode: {barcode}")
|
||||
|
||||
# k += 1
|
||||
# if k == 10:
|
||||
# break
|
||||
print(f"Len of errbarcodes: {len(errbarcodes)}")
|
||||
return
|
||||
# # k += 1
|
||||
# # if k == 10:
|
||||
# # break
|
||||
# print(f"Len of errbarcodes: {len(errbarcodes)}")
|
||||
# return
|
||||
|
||||
def save_event_subimg(event, savepath):
|
||||
'''
|
||||
@ -355,92 +355,92 @@ def batch_inference(imgpaths, batch):
|
||||
features = np.concatenate(features, axis=0)
|
||||
return features
|
||||
|
||||
def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||||
'''
|
||||
inputs:
|
||||
imgPath: 该文件夹下的 pickle 文件格式 {barcode: [imgpath1, imgpath1, ...]}
|
||||
featPath: imgPath图像对应特征的存储地址
|
||||
功能:
|
||||
对 imgPath中图像进行特征提取,生成只有一个key值的字典,
|
||||
{barcode: features},features.shape=(nsample, 256),并保存至 featPath 中
|
||||
# def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||||
# '''
|
||||
# inputs:
|
||||
# imgPath: 该文件夹下的 pickle 文件格式 {barcode: [imgpath1, imgpath1, ...]}
|
||||
# featPath: imgPath图像对应特征的存储地址
|
||||
# 功能:
|
||||
# 对 imgPath中图像进行特征提取,生成只有一个key值的字典,
|
||||
# {barcode: features},features.shape=(nsample, 256),并保存至 featPath 中
|
||||
|
||||
'''
|
||||
# '''
|
||||
|
||||
# imgPath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes"
|
||||
# featPath = r"\\192.168.1.28\share\测试_202406\contrast\std_features"
|
||||
stdBarcodeDict = {}
|
||||
stdBarcodeDict_ft16 = {}
|
||||
# # imgPath = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes"
|
||||
# # featPath = r"\\192.168.1.28\share\测试_202406\contrast\std_features"
|
||||
# stdBarcodeDict = {}
|
||||
# stdBarcodeDict_ft16 = {}
|
||||
|
||||
|
||||
'''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名、该pickle文件中字典的key值'''
|
||||
# '''4处同名: (1)barcode原始图像文件夹; (2)imgPath中的 .pickle 文件名、该pickle文件中字典的key值'''
|
||||
|
||||
k = 0
|
||||
for filename in os.listdir(imgPath):
|
||||
bcd, ext = os.path.splitext(filename)
|
||||
pkpath = os.path.join(featPath, f"{bcd}.pickle")
|
||||
# k = 0
|
||||
# for filename in os.listdir(imgPath):
|
||||
# bcd, ext = os.path.splitext(filename)
|
||||
# pkpath = os.path.join(featPath, f"{bcd}.pickle")
|
||||
|
||||
if os.path.isfile(pkpath): continue
|
||||
if bcdSet is not None and bcd not in bcdSet:
|
||||
continue
|
||||
# if os.path.isfile(pkpath): continue
|
||||
# if bcdSet is not None and bcd not in bcdSet:
|
||||
# continue
|
||||
|
||||
filepath = os.path.join(imgPath, filename)
|
||||
# filepath = os.path.join(imgPath, filename)
|
||||
|
||||
stdbDict = {}
|
||||
stdbDict_ft16 = {}
|
||||
stdbDict_uint8 = {}
|
||||
# stdbDict = {}
|
||||
# stdbDict_ft16 = {}
|
||||
# stdbDict_uint8 = {}
|
||||
|
||||
t1 = time.time()
|
||||
# t1 = time.time()
|
||||
|
||||
try:
|
||||
with open(filepath, 'rb') as f:
|
||||
bpDict = pickle.load(f)
|
||||
for barcode, imgpaths in bpDict.items():
|
||||
# feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
|
||||
feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
|
||||
feature /= np.linalg.norm(feature, axis=1)[:, None]
|
||||
# try:
|
||||
# with open(filepath, 'rb') as f:
|
||||
# bpDict = pickle.load(f)
|
||||
# for barcode, imgpaths in bpDict.items():
|
||||
# # feature = batch_inference(imgpaths, 8) #from vit distilled model of LiChen
|
||||
# feature = inference_image(imgpaths, conf.test_transform, model, conf.device)
|
||||
# feature /= np.linalg.norm(feature, axis=1)[:, None]
|
||||
|
||||
# float16
|
||||
feature_ft16 = feature.astype(np.float16)
|
||||
feature_ft16 /= np.linalg.norm(feature_ft16, axis=1)[:, None]
|
||||
# # float16
|
||||
# feature_ft16 = feature.astype(np.float16)
|
||||
# feature_ft16 /= np.linalg.norm(feature_ft16, axis=1)[:, None]
|
||||
|
||||
# uint8, 两种策略,1) 精度损失小, 2) 计算复杂度小
|
||||
# feature_uint8, _ = ft16_to_uint8(feature_ft16)
|
||||
feature_uint8 = (feature_ft16*128).astype(np.int8)
|
||||
# # uint8, 两种策略,1) 精度损失小, 2) 计算复杂度小
|
||||
# # feature_uint8, _ = ft16_to_uint8(feature_ft16)
|
||||
# feature_uint8 = (feature_ft16*128).astype(np.int8)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error accured at: {filename}, with Exception is: {e}")
|
||||
# except Exception as e:
|
||||
# print(f"Error accured at: {filename}, with Exception is: {e}")
|
||||
|
||||
'''================ 保存单个barcode特征 ================'''
|
||||
##================== float32
|
||||
stdbDict["barcode"] = barcode
|
||||
stdbDict["imgpaths"] = imgpaths
|
||||
stdbDict["feats_ft32"] = feature
|
||||
stdbDict["feats_ft16"] = feature_ft16
|
||||
stdbDict["feats_uint8"] = feature_uint8
|
||||
# '''================ 保存单个barcode特征 ================'''
|
||||
# ##================== float32
|
||||
# stdbDict["barcode"] = barcode
|
||||
# stdbDict["imgpaths"] = imgpaths
|
||||
# stdbDict["feats_ft32"] = feature
|
||||
# stdbDict["feats_ft16"] = feature_ft16
|
||||
# stdbDict["feats_uint8"] = feature_uint8
|
||||
|
||||
with open(pkpath, 'wb') as f:
|
||||
pickle.dump(stdbDict, f)
|
||||
# with open(pkpath, 'wb') as f:
|
||||
# pickle.dump(stdbDict, f)
|
||||
|
||||
stdBarcodeDict[barcode] = feature
|
||||
stdBarcodeDict_ft16[barcode] = feature_ft16
|
||||
# stdBarcodeDict[barcode] = feature
|
||||
# stdBarcodeDict_ft16[barcode] = feature_ft16
|
||||
|
||||
t2 = time.time()
|
||||
print(f"Barcode: {barcode}, need time: {t2-t1:.1f} secs")
|
||||
# k += 1
|
||||
# if k == 10:
|
||||
# break
|
||||
# t2 = time.time()
|
||||
# print(f"Barcode: {barcode}, need time: {t2-t1:.1f} secs")
|
||||
# # k += 1
|
||||
# # if k == 10:
|
||||
# # break
|
||||
|
||||
##================== float32
|
||||
# pickpath = os.path.join(featPath, f"barcode_features_{k}.pickle")
|
||||
# with open(pickpath, 'wb') as f:
|
||||
# pickle.dump(stdBarcodeDict, f)
|
||||
# ##================== float32
|
||||
# # pickpath = os.path.join(featPath, f"barcode_features_{k}.pickle")
|
||||
# # with open(pickpath, 'wb') as f:
|
||||
# # pickle.dump(stdBarcodeDict, f)
|
||||
|
||||
##================== float16
|
||||
# pickpath_ft16 = os.path.join(featPath, f"barcode_features_ft16_{k}.pickle")
|
||||
# with open(pickpath_ft16, 'wb') as f:
|
||||
# pickle.dump(stdBarcodeDict_ft16, f)
|
||||
# ##================== float16
|
||||
# # pickpath_ft16 = os.path.join(featPath, f"barcode_features_ft16_{k}.pickle")
|
||||
# # with open(pickpath_ft16, 'wb') as f:
|
||||
# # pickle.dump(stdBarcodeDict_ft16, f)
|
||||
|
||||
return
|
||||
# return
|
||||
|
||||
|
||||
def contrast_performance_evaluate(resultPath):
|
||||
@ -789,30 +789,28 @@ def main():
|
||||
compute_precise_recall(pickpath)
|
||||
|
||||
|
||||
def main_std():
|
||||
std_sample_path = r"\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_500_2192_已清洗"
|
||||
std_barcode_path = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes_2192"
|
||||
std_feature_path = r"\\192.168.1.28\share\测试_202406\contrast\std_features_2192_ft32vsft16"
|
||||
# def main_std():
|
||||
# std_sample_path = r"\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_500_2192_已清洗"
|
||||
# std_barcode_path = r"\\192.168.1.28\share\测试_202406\contrast\std_barcodes_2192"
|
||||
# std_feature_path = r"\\192.168.1.28\share\测试_202406\contrast\std_features_2192_ft32vsft16"
|
||||
|
||||
get_std_barcodeDict(std_sample_path, std_barcode_path)
|
||||
|
||||
stdfeat_infer(std_barcode_path, std_feature_path, bcdSet=None)
|
||||
# get_std_barcodeDict(std_sample_path, std_barcode_path)
|
||||
# stdfeat_infer(std_barcode_path, std_feature_path, bcdSet=None)
|
||||
|
||||
# fileList = []
|
||||
# for filename in os.listdir(std_barcode_path):
|
||||
# filepath = os.path.join(std_barcode_path, filename)
|
||||
# with open(filepath, 'rb') as f:
|
||||
# bpDict = pickle.load(f)
|
||||
# # fileList = []
|
||||
# # for filename in os.listdir(std_barcode_path):
|
||||
# # filepath = os.path.join(std_barcode_path, filename)
|
||||
# # with open(filepath, 'rb') as f:
|
||||
# # bpDict = pickle.load(f)
|
||||
|
||||
# for v in bpDict.values():
|
||||
# fileList.append(len(v))
|
||||
# print("done")
|
||||
# # for v in bpDict.values():
|
||||
# # fileList.append(len(v))
|
||||
# # print("done")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# main()
|
||||
|
||||
|
||||
main_std()
|
||||
main()
|
||||
# main_std()
|
||||
|
||||
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Wed Sep 11 11:57:30 2024
|
||||
|
||||
永辉现场 1:1 比对测试
|
||||
|
||||
永辉现场试验输出数据的 1:1 性能评估
|
||||
适用于202410前数据保存版本的,需调用 OneToOneCompare.txt
|
||||
@author: ym
|
||||
"""
|
||||
import os
|
||||
@ -65,14 +64,14 @@ def plot_pr_curve(matrix):
|
||||
axs[1].set_title(f'Cross Barcode, Num: {TPFN_mean}')
|
||||
# plt.savefig(f'./result/{file}_hist.png') # svg, png, pdf
|
||||
|
||||
Recall_Pos = []
|
||||
Recall_Neg = []
|
||||
Thresh = np.linspace(-0.2, 1, 100)
|
||||
for th in Thresh:
|
||||
TN = np.sum(simimax < th)
|
||||
Recall_Pos.append(TN/TPFN_max)
|
||||
Recall_Neg.append(TN/TPFN_max)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
|
||||
ax.plot(Thresh, Recall_Neg, 'b', label='Recall_Pos: TP/TPFN')
|
||||
ax.set_xlim([0, 1])
|
||||
ax.set_ylim([0, 1])
|
||||
ax.grid(True)
|
||||
@ -96,9 +95,7 @@ def main():
|
||||
simiList = []
|
||||
for fp in filepaths:
|
||||
slist = read_one2one_data(fp)
|
||||
|
||||
simiList.extend(slist)
|
||||
|
||||
|
||||
plot_pr_curve(simiList)
|
||||
|
||||
|
BIN
contrast/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
contrast/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/utils/__pycache__/tools.cpython-39.pyc
Normal file
BIN
contrast/utils/__pycache__/tools.cpython-39.pyc
Normal file
Binary file not shown.
56
contrast/utils/tools.py
Normal file
56
contrast/utils/tools.py
Normal file
@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Oct 31 15:17:01 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
def showHist(err, correct):
|
||||
err = np.array(err)
|
||||
correct = np.array(correct)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(err, bins=50, edgecolor='black')
|
||||
axs[0].set_xlim([0, 1])
|
||||
axs[0].set_title('err')
|
||||
|
||||
axs[1].hist(correct, bins=50, edgecolor='black')
|
||||
axs[1].set_xlim([0, 1])
|
||||
axs[1].set_title('correct')
|
||||
# plt.show()
|
||||
|
||||
return plt
|
||||
|
||||
def show_recall_prec(recall, prec, ths):
|
||||
# x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist()
|
||||
fig = plt.figure(figsize=(10, 6))
|
||||
plt.plot(ths, recall, color='red', label='recall')
|
||||
plt.plot(ths, prec, color='blue', label='PrecisePos')
|
||||
plt.legend()
|
||||
plt.xlabel(f'threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
# plt.savefig('accuracy_recall_grid.png')
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
|
||||
return plt
|
||||
|
||||
|
||||
def compute_recall_precision(err_similarity, correct_similarity):
|
||||
ths = np.linspace(0, 1, 51)
|
||||
recall, prec = [], []
|
||||
for th in ths:
|
||||
TP = len([num for num in correct_similarity if num >= th])
|
||||
FP = len([num for num in err_similarity if num >= th])
|
||||
if (TP+FP) == 0:
|
||||
prec.append(1)
|
||||
recall.append(0)
|
||||
else:
|
||||
prec.append(TP / (TP + FP))
|
||||
recall.append(TP / (len(err_similarity) + len(correct_similarity)))
|
||||
return recall, prec, ths
|
Reference in New Issue
Block a user