mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-18 21:30:25 +00:00
增加测试维度
This commit is contained in:
16
gift_demo.py
16
gift_demo.py
@ -5,7 +5,7 @@ from ultralytics import YOLOv10
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from ultralytics.utils.show_trace_pr import ShowPR
|
||||
from ultralytics.utils.show_pr import ShowPR
|
||||
# from trace_detect import run, _init_model
|
||||
import numpy as np
|
||||
image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
|
||||
@ -43,7 +43,7 @@ def get_image_list(path):
|
||||
|
||||
|
||||
def _init():
|
||||
model = YOLOv10('runs/detect/train/weights/best_gift_v10n.pt')
|
||||
model = YOLOv10('ckpts/20250620/best_gift_v10n.pt')
|
||||
return model
|
||||
|
||||
|
||||
@ -72,6 +72,7 @@ def get_trace_event(model, path):
|
||||
|
||||
def main(path):
|
||||
model = _init()
|
||||
tags_tmp, result_all_tmp = [], []
|
||||
tags, result_all = [], []
|
||||
classify = ['commodity', 'gift']
|
||||
for cla in classify:
|
||||
@ -79,12 +80,15 @@ def main(path):
|
||||
for root, dirs, files in os.walk(pre_pth):
|
||||
if not dirs:
|
||||
if cla == 'commodity':
|
||||
tags.append(0)
|
||||
tags_tmp.append(0)
|
||||
else:
|
||||
tags.append(1)
|
||||
tags_tmp.append(1)
|
||||
res_single = get_trace_event(model, root)
|
||||
result_all.append(res_single)
|
||||
spr = ShowPR(tags, result_all, title_name='yolov10n')
|
||||
result_all_tmp.append(res_single)
|
||||
for tag, result in zip(tags_tmp, result_all_tmp):
|
||||
tags += [tag]*len(result)
|
||||
result_all += result
|
||||
spr = ShowPR(tags, result_all, title_name='yolov10n', )
|
||||
# spr.change_precValue()
|
||||
spr.get_pic()
|
||||
|
||||
|
Reference in New Issue
Block a user