From 96a94457614a3deea72383d6ef9806d3e2b33653 Mon Sep 17 00:00:00 2001 From: lee <770918727@qq.com> Date: Wed, 25 Jun 2025 16:23:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95=E7=BB=B4?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gift_demo.py | 16 ++-- pipline_test/read_trackdata_filter.py | 43 +++++----- trace_demo.py | 2 +- train.py | 2 +- ultralytics/utils/show_pr.py | 109 ++++++++++++++++++++++++++ ultralytics/utils/show_trace_pr.py | 2 +- 6 files changed, 141 insertions(+), 33 deletions(-) create mode 100644 ultralytics/utils/show_pr.py diff --git a/gift_demo.py b/gift_demo.py index c16f3f6..a5a4edd 100644 --- a/gift_demo.py +++ b/gift_demo.py @@ -5,7 +5,7 @@ from ultralytics import YOLOv10 import cv2 import torch -from ultralytics.utils.show_trace_pr import ShowPR +from ultralytics.utils.show_pr import ShowPR # from trace_detect import run, _init_model import numpy as np image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"] @@ -43,7 +43,7 @@ def get_image_list(path): def _init(): - model = YOLOv10('runs/detect/train/weights/best_gift_v10n.pt') + model = YOLOv10('ckpts/20250620/best_gift_v10n.pt') return model @@ -72,6 +72,7 @@ def get_trace_event(model, path): def main(path): model = _init() + tags_tmp, result_all_tmp = [], [] tags, result_all = [], [] classify = ['commodity', 'gift'] for cla in classify: @@ -79,12 +80,15 @@ def main(path): for root, dirs, files in os.walk(pre_pth): if not dirs: if cla == 'commodity': - tags.append(0) + tags_tmp.append(0) else: - tags.append(1) + tags_tmp.append(1) res_single = get_trace_event(model, root) - result_all.append(res_single) - spr = ShowPR(tags, result_all, title_name='yolov10n') + result_all_tmp.append(res_single) + for tag, result in zip(tags_tmp, result_all_tmp): + tags += [tag]*len(result) + result_all += result + spr = ShowPR(tags, result_all, title_name='yolov10n', ) # spr.change_precValue() spr.get_pic() diff --git a/pipline_test/read_trackdata_filter.py b/pipline_test/read_trackdata_filter.py index 02ece43..628bca9 100644 --- a/pipline_test/read_trackdata_filter.py +++ b/pipline_test/read_trackdata_filter.py @@ -49,10 +49,14 @@ def read_tracking_output(filepath): break if start_idx != -1 and end_idx != -1: + content = [] for i in range(start_idx, end_idx): line = lines[i].strip() if line: - gift_data.append(line) + content.append(line) + # 将所有内容合并成一行字符串 + if content: + gift_data.append(' '.join(content)) except Exception as e: print(f"Error extracting gift data: {e}") @@ -63,28 +67,13 @@ def read_tracking_output(filepath): def extract_data_realtime(datapath): - boxes, feats = [], [] - tracker_feats = {} - with open(datapath, 'r', encoding='utf-8') as lines: - for line in lines: - line = line.strip() # 去除行尾的换行符和可能的空白字符 - if not line: # 跳过空行 - continue - - if line.endswith(','): - line = line[:-1] - ftlist = [float(x) for x in line.split()] - - if len(ftlist) != 265: - continue - boxes.append(ftlist[:9]) - feats.append(ftlist[9:]) - - trackerboxes = np.array(boxes) - trackerfeats = np.array(feats) - - if len(trackerboxes) == 0 or len(trackerboxes) != len(trackerfeats): - return np.array([]), {} + boxes, feats, gift_data = read_tracking_output(datapath) + + if not boxes or not feats: + return np.array([]), {}, [] + + trackerboxes = boxes[0] # 因为read_tracking_output返回的是list中的numpy数组 + trackerfeats = feats[0] frmIDs = np.sort(np.unique(trackerboxes[:, 7].astype(int))) for fid in frmIDs: @@ -373,9 +362,15 @@ for event in os.listdir(video_path): # print('imgfile_list', imgfile_list) for track_data in track_list: track_path = os.path.join(event_path, track_data) - boxes, feat = extract_data_realtime(track_path) + boxes, feat, gift_data = extract_data_realtime(track_path) camera_id = track_data.split('_')[0] imgfile = [x for x in imgfile_list if x.split('_')[0] == camera_id][0] + + # 打印gift数据 + if gift_data: + print(f"\nGift data for {event}/{track_data}:") + print(gift_data[0]) # 现在gift_data只包含一个元素,即合并后的字符串 + if len(boxes) > 0: if del_staticBox: ##根据距离删除box boxes_ = compute_box_dist(boxes) diff --git a/trace_demo.py b/trace_demo.py index 9350458..60eda4c 100644 --- a/trace_demo.py +++ b/trace_demo.py @@ -43,7 +43,7 @@ def get_image_list(path): def _init(): - model = YOLOv10('ckpts/20250514/best_gift_v10n.pt') + model = YOLOv10('ckpts/20250620/best_gift_v10n.pt') return model diff --git a/train.py b/train.py index 9c5a7fb..fa3eddc 100644 --- a/train.py +++ b/train.py @@ -17,5 +17,5 @@ model = YOLOv10('ckpts/weights/yolov10n.pt') #model.train(data='coco.yaml', epochs=1, batch=64, imgsz=640) #model.train(data='coco128_cls10_0924.yaml', epochs=300, batch=64, imgsz=640, resume=False) #model.train(data='coco128_cls10_1010.yaml', epochs=300, batch=128, imgsz=640, resume=False) -model.train(data='gift.yaml', epochs=400, batch=32, imgsz=224, resume=False, save_dir='/ckpts') +model.train(data='gift.yaml', epochs=600, batch=32, imgsz=224, resume=False, save_dir='/ckpts') #model.train(data='coco128_cls10_1010_1205.yaml', epochs=300, batch=32, imgsz=640, resume=True) \ No newline at end of file diff --git a/ultralytics/utils/show_pr.py b/ultralytics/utils/show_pr.py new file mode 100644 index 0000000..003a78e --- /dev/null +++ b/ultralytics/utils/show_pr.py @@ -0,0 +1,109 @@ +import os.path +import os +import matplotlib.pyplot as plt +import numpy as np + + +# [tag, bandage, null] +# [0, 1, 2] +class ShowPR: + def __init__(self, tags, prec_value, title_name=None): + self.tags = tags + self.prec_value = prec_value + self.thres = [i * 0.01 for i in range(101)] + self.title_name = title_name + + def change_precValue(self, thre=0.5): + values = [] + for i in range(len(self.prec_value)): + value = [] + for j in range(len(self.prec_value[i])): + if self.prec_value[i][j] > thre: + value.append(1) + else: + value.append(0) + values.append(value) + return values + + def _calculate_pr(self, prec_value): + FN, FP, TN, TP = 0, 0, 0, 0 + for output, target in zip(prec_value, self.tags): + # print("output >> {} , target >> {}".format(output, target)) + if output != target: + if target == 0: + FP += 1 + elif target == 1: + FN += 1 + else: + if target == 0: + TN += 1 + elif target == 1: + TP += 1 + if TP == 0: + prec, recall = 0, 0 + else: + prec = TP / (TP + FP) + recall = TP / (TP + FN) + if TN == 0: + tn_prec, tn_recall = 0, 0 + else: + tn_prec = TN / (TN + FN) + tn_recall = TN / (TN + FP) + # print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN)) + return prec, recall, tn_prec, tn_recall + + def calculate_multiple(self): + recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], [] + for thre in self.thres: + # prec_value = [] + # if self.prec_value >= thre: + # prec_value.append(1) + # else: + # prec_value.append(0) + prec_value = [1 if num >= thre else 0 for num in self.prec_value] + prec, recall_pos, tn_prec, tn_recall = self._calculate_pr(prec_value) + print( + f"thre>>{thre:.2f}, recall>>{recall_pos:.4f}, precise_pos>>{prec:.4f}, recall_tn>>{tn_recall:.4f}, precise_neg>>{tn_prec:4f}") + PrecisePos.append(prec) + recall.append(recall_pos) + PreciseNeg.append(tn_prec) + recall_TN.append(tn_recall) + return recall, recall_TN, PrecisePos, PreciseNeg + + def write_results_to_file(self, recall, recall_TN, PrecisePos, PreciseNeg): + file_path = os.sep.join(['./ckpts/tracePR', self.title_name + '.txt']) + with open(file_path, 'w') as file: + file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n") + for thre, rec, rec_tn, prec_pos, prec_neg in zip(self.thres, recall, recall_TN, PrecisePos, PreciseNeg): + file.write( + f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n") + + def show_pr(self, recall, recall_TN, PrecisePos, PreciseNeg): + # self.calculate_multiple() + x = self.thres + plt.figure(figsize=(10, 6)) + plt.plot(x, recall, color='red', label='recall:TP/TPFN') + plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP') + plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN') + plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP') + plt.legend() + plt.xlabel('threshold') + # if self.title_name is not None: + # plt.title(f"PrecisePos & Recall ratio:{ratio:.2f}", fontdict={'fontsize': 12, 'fontweight': 'black'}) + # plt.grid(True, linestyle='--', alpha=0.5) + # 启用次刻度 + # plt.minorticks_on() + + # 设置主刻度的网格线 + plt.grid(which='major', linestyle='-', alpha=0.5, color='gray') + + # 设置次刻度的网格线 + plt.grid(which='minor', linestyle=':', alpha=0.3, color='gray') + plt.savefig(os.sep.join(['./ckpts/tracePR', self.title_name + '.png'])) + plt.show() + plt.close() + self.write_results_to_file(recall, recall_TN, PrecisePos, PreciseNeg) + + def get_pic(self): + recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple() + self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg) diff --git a/ultralytics/utils/show_trace_pr.py b/ultralytics/utils/show_trace_pr.py index 8dcdd57..83e4dc4 100644 --- a/ultralytics/utils/show_trace_pr.py +++ b/ultralytics/utils/show_trace_pr.py @@ -110,7 +110,7 @@ class ShowPR: def get_pic(self): for ratio in self.ratios: # ratio = 0.5 - if ratio < 0.2 or ratio > 0.95: + if ratio < 0.1 or ratio > 0.95: continue recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple(ratio) self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg, ratio)