This commit is contained in:
王庆刚
2024-11-04 18:06:52 +08:00
parent dfb2272a15
commit 5ecc1285d4
41 changed files with 2552 additions and 440 deletions

View File

@ -150,6 +150,8 @@ def yolo_resnet_tracker(
save_crop=False, # save cropped prediction boxes
nosave=False, # do not save images/videos
is_save_img = False,
is_save_video = True,
classes=None, # filter by class: --class 0, or --class 0 2 3
@ -166,9 +168,7 @@ def yolo_resnet_tracker(
vid_stride=1, # video frame-rate stride
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
# source = str(source)
# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
@ -209,7 +209,6 @@ def yolo_resnet_tracker(
for i, det in enumerate(pred): # per image
im0 = im0s.copy()
save_path = str(save_dir / Path(path).name) # im.jpg
s += '%gx%g ' % im.shape[2:] # print string
annotator = Annotator(im0.copy(), line_width=line_thickness, example=str(names))
@ -228,7 +227,14 @@ def yolo_resnet_tracker(
tracks = tracker.update(det_tracking, im0)
if len(tracks) == 0:
continue
tracks[:, 7] = dataset.frame
if dataset.mode == "video":
frameId = dataset.frame
else:
frameId = dataset.count
tracks[:, 7] = frameId
'''================== 1. 存储 dets/subimgs/features Dict ============='''
imgs, features = inference_image(im0, tracks)
@ -242,7 +248,7 @@ def yolo_resnet_tracker(
imgdict.update({int(bid): imgs[ii]}) # [f"img_{int(bid)}"] = imgs[i]
boxdict.update({int(bid): tracks[ii, :]}) # [f"box_{int(bid)}"] = tracks[i, :]
featdict.update({int(bid): features[ii, :]}) # [f"feat_{int(bid)}"] = features[i, :]
TracksDict[f"frame_{int(dataset.frame)}"] = {"imgs":imgdict, "boxes":boxdict, "feats":featdict}
TracksDict[f"frame_{int(frameId)}"] = {"imgs":imgdict, "boxes":boxdict, "feats":featdict}
track_boxes = np.concatenate([track_boxes, tracks], axis=0)
@ -256,20 +262,21 @@ def yolo_resnet_tracker(
elif id >=0 and cls!=0:
color = colors(int(id), True)
else:
color = colors(19, True) # 19为调色板的最后一个元素
color = colors(19, True) # 19为调色板的最后一个元素
annotator.box_label(xyxy, label, color=color)
# Save results (image and video with tracking)
'''====== Save results (image and video) ======'''
save_path = str(save_dir / Path(path).name) # 带有后缀名
im0 = annotator.result()
save_path_img, ext = os.path.splitext(save_path)
if save_img:
# if dataset.mode == 'image':
# imgpath = save_path_img + f"_{dataset}.png"
# else:
# imgpath = save_path_img + f"_{dataset.frame}.png"
# cv2.imwrite(Path(imgpath), im0)
if is_save_img:
save_path_img, ext = os.path.splitext(save_path)
if dataset.mode == 'image':
imgpath = save_path_img + ".png"
else:
imgpath = save_path_img + f"_{frameId}.png"
cv2.imwrite(Path(imgpath), im0)
if dataset.mode == 'video' and is_save_video:
if vid_path[i] != save_path: # new video
vid_path[i] = save_path
if isinstance(vid_writer[i], cv2.VideoWriter):
@ -396,8 +403,8 @@ def run(
imgshow = im0s.copy()
## ============================= tracking 功能只处理视频writed by WQG
if dataset.mode == 'image':
continue
# if dataset.mode == 'image':
# continue
with dt[0]:
im = torch.from_numpy(im).to(model.device)
@ -482,7 +489,14 @@ def run(
tracks = tracker.update(det_tracking, im0)
if len(tracks) == 0:
continue
tracks[:, 7] = dataset.frame
if dataset.mode == "video":
frameId = dataset.frame
else:
frameId = dataset.count
tracks[:, 7] = frameId
tracks[:, 7] = frameId
'''================== 1. 存储 dets/subimgs/features Dict ============='''
imgs, features = inference_image(im0, tracks)
@ -496,7 +510,7 @@ def run(
imgdict.update({int(bid): imgs[ii]}) # [f"img_{int(bid)}"] = imgs[i]
boxdict.update({int(bid): tracks[ii, :]}) # [f"box_{int(bid)}"] = tracks[i, :]
featdict.update({int(bid): features[ii, :]}) # [f"feat_{int(bid)}"] = features[i, :]
TracksDict[f"frame_{int(dataset.frame)}"] = {"imgs":imgdict, "boxes":boxdict, "feats":featdict}
TracksDict[f"frame_{int(frameId)}"] = {"imgs":imgdict, "boxes":boxdict, "feats":featdict}
track_boxes = np.concatenate([track_boxes, tracks], axis=0)
@ -535,7 +549,7 @@ def run(
if dataset.mode == 'image':
imgpath = save_path_img + f"_{dataset}.png"
else:
imgpath = save_path_img + f"_{dataset.frame}.png"
imgpath = save_path_img + f"_{frameId}.png"
cv2.imwrite(Path(imgpath), im0)
@ -664,23 +678,37 @@ print('=======')
def main(opt):
check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
p = r"D:\datasets\ym\永辉测试数据_202404\20240402"
optdict = vars(opt)
p = r"D:\datasets\ym"
p = r"D:\datasets\ym\exhibition\153112511_0_seek_105.mp4"
files = []
k = 0
if os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))
for file in files:
optdict["source"] = file
run(**optdict)
k += 1
if k == 2:
if k == 1:
break
elif os.path.isfile(p):
optdict["source"] = p
run(**vars(opt))
def main_imgdir(opt):
check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
optdict = vars(opt)
optdict["project"] = r"\\192.168.1.28\share\realtime"
optdict["source"] = r"\\192.168.1.28\share\realtime\addReturn\add\1728978052624"
run(**optdict)
def main_loop(opt):
check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
@ -725,8 +753,9 @@ def main_loop(opt):
if __name__ == '__main__':
opt = parse_opt()
# main(opt)
main_loop(opt)
main(opt)
# main_imgdir(opt)
# main_loop(opt)