import argparse import os import time from ultralytics import YOLOv10 import cv2 import torch from ultralytics.utils.show_trace_pr import ShowPR # from trace_detect import run, _init_model import numpy as np image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"] video_ext = ["mp4", "mov", "avi", "mkv"] def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--demo", default="image", help="demo type, eg. image, video and webcam" ) parser.add_argument("--config", default='../config/gift-1.5x.yml', help="model config file path") parser.add_argument("--model", default='../ckpts/nanodet_m_1.5x/model_best/nanodet_model_best.pth', help="model file path") parser.add_argument("--path", default="../../data_center/gift/objectdet/test/images", help="path to images or video") parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id") parser.add_argument( "--save_result", default="../../data_center/gift/objectdet/test/result", help="whether to save the inference result of image/video", ) args = parser.parse_args() return args def get_image_list(path): image_names = [] for maindir, subdir, file_name_list in os.walk(path): for filename in file_name_list: apath = os.path.join(maindir, filename) ext = os.path.splitext(apath)[1] if ext in image_ext: image_names.append(apath) return image_names def _init(): model = YOLOv10('ckpts/20250620/best_gift_v10n.pt') return model def get_trace_event(model, path): res_single = [] if os.path.isdir(path): files = get_image_list(path) else: files = [path] files.sort() for image_name in files: # all_box = run(model=model, stride=stride, pt=pt, source=image_name) # print(image_name) all_box = model.predict(image_name, save=False, imgsz=[224, 224], conf=0.1) # print(all_box[0].boxes.conf) all_box = np.array(all_box[0].boxes.conf.cpu()) if len(all_box) == 0: res_single.append(0) else: res_single.append(all_box[-1]) # if sum(res_single) == 0 and (not "commodity" in path): # with open('err.txt', 'w') as f: # f.write(path+'\n') return res_single def main(path): model = _init() tags, result_all = [], [] classify = ['commodity', 'gift'] for cla in classify: pre_pth = os.sep.join([path, cla]) for root, dirs, files in os.walk(pre_pth): if not dirs: if cla == 'commodity': tags.append(0) else: tags.append(1) res_single = get_trace_event(model, root) result_all.append(res_single) spr = ShowPR(tags, result_all, title_name='yolov10n') # spr.change_precValue() spr.get_pic() if __name__ == "__main__": # path = '../data_center/gift/trace_subimgs/d50' # 间距为50时 path = '../data_center/gift/trace_subimgs/actual_test' # 永辉超市实测 # path = '../data_center/gift/trace_subimgs/tracluster' # tracluster方法过滤 main(path)