mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-23 23:50:25 +00:00
测试代码
This commit is contained in:
97
gift_demo.py
Normal file
97
gift_demo.py
Normal file
@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from ultralytics import YOLOv10
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from ultralytics.utils.show_trace_pr import ShowPR
|
||||
# from trace_detect import run, _init_model
|
||||
import numpy as np
|
||||
image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
|
||||
video_ext = ["mp4", "mov", "avi", "mkv"]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--demo", default="image", help="demo type, eg. image, video and webcam"
|
||||
)
|
||||
parser.add_argument("--config", default='../config/gift-1.5x.yml', help="model config file path")
|
||||
parser.add_argument("--model", default='../ckpts/nanodet_m_1.5x/model_best/nanodet_model_best.pth',
|
||||
help="model file path")
|
||||
parser.add_argument("--path", default="../../data_center/gift/objectdet/test/images",
|
||||
help="path to images or video")
|
||||
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
|
||||
parser.add_argument(
|
||||
"--save_result",
|
||||
default="../../data_center/gift/objectdet/test/result",
|
||||
help="whether to save the inference result of image/video",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def get_image_list(path):
|
||||
image_names = []
|
||||
for maindir, subdir, file_name_list in os.walk(path):
|
||||
for filename in file_name_list:
|
||||
apath = os.path.join(maindir, filename)
|
||||
ext = os.path.splitext(apath)[1]
|
||||
if ext in image_ext:
|
||||
image_names.append(apath)
|
||||
return image_names
|
||||
|
||||
|
||||
def _init():
|
||||
model = YOLOv10('runs/detect/train/weights/best_gift_v10n.pt')
|
||||
return model
|
||||
|
||||
|
||||
def get_trace_event(model, path):
|
||||
res_single = []
|
||||
if os.path.isdir(path):
|
||||
files = get_image_list(path)
|
||||
else:
|
||||
files = [path]
|
||||
files.sort()
|
||||
for image_name in files:
|
||||
# all_box = run(model=model, stride=stride, pt=pt, source=image_name)
|
||||
# print(image_name)
|
||||
all_box = model.predict(image_name, save=False, imgsz=[224, 224], conf=0.1)
|
||||
# print(all_box[0].boxes.conf)
|
||||
all_box = np.array(all_box[0].boxes.conf.cpu())
|
||||
if len(all_box) == 0:
|
||||
res_single.append(0)
|
||||
else:
|
||||
res_single.append(all_box[-1])
|
||||
# if sum(res_single) == 0 and (not "commodity" in path):
|
||||
# with open('err.txt', 'w') as f:
|
||||
# f.write(path+'\n')
|
||||
return res_single
|
||||
|
||||
|
||||
def main(path):
|
||||
model = _init()
|
||||
tags, result_all = [], []
|
||||
classify = ['commodity', 'gift']
|
||||
for cla in classify:
|
||||
pre_pth = os.sep.join([path, cla])
|
||||
for root, dirs, files in os.walk(pre_pth):
|
||||
if not dirs:
|
||||
if cla == 'commodity':
|
||||
tags.append(0)
|
||||
else:
|
||||
tags.append(1)
|
||||
res_single = get_trace_event(model, root)
|
||||
result_all.append(res_single)
|
||||
spr = ShowPR(tags, result_all, title_name='yolov10n')
|
||||
# spr.change_precValue()
|
||||
spr.get_pic()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# path = '../data_center/gift/trace_subimgs/d50' # 间距为50时
|
||||
# path = '../data_center/gift/trace_subimgs/actual_test' # 永辉超市实测
|
||||
path = '../data_center/gift/gift_test' #yolov10单图测试
|
||||
# path = '../data_center/gift/trace_subimgs/tracluster' # tracluster方法过滤
|
||||
main(path)
|
Reference in New Issue
Block a user