import os.path import shutil import numpy as np import matplotlib.pyplot as plt import cv2 from utils.plotting import Annotator, colors import sys sys.path.append(r"D:\DetectTracking") from tracking.utils.read_data import extract_data, read_deletedBarcode_file, read_tracking_output from tracking.utils.plotting import draw_tracking_boxes def showHist(err, correct): err = np.array(err) correct = np.array(correct) fig, axs = plt.subplots(2, 1) axs[0].hist(err, bins=50, edgecolor='black') axs[0].set_xlim([0, 1]) axs[0].set_title('err') axs[1].hist(correct, bins=50, edgecolor='black') axs[1].set_xlim([0, 1]) axs[1].set_title('correct') plt.show() def showgrid(recall, prec, ths): # x = np.linspace(start=-0, stop=1, num=11, endpoint=True).tolist() fig = plt.figure(figsize=(10, 6)) plt.plot(ths, recall, color='red', label='recall') plt.plot(ths, prec, color='blue', label='PrecisePos') plt.legend() plt.xlabel('threshold') # plt.ylabel('Similarity') plt.grid(True, linestyle='--', alpha=0.5) plt.savefig('accuracy_recall_grid.png') plt.show() # plt.close() def compute_recall_precision(err_similarity, correct_similarity): ths = np.linspace(0, 1, 11) recall, prec = [], [] for th in ths: TP = len([num for num in correct_similarity if num >= th]) FP = len([num for num in err_similarity if num >= th]) if (TP+FP) == 0: prec.append(1) recall.append(0) else: prec.append(TP / (TP + FP)) recall.append(TP / (len(err_similarity) + len(correct_similarity))) showgrid(recall, prec, ths) return recall, prec # ============================================================================= # def read_tracking_output(filepath): # boxes = [] # feats = [] # with open(filepath, 'r', encoding='utf-8') as file: # for line in file: # line = line.strip() # 去除行尾的换行符和可能的空白字符 # # if not line: # continue # # if line.endswith(','): # line = line[:-1] # # data = np.array([float(x) for x in line.split(",")]) # if data.size == 9: # boxes.append(data) # if data.size == 256: # feats.append(data) # # return np.array(boxes), np.array(feats) # ============================================================================= def read_tracking_imgs(imgspath): ''' input: imgspath:该路径中的图像为Yolo算法的输入图像,640x512 output: imgs_0:后摄图像,根据 frameId 进行了排序 imgs_1:前摄图像,根据 frameId 进行了排序 ''' imgs_0, frmIDs_0, imgs_1, frmIDs_1 = [], [], [], [] for filename in os.listdir(imgspath): file, ext = os.path.splitext(filename) flist = file.split('_') if len(flist)==4 and ext==".jpg": camID, frmID = flist[0], int(flist[-1]) imgpath = os.path.join(imgspath, filename) img = cv2.imread(imgpath) if camID=='0': imgs_0.append(img) frmIDs_0.append(frmID) if camID=='1': imgs_1.append(img) frmIDs_1.append(frmID) if len(frmIDs_0): indice = np.argsort(np.array(frmIDs_0)) imgs_0 = [imgs_0[i] for i in indice ] if len(frmIDs_1): indice = np.argsort(np.array(frmIDs_1)) imgs_1 = [imgs_1[i] for i in indice ] return imgs_0, imgs_1 # ============================================================================= # def draw_tracking_boxes(imgs, tracks): # '''tracks: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index] # 0 1 2 3 4 5 6 7 8 # 关键:imgs中的次序和 track 中的 fid 对应 # ''' # subimgs = [] # for *xyxy, tid, conf, cls, fid, bid in tracks: # label = f'id:{int(tid)}_{int(cls)}_{conf:.2f}' # # annotator = Annotator(imgs[int(fid-1)].copy()) # if cls==0: # color = colors(int(cls), True) # elif tid>0 and cls!=0: # color = colors(int(tid), True) # else: # color = colors(19, True) # 19为调色板的最后一个元素 # # pt2 = [p/2 for p in xyxy] # annotator.box_label(pt2, label, color=color) # img0 = annotator.result() # # subimgs.append(img0) # # return subimgs # ============================================================================= def get_contrast_paths(pair, basepath): assert(len(pair)==2 or len(pair)==3), "pair: seqdir, delete, barcodes" getout_fold = pair[0] # 取出操作对应的文件夹 relvt_barcode = pair[1] # 取出操作对应放入操作的 Barcode if len(pair)==3: error_match = pair[2] # 取出操作错误匹配的 Barcode else: error_match = '' getoutpath, inputpath, errorpath = '', '', '' day, hms = getout_fold.strip('_').split('-') input_folds, times = [], [] errmatch_folds, errmatch_times = [], [] for pathname in os.listdir(basepath): if pathname.endswith('_'): continue if os.path.isfile(os.path.join(basepath, pathname)):continue infold = pathname.split('_') if len(infold)!=2: continue day1, hms1 = infold[0].split('-') if day1==day and infold[1]==relvt_barcode and int(hms1)