mirror of
https://gitee.com/nanjing-yimao-information/ieemoo-ai-gift.git
synced 2025-08-18 13:20:25 +00:00
97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
import argparse
|
|
import os
|
|
import time
|
|
from ultralytics import YOLOv10
|
|
import cv2
|
|
import torch
|
|
|
|
from ultralytics.utils.show_trace_pr import ShowPR
|
|
# from trace_detect import run, _init_model
|
|
import numpy as np
|
|
image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
|
|
video_ext = ["mp4", "mov", "avi", "mkv"]
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--demo", default="image", help="demo type, eg. image, video and webcam"
|
|
)
|
|
parser.add_argument("--config", default='../config/gift-1.5x.yml', help="model config file path")
|
|
parser.add_argument("--model", default='../ckpts/nanodet_m_1.5x/model_best/nanodet_model_best.pth',
|
|
help="model file path")
|
|
parser.add_argument("--path", default="../../data_center/gift/objectdet/test/images",
|
|
help="path to images or video")
|
|
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
|
|
parser.add_argument(
|
|
"--save_result",
|
|
default="../../data_center/gift/objectdet/test/result",
|
|
help="whether to save the inference result of image/video",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def get_image_list(path):
|
|
image_names = []
|
|
for maindir, subdir, file_name_list in os.walk(path):
|
|
for filename in file_name_list:
|
|
apath = os.path.join(maindir, filename)
|
|
ext = os.path.splitext(apath)[1]
|
|
if ext in image_ext:
|
|
image_names.append(apath)
|
|
return image_names
|
|
|
|
|
|
def _init():
|
|
model = YOLOv10('ckpts/20250620/best_gift_v10n.pt')
|
|
return model
|
|
|
|
|
|
def get_trace_event(model, path):
|
|
res_single = []
|
|
if os.path.isdir(path):
|
|
files = get_image_list(path)
|
|
else:
|
|
files = [path]
|
|
files.sort()
|
|
for image_name in files:
|
|
# all_box = run(model=model, stride=stride, pt=pt, source=image_name)
|
|
# print(image_name)
|
|
all_box = model.predict(image_name, save=False, imgsz=[224, 224], conf=0.1)
|
|
# print(all_box[0].boxes.conf)
|
|
all_box = np.array(all_box[0].boxes.conf.cpu())
|
|
if len(all_box) == 0:
|
|
res_single.append(0)
|
|
else:
|
|
res_single.append(all_box[-1])
|
|
# if sum(res_single) == 0 and (not "commodity" in path):
|
|
# with open('err.txt', 'w') as f:
|
|
# f.write(path+'\n')
|
|
return res_single
|
|
|
|
|
|
def main(path):
|
|
model = _init()
|
|
tags, result_all = [], []
|
|
classify = ['commodity', 'gift']
|
|
for cla in classify:
|
|
pre_pth = os.sep.join([path, cla])
|
|
for root, dirs, files in os.walk(pre_pth):
|
|
if not dirs:
|
|
if cla == 'commodity':
|
|
tags.append(0)
|
|
else:
|
|
tags.append(1)
|
|
res_single = get_trace_event(model, root)
|
|
result_all.append(res_single)
|
|
spr = ShowPR(tags, result_all, title_name='yolov10n')
|
|
# spr.change_precValue()
|
|
spr.get_pic()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# path = '../data_center/gift/trace_subimgs/d50' # 间距为50时
|
|
path = '../data_center/gift/trace_subimgs/actual_test' # 永辉超市实测
|
|
# path = '../data_center/gift/trace_subimgs/tracluster' # tracluster方法过滤
|
|
main(path)
|