add yolo v10 and modify pipeline

This commit is contained in:
王庆刚
2025-03-28 13:19:54 +08:00
parent 183299c06b
commit 798c596acc
471 changed files with 19109 additions and 7342 deletions

View File

@ -64,7 +64,10 @@ from hands.hand_inference import hand_pose
from contrast.feat_extract.config import config as conf
from contrast.feat_extract.inference import FeatsInterface
from ultralytics import YOLOv10
ReIDEncoder = FeatsInterface(conf)
print(f'load model {conf.testbackbone} in {Path(__file__).stem}')
IMG_FORMATS = '.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff', '.webp', '.pfm' # include image suffixes
VID_FORMATS = '.asf', '.avi', '.gif', '.m4v', '.mkv', '.mov', '.mp4', '.mpeg', '.mpg', '.ts', '.wmv' # include video suffixes
@ -131,12 +134,158 @@ def init_trackers(tracker_yaml = None, bs=1):
trackers = []
for _ in range(bs):
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
if cfg.with_reid:
tracker.encoder = ReIDEncoder
trackers.append(tracker)
return trackers
'''=============== used in pipeline.py for Yolov10 =================='''
def yolov10_resnet_tracker(
weights = ROOT / 'ckpts/best_v10s_width0375_1205.pt', # model path or triton URL
source = '', # file/dir/URL/glob/screen/0(webcam)
save_dir = '',
is_save_img = True,
is_save_video = True,
tracker_yaml = "./tracking/trackers/cfg/botsort.yaml",
line_thickness=3, # bounding box thickness (pixels)
hide_labels=False, # hide labels
):
## load a custom model
model = YOLOv10(weights)
'''=============== used in pipeline.py =================='''
custom = {"conf": 0.25, "batch": 1, "save": False, "mode": "predict"}
kwargs = {"save": True, "imgsz": 640, "conf": 0.1}
args = {**model.overrides, **custom, **kwargs}
predictor = model.task_map[model.task]["predictor"](overrides=args, _callbacks=model.callbacks)
vid_path, vid_writer = None, None
tracker = init_trackers(tracker_yaml)[0]
yoloResnetTracker = []
for i, result in enumerate(predictor.stream_inference(source)):
datamode = predictor.dataset.mode
det = result.boxes.data.cpu().numpy()
im0 = result.orig_img
names = result.names
path = result.path
im_array = result.plot()
## to do tracker.update()
det_tracking = Boxes(det, im0.shape)
tracks, outfeats = tracker.update(det_tracking, im0)
if datamode == "video":
frameId = predictor.dataset.frame
elif datamode == "image":
frameId = predictor.dataset.count
annotator = Annotator(im0.copy(), line_width=line_thickness, example=str(names))
simdict, simdict1 = {}, {}
for fid, bid, mfeat, cfeat, features in outfeats:
if mfeat is not None and cfeat is not None:
simi = 1 - np.maximum(0.0, cdist(mfeat[None, :], cfeat[None, :], "cosine"))[0][0]
simdict.update({f"{int(frameId)}_{int(bid)}":simi})
if cfeat is not None and len(features)>=2:
mfeat = features[-2]
simi = 1 - np.maximum(0.0, cdist(mfeat[None, :], cfeat[None, :], "cosine"))[0][0]
simdict1.update({f"{int(frameId)}_{int(bid)}":simi})
if len(tracks) > 0:
tracks[:, 7] = frameId
# trackerBoxes = np.concatenate([trackerBoxes, tracks], axis=0)
'''================== 1. 存储 dets/subimgs/features Dict ============='''
imgs, features = ReIDEncoder.inference(im0, tracks)
imgdict, featdict = {}, {}
for ii, bid in enumerate(tracks[:, 8]):
featdict.update({f"{int(frameId)}_{int(bid)}": features[ii, :]}) # [f"feat_{int(bid)}"] = features[i, :]
imgdict.update({f"{int(frameId)}_{int(bid)}": imgs[ii]})
frameDict = {"path": path,
"fid": int(frameId),
"bboxes": det,
"tboxes": tracks,
"imgs": imgdict,
"feats": featdict,
"featsimi": simdict, # 当前 box 特征和该轨迹 smooth_feat 特征的相似度
"featsimi1": simdict1 # 当前 box 特征和该轨迹前一个 box 特征的相似度
}
yoloResnetTracker.append(frameDict)
# imgs, features = inference_image(im0, tracks)
# TrackerFeats = np.concatenate([TrackerFeats, features], axis=0)
'''================== 2. 提取手势位置 ==================='''
for *xyxy, id, conf, cls, fid, bid in reversed(tracks):
name = ('' if id==-1 else f'id:{int(id)} ') + names[int(cls)]
if f"{int(frameId)}_{int(bid)}" in simdict.keys():
sim = simdict[f"{int(frameId)}_{int(bid)}"]
label = f"{name} {sim:.2f}"
else:
label = None if hide_labels else name
# label = None if hide_labels else (name if hide_conf else f'{name} {conf:.1f}')
if id >=0 and cls==0:
color = colors(int(cls), True)
elif id >=0 and cls!=0:
color = colors(int(id), True)
else:
color = colors(19, True) # 19为调色板的最后一个元素
annotator.box_label(xyxy, label, color=color)
'''====== Save results (image and video) ======'''
# save_path = str(save_dir / Path(path).name) # 带有后缀名
im0 = annotator.result()
if is_save_img:
save_path_img = str(save_dir / Path(path).stem)
if datamode == 'image':
imgpath = save_path_img + ".png"
if datamode == 'video' :
imgpath = save_path_img + f"_{frameId}.png"
cv2.imwrite(Path(imgpath), im0)
# if dataset.mode == 'video' and is_save_video:
if is_save_video:
if datamode == 'video':
video_path = str(save_dir / Path(path).stem) + '.mp4' # 带有后缀名
else:
videoname = str(Path(path).stem).split('_')[0] + '.mp4'
video_path = str(save_dir / videoname)
if vid_path != video_path: # new video
vid_path = video_path
vid_cap = predictor.dataset.cap
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 25, im0.shape[1], im0.shape[0]
## for image rotating in dataloader.LoadImages.__next__()
w, h = im0.shape[1], im0.shape[0]
video_path = str(Path(video_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
vid_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer.write(im0)
return yoloResnetTracker
'''=============== used in pipeline.py for Yolov5 =================='''
@smart_inference_mode()
def yolo_resnet_tracker(
weights=ROOT / 'yolov5s.pt', # model path or triton URL
@ -660,8 +809,6 @@ def run(
def parse_opt():
modelpath = ROOT / 'ckpts/best_cls10_0906.pt' # 'ckpts/best_15000_0908.pt', 'ckpts/yolov5s.pt', 'ckpts/best_20000_cls30.pt, best_yolov5m_250000'
'''datapath为视频文件目录或视频文件'''
datapath = r"D:/datasets/ym/videos/标记视频/" # ROOT/'data/videos', ROOT/'data/images' images
# datapath = r"D:\datasets\ym\highvalue\videos"
@ -714,7 +861,7 @@ def find_video_imgs(root_dir):
def main():
def main_v5():
'''
run(): 单张图像或单个视频文件的推理,不支持图像序列,
'''
@ -733,10 +880,10 @@ def main():
# p = r"D:\exhibition\images\153112511_0_seek_105.mp4"
# p = r"D:\exhibition\images\image"
p = r"D:\全实时\202502\tracker\1_1740891284792.mp4"
optdict["project"] = r"D:\全实时\202502\tracker"
# optdict["project"] = r"D:\exhibition\result"
p = r"D:\datasets\ym\后台数据\unzip\20250310-175352-741"
optdict["project"] = r"D:\work\result"
optdict["weights"] = ROOT / 'ckpts/best_cls10_0906.pt'
if os.path.isdir(p):
files = find_video_imgs(p)
k = 0
@ -745,17 +892,39 @@ def main():
run(**optdict)
k += 1
if k == 1:
if k == 2:
break
elif os.path.isfile(p):
optdict["source"] = p
run(**optdict)
def main_v10():
datapath = r'D:\datasets\ym\后台数据\unzip\20250310-175352-741\0.mp4'
savepath = r'D:\work\result'
savepath = savepath / Path(str(Path(datapath).stem))
if not savepath.exists():
savepath.mkdir(parents=True, exist_ok=True)
weightpath = ROOT / 'ckpts/best_v10s_width0375_1205.pt'
optdict = {}
optdict["weights"] = weightpath
optdict["source"] = datapath
optdict["save_dir"] = savepath
optdict["is_save_img"] = True
optdict["is_save_video"] = True
yrtOut = yolov10_resnet_tracker(**optdict)
if __name__ == '__main__':
main()
# main_v5()
main_v10()