Files
ieemoo-ai-gift/gift_demo.py
2025-06-24 16:57:16 +08:00

97 lines
3.3 KiB
Python

import argparse
import os
import time
from ultralytics import YOLOv10
import cv2
import torch
from ultralytics.utils.show_trace_pr import ShowPR
# from trace_detect import run, _init_model
import numpy as np
image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
video_ext = ["mp4", "mov", "avi", "mkv"]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--demo", default="image", help="demo type, eg. image, video and webcam"
)
parser.add_argument("--config", default='../config/gift-1.5x.yml', help="model config file path")
parser.add_argument("--model", default='../ckpts/nanodet_m_1.5x/model_best/nanodet_model_best.pth',
help="model file path")
parser.add_argument("--path", default="../../data_center/gift/objectdet/test/images",
help="path to images or video")
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
default="../../data_center/gift/objectdet/test/result",
help="whether to save the inference result of image/video",
)
args = parser.parse_args()
return args
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in image_ext:
image_names.append(apath)
return image_names
def _init():
model = YOLOv10('runs/detect/train/weights/best_gift_v10n.pt')
return model
def get_trace_event(model, path):
res_single = []
if os.path.isdir(path):
files = get_image_list(path)
else:
files = [path]
files.sort()
for image_name in files:
# all_box = run(model=model, stride=stride, pt=pt, source=image_name)
# print(image_name)
all_box = model.predict(image_name, save=False, imgsz=[224, 224], conf=0.1)
# print(all_box[0].boxes.conf)
all_box = np.array(all_box[0].boxes.conf.cpu())
if len(all_box) == 0:
res_single.append(0)
else:
res_single.append(all_box[-1])
# if sum(res_single) == 0 and (not "commodity" in path):
# with open('err.txt', 'w') as f:
# f.write(path+'\n')
return res_single
def main(path):
model = _init()
tags, result_all = [], []
classify = ['commodity', 'gift']
for cla in classify:
pre_pth = os.sep.join([path, cla])
for root, dirs, files in os.walk(pre_pth):
if not dirs:
if cla == 'commodity':
tags.append(0)
else:
tags.append(1)
res_single = get_trace_event(model, root)
result_all.append(res_single)
spr = ShowPR(tags, result_all, title_name='yolov10n')
# spr.change_precValue()
spr.get_pic()
if __name__ == "__main__":
# path = '../data_center/gift/trace_subimgs/d50' # 间距为50时
# path = '../data_center/gift/trace_subimgs/actual_test' # 永辉超市实测
path = '../data_center/gift/gift_test' #yolov10单图测试
# path = '../data_center/gift/trace_subimgs/tracluster' # tracluster方法过滤
main(path)