mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-18 13:20:25 +00:00
增加测试维度
This commit is contained in:
16
gift_demo.py
16
gift_demo.py
@ -5,7 +5,7 @@ from ultralytics import YOLOv10
|
|||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.utils.show_trace_pr import ShowPR
|
from ultralytics.utils.show_pr import ShowPR
|
||||||
# from trace_detect import run, _init_model
|
# from trace_detect import run, _init_model
|
||||||
import numpy as np
|
import numpy as np
|
||||||
image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
|
image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
|
||||||
@ -43,7 +43,7 @@ def get_image_list(path):
|
|||||||
|
|
||||||
|
|
||||||
def _init():
|
def _init():
|
||||||
model = YOLOv10('runs/detect/train/weights/best_gift_v10n.pt')
|
model = YOLOv10('ckpts/20250620/best_gift_v10n.pt')
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -72,6 +72,7 @@ def get_trace_event(model, path):
|
|||||||
|
|
||||||
def main(path):
|
def main(path):
|
||||||
model = _init()
|
model = _init()
|
||||||
|
tags_tmp, result_all_tmp = [], []
|
||||||
tags, result_all = [], []
|
tags, result_all = [], []
|
||||||
classify = ['commodity', 'gift']
|
classify = ['commodity', 'gift']
|
||||||
for cla in classify:
|
for cla in classify:
|
||||||
@ -79,12 +80,15 @@ def main(path):
|
|||||||
for root, dirs, files in os.walk(pre_pth):
|
for root, dirs, files in os.walk(pre_pth):
|
||||||
if not dirs:
|
if not dirs:
|
||||||
if cla == 'commodity':
|
if cla == 'commodity':
|
||||||
tags.append(0)
|
tags_tmp.append(0)
|
||||||
else:
|
else:
|
||||||
tags.append(1)
|
tags_tmp.append(1)
|
||||||
res_single = get_trace_event(model, root)
|
res_single = get_trace_event(model, root)
|
||||||
result_all.append(res_single)
|
result_all_tmp.append(res_single)
|
||||||
spr = ShowPR(tags, result_all, title_name='yolov10n')
|
for tag, result in zip(tags_tmp, result_all_tmp):
|
||||||
|
tags += [tag]*len(result)
|
||||||
|
result_all += result
|
||||||
|
spr = ShowPR(tags, result_all, title_name='yolov10n', )
|
||||||
# spr.change_precValue()
|
# spr.change_precValue()
|
||||||
spr.get_pic()
|
spr.get_pic()
|
||||||
|
|
||||||
|
@ -49,10 +49,14 @@ def read_tracking_output(filepath):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if start_idx != -1 and end_idx != -1:
|
if start_idx != -1 and end_idx != -1:
|
||||||
|
content = []
|
||||||
for i in range(start_idx, end_idx):
|
for i in range(start_idx, end_idx):
|
||||||
line = lines[i].strip()
|
line = lines[i].strip()
|
||||||
if line:
|
if line:
|
||||||
gift_data.append(line)
|
content.append(line)
|
||||||
|
# 将所有内容合并成一行字符串
|
||||||
|
if content:
|
||||||
|
gift_data.append(' '.join(content))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting gift data: {e}")
|
print(f"Error extracting gift data: {e}")
|
||||||
|
|
||||||
@ -63,28 +67,13 @@ def read_tracking_output(filepath):
|
|||||||
|
|
||||||
|
|
||||||
def extract_data_realtime(datapath):
|
def extract_data_realtime(datapath):
|
||||||
boxes, feats = [], []
|
boxes, feats, gift_data = read_tracking_output(datapath)
|
||||||
tracker_feats = {}
|
|
||||||
with open(datapath, 'r', encoding='utf-8') as lines:
|
if not boxes or not feats:
|
||||||
for line in lines:
|
return np.array([]), {}, []
|
||||||
line = line.strip() # 去除行尾的换行符和可能的空白字符
|
|
||||||
if not line: # 跳过空行
|
trackerboxes = boxes[0] # 因为read_tracking_output返回的是list中的numpy数组
|
||||||
continue
|
trackerfeats = feats[0]
|
||||||
|
|
||||||
if line.endswith(','):
|
|
||||||
line = line[:-1]
|
|
||||||
ftlist = [float(x) for x in line.split()]
|
|
||||||
|
|
||||||
if len(ftlist) != 265:
|
|
||||||
continue
|
|
||||||
boxes.append(ftlist[:9])
|
|
||||||
feats.append(ftlist[9:])
|
|
||||||
|
|
||||||
trackerboxes = np.array(boxes)
|
|
||||||
trackerfeats = np.array(feats)
|
|
||||||
|
|
||||||
if len(trackerboxes) == 0 or len(trackerboxes) != len(trackerfeats):
|
|
||||||
return np.array([]), {}
|
|
||||||
|
|
||||||
frmIDs = np.sort(np.unique(trackerboxes[:, 7].astype(int)))
|
frmIDs = np.sort(np.unique(trackerboxes[:, 7].astype(int)))
|
||||||
for fid in frmIDs:
|
for fid in frmIDs:
|
||||||
@ -373,9 +362,15 @@ for event in os.listdir(video_path):
|
|||||||
# print('imgfile_list', imgfile_list)
|
# print('imgfile_list', imgfile_list)
|
||||||
for track_data in track_list:
|
for track_data in track_list:
|
||||||
track_path = os.path.join(event_path, track_data)
|
track_path = os.path.join(event_path, track_data)
|
||||||
boxes, feat = extract_data_realtime(track_path)
|
boxes, feat, gift_data = extract_data_realtime(track_path)
|
||||||
camera_id = track_data.split('_')[0]
|
camera_id = track_data.split('_')[0]
|
||||||
imgfile = [x for x in imgfile_list if x.split('_')[0] == camera_id][0]
|
imgfile = [x for x in imgfile_list if x.split('_')[0] == camera_id][0]
|
||||||
|
|
||||||
|
# 打印gift数据
|
||||||
|
if gift_data:
|
||||||
|
print(f"\nGift data for {event}/{track_data}:")
|
||||||
|
print(gift_data[0]) # 现在gift_data只包含一个元素,即合并后的字符串
|
||||||
|
|
||||||
if len(boxes) > 0:
|
if len(boxes) > 0:
|
||||||
if del_staticBox: ##根据距离删除box
|
if del_staticBox: ##根据距离删除box
|
||||||
boxes_ = compute_box_dist(boxes)
|
boxes_ = compute_box_dist(boxes)
|
||||||
|
@ -43,7 +43,7 @@ def get_image_list(path):
|
|||||||
|
|
||||||
|
|
||||||
def _init():
|
def _init():
|
||||||
model = YOLOv10('ckpts/20250514/best_gift_v10n.pt')
|
model = YOLOv10('ckpts/20250620/best_gift_v10n.pt')
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
2
train.py
2
train.py
@ -17,5 +17,5 @@ model = YOLOv10('ckpts/weights/yolov10n.pt')
|
|||||||
#model.train(data='coco.yaml', epochs=1, batch=64, imgsz=640)
|
#model.train(data='coco.yaml', epochs=1, batch=64, imgsz=640)
|
||||||
#model.train(data='coco128_cls10_0924.yaml', epochs=300, batch=64, imgsz=640, resume=False)
|
#model.train(data='coco128_cls10_0924.yaml', epochs=300, batch=64, imgsz=640, resume=False)
|
||||||
#model.train(data='coco128_cls10_1010.yaml', epochs=300, batch=128, imgsz=640, resume=False)
|
#model.train(data='coco128_cls10_1010.yaml', epochs=300, batch=128, imgsz=640, resume=False)
|
||||||
model.train(data='gift.yaml', epochs=400, batch=32, imgsz=224, resume=False, save_dir='/ckpts')
|
model.train(data='gift.yaml', epochs=600, batch=32, imgsz=224, resume=False, save_dir='/ckpts')
|
||||||
#model.train(data='coco128_cls10_1010_1205.yaml', epochs=300, batch=32, imgsz=640, resume=True)
|
#model.train(data='coco128_cls10_1010_1205.yaml', epochs=300, batch=32, imgsz=640, resume=True)
|
109
ultralytics/utils/show_pr.py
Normal file
109
ultralytics/utils/show_pr.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import os.path
|
||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# [tag, bandage, null]
|
||||||
|
# [0, 1, 2]
|
||||||
|
class ShowPR:
|
||||||
|
def __init__(self, tags, prec_value, title_name=None):
|
||||||
|
self.tags = tags
|
||||||
|
self.prec_value = prec_value
|
||||||
|
self.thres = [i * 0.01 for i in range(101)]
|
||||||
|
self.title_name = title_name
|
||||||
|
|
||||||
|
def change_precValue(self, thre=0.5):
|
||||||
|
values = []
|
||||||
|
for i in range(len(self.prec_value)):
|
||||||
|
value = []
|
||||||
|
for j in range(len(self.prec_value[i])):
|
||||||
|
if self.prec_value[i][j] > thre:
|
||||||
|
value.append(1)
|
||||||
|
else:
|
||||||
|
value.append(0)
|
||||||
|
values.append(value)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _calculate_pr(self, prec_value):
|
||||||
|
FN, FP, TN, TP = 0, 0, 0, 0
|
||||||
|
for output, target in zip(prec_value, self.tags):
|
||||||
|
# print("output >> {} , target >> {}".format(output, target))
|
||||||
|
if output != target:
|
||||||
|
if target == 0:
|
||||||
|
FP += 1
|
||||||
|
elif target == 1:
|
||||||
|
FN += 1
|
||||||
|
else:
|
||||||
|
if target == 0:
|
||||||
|
TN += 1
|
||||||
|
elif target == 1:
|
||||||
|
TP += 1
|
||||||
|
if TP == 0:
|
||||||
|
prec, recall = 0, 0
|
||||||
|
else:
|
||||||
|
prec = TP / (TP + FP)
|
||||||
|
recall = TP / (TP + FN)
|
||||||
|
if TN == 0:
|
||||||
|
tn_prec, tn_recall = 0, 0
|
||||||
|
else:
|
||||||
|
tn_prec = TN / (TN + FN)
|
||||||
|
tn_recall = TN / (TN + FP)
|
||||||
|
# print("TP>>{}, FP>>{}, TN>>{}, FN>>{}".format(TP, FP, TN, FN))
|
||||||
|
return prec, recall, tn_prec, tn_recall
|
||||||
|
|
||||||
|
def calculate_multiple(self):
|
||||||
|
recall, recall_TN, PrecisePos, PreciseNeg = [], [], [], []
|
||||||
|
for thre in self.thres:
|
||||||
|
# prec_value = []
|
||||||
|
# if self.prec_value >= thre:
|
||||||
|
# prec_value.append(1)
|
||||||
|
# else:
|
||||||
|
# prec_value.append(0)
|
||||||
|
prec_value = [1 if num >= thre else 0 for num in self.prec_value]
|
||||||
|
prec, recall_pos, tn_prec, tn_recall = self._calculate_pr(prec_value)
|
||||||
|
print(
|
||||||
|
f"thre>>{thre:.2f}, recall>>{recall_pos:.4f}, precise_pos>>{prec:.4f}, recall_tn>>{tn_recall:.4f}, precise_neg>>{tn_prec:4f}")
|
||||||
|
PrecisePos.append(prec)
|
||||||
|
recall.append(recall_pos)
|
||||||
|
PreciseNeg.append(tn_prec)
|
||||||
|
recall_TN.append(tn_recall)
|
||||||
|
return recall, recall_TN, PrecisePos, PreciseNeg
|
||||||
|
|
||||||
|
def write_results_to_file(self, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||||
|
file_path = os.sep.join(['./ckpts/tracePR', self.title_name + '.txt'])
|
||||||
|
with open(file_path, 'w') as file:
|
||||||
|
file.write("threshold, recall, recall_TN, PrecisePos, PreciseNeg\n")
|
||||||
|
for thre, rec, rec_tn, prec_pos, prec_neg in zip(self.thres, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||||
|
file.write(
|
||||||
|
f"thre>>{thre:.2f}, recall>>{rec:.4f}, precise_pos>>{prec_pos:.4f}, recall_tn>>{rec_tn:.4f}, precise_neg>>{prec_neg:4f}\n")
|
||||||
|
|
||||||
|
def show_pr(self, recall, recall_TN, PrecisePos, PreciseNeg):
|
||||||
|
# self.calculate_multiple()
|
||||||
|
x = self.thres
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(x, recall, color='red', label='recall:TP/TPFN')
|
||||||
|
plt.plot(x, recall_TN, color='black', label='recall_TN:TN/TNFP')
|
||||||
|
plt.plot(x, PrecisePos, color='blue', label='PrecisePos:TP/TPFN')
|
||||||
|
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg:TN/TNFP')
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('threshold')
|
||||||
|
# if self.title_name is not None:
|
||||||
|
# plt.title(f"PrecisePos & Recall ratio:{ratio:.2f}", fontdict={'fontsize': 12, 'fontweight': 'black'})
|
||||||
|
# plt.grid(True, linestyle='--', alpha=0.5)
|
||||||
|
# 启用次刻度
|
||||||
|
# plt.minorticks_on()
|
||||||
|
|
||||||
|
# 设置主刻度的网格线
|
||||||
|
plt.grid(which='major', linestyle='-', alpha=0.5, color='gray')
|
||||||
|
|
||||||
|
# 设置次刻度的网格线
|
||||||
|
plt.grid(which='minor', linestyle=':', alpha=0.3, color='gray')
|
||||||
|
plt.savefig(os.sep.join(['./ckpts/tracePR', self.title_name + '.png']))
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
self.write_results_to_file(recall, recall_TN, PrecisePos, PreciseNeg)
|
||||||
|
|
||||||
|
def get_pic(self):
|
||||||
|
recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple()
|
||||||
|
self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg)
|
@ -110,7 +110,7 @@ class ShowPR:
|
|||||||
def get_pic(self):
|
def get_pic(self):
|
||||||
for ratio in self.ratios:
|
for ratio in self.ratios:
|
||||||
# ratio = 0.5
|
# ratio = 0.5
|
||||||
if ratio < 0.2 or ratio > 0.95:
|
if ratio < 0.1 or ratio > 0.95:
|
||||||
continue
|
continue
|
||||||
recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple(ratio)
|
recall, recall_TN, PrecisePos, PreciseNeg = self.calculate_multiple(ratio)
|
||||||
self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg, ratio)
|
self.show_pr(recall, recall_TN, PrecisePos, PreciseNeg, ratio)
|
||||||
|
Reference in New Issue
Block a user