Files
detecttracking/pipeline.py
2025-02-28 17:55:40 +08:00

304 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Sun Sep 29 08:59:21 2024
@author: ym
"""
import os
# import sys
import cv2
import pickle
import numpy as np
from pathlib import Path
from scipy.spatial.distance import cdist
from track_reid import yolo_resnet_tracker
from tracking.dotrack.dotracks_back import doBackTracks
from tracking.dotrack.dotracks_front import doFrontTracks
from tracking.utils.drawtracks import plot_frameID_y2, draw_all_trajectories
from utils.getsource import get_image_pairs, get_video_pairs
from tracking.utils.read_data import read_similar
def save_subimgs(imgdict, boxes, spath, ctype, featdict = None):
'''
当前 box 特征和该轨迹前一个 box 特征的相似度,可用于和跟踪序列中的相似度进行比较
'''
boxes = boxes[np.argsort(boxes[:, 7])]
for i in range(len(boxes)):
simi = None
tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8])
if i>0:
_, fid0, bid0 = int(boxes[i-1, 4]), int(boxes[i-1, 7]), int(boxes[i-1, 8])
if f"{fid0}_{bid0}" in featdict.keys() and f"{fid}_{bid}" in featdict.keys():
feat0 = featdict[f"{fid0}_{bid0}"]
feat1 = featdict[f"{fid}_{bid}"]
simi = 1 - np.maximum(0.0, cdist(feat0[None, :], feat1[None, :], "cosine"))[0][0]
img = imgdict[f"{fid}_{bid}"]
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png"
if simi is not None:
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simi:.2f}.png"
cv2.imwrite(imgpath, img)
def save_subimgs_1(imgdict, boxes, spath, ctype, simidict = None):
'''
当前 box 特征和该轨迹 smooth_feat 特征的相似度, yolo_resnet_tracker 函数中,
采用该方式记录特征相似度
'''
for i in range(len(boxes)):
tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8])
key = f"{fid}_{bid}"
img = imgdict[key]
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png"
if simidict is not None and key in simidict.keys():
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simidict[key]:.2f}.png"
cv2.imwrite(imgpath, img)
def pipeline(
eventpath,
savepath,
SourceType,
weights
):
'''
eventpath: 单个事件的存储路径
'''
if SourceType == "video":
vpaths = get_video_pairs(eventpath)
elif SourceType == "image":
vpaths = get_image_pairs(eventpath)
optdict = {}
optdict["weights"] = weights
event_tracks = []
## 构造购物事件字典
evtname = Path(eventpath).stem
barcode = evtname.split('_')[-1] if len(evtname.split('_'))>=2 \
and len(evtname.split('_')[-1])>=8 \
and evtname.split('_')[-1].isdigit() else ''
'''事件结果存储文件夹'''
if not savepath:
savepath = Path(__file__).resolve().parents[0] / "events_result"
savepath_pipeline = Path(savepath) / Path("Yolos_Tracking") / evtname
"""ShoppingDict pickle 文件保存地址 """
savepath_spdict = Path(savepath) / "ShoppingDict_pkfile"
if not savepath_spdict.exists():
savepath_spdict.mkdir(parents=True, exist_ok=True)
pf_path = Path(savepath_spdict) / Path(str(evtname)+".pickle")
# if pf_path.exists():
# return
'''====================== 构造 ShoppingDict 模块 ======================='''
ShoppingDict = {"eventPath": eventpath,
"eventName": evtname,
"barcode": barcode,
"eventType": '', # "input", "output", "other"
"frontCamera": {},
"backCamera": {},
"one2n": [] #
}
procpath = Path(eventpath).joinpath('process.data')
if procpath.is_file():
SimiDict = read_similar(procpath)
ShoppingDict["one2n"] = SimiDict['one2n']
for vpath in vpaths:
'''================= 1. 构造相机事件字典 ================='''
CameraEvent = {"cameraType": '', # "front", "back"
"videoPath": '',
"imagePaths": [],
"yoloResnetTracker": [],
"tracking": [],
}
if isinstance(vpath, list):
CameraEvent["imagePaths"] = vpath
bname = os.path.basename(vpath[0])
if not isinstance(vpath, list):
CameraEvent["videoPath"] = vpath
bname = os.path.basename(vpath).split('.')[0]
if bname.split('_')[0] == "0" or bname.find('back')>=0:
CameraEvent["cameraType"] = "back"
if bname.split('_')[0] == "1" or bname.find('front')>=0:
CameraEvent["cameraType"] = "front"
'''================= 2. 事件结果存储文件夹 ================='''
if isinstance(vpath, list):
savepath_pipeline_imgs = savepath_pipeline / Path("images")
else:
savepath_pipeline_imgs = savepath_pipeline / Path(str(Path(vpath).stem))
if not savepath_pipeline_imgs.exists():
savepath_pipeline_imgs.mkdir(parents=True, exist_ok=True)
savepath_pipeline_subimgs = savepath_pipeline / Path("subimgs")
if not savepath_pipeline_subimgs.exists():
savepath_pipeline_subimgs.mkdir(parents=True, exist_ok=True)
'''================= 3. Yolo + Resnet + Tracker ================='''
optdict["source"] = vpath
optdict["save_dir"] = savepath_pipeline_imgs
yrtOut = yolo_resnet_tracker(**optdict)
CameraEvent["yoloResnetTracker"] = yrtOut
'''================= 4. tracking ================='''
'''(1) 生成用于 tracking 模块的 boxes、feats'''
trackerboxes = np.empty((0, 9), dtype=np.float64)
trackefeats = {}
for frameDict in yrtOut:
tboxes = frameDict["tboxes"]
ffeats = frameDict["feats"]
trackerboxes = np.concatenate((trackerboxes, np.array(tboxes)), axis=0)
for i in range(len(tboxes)):
fid, bid = int(tboxes[i, 7]), int(tboxes[i, 8])
trackefeats.update({f"{fid}_{bid}": ffeats[f"{fid}_{bid}"]})
'''(2) tracking, 后摄'''
if CameraEvent["cameraType"] == "back":
vts = doBackTracks(trackerboxes, trackefeats)
vts.classify()
event_tracks.append(("back", vts))
CameraEvent["tracking"] = vts
ShoppingDict["backCamera"] = CameraEvent
'''(2) tracking, 前摄'''
if CameraEvent["cameraType"] == "front":
vts = doFrontTracks(trackerboxes, trackefeats)
vts.classify()
event_tracks.append(("front", vts))
CameraEvent["tracking"] = vts
ShoppingDict["frontCamera"] = CameraEvent
'''========================== 保存模块 ================================='''
'''(1) 保存 ShoppingDict 事件'''
with open(str(pf_path), 'wb') as f:
pickle.dump(ShoppingDict, f)
'''(2) 保存 Tracking 输出的运动轨迹子图,并记录相似度'''
for CamerType, vts in event_tracks:
if len(vts.tracks)==0: continue
if CamerType == 'front':
yolos = ShoppingDict["frontCamera"]["yoloResnetTracker"]
ctype = 1
if CamerType == 'back':
yolos = ShoppingDict["backCamera"]["yoloResnetTracker"]
ctype = 0
imgdict, featdict, simidict = {}, {}, {}
for y in yolos:
imgdict.update(y["imgs"])
featdict.update(y["feats"])
simidict.update(y["featsimi"])
for track in vts.Residual:
if isinstance(track, np.ndarray):
save_subimgs(imgdict, track, savepath_pipeline_subimgs, ctype, featdict)
else:
save_subimgs(imgdict, track.slt_boxes, savepath_pipeline_subimgs, ctype, featdict)
'''(3) 轨迹显示与保存'''
illus = [None, None]
for CamerType, vts in event_tracks:
if len(vts.tracks)==0: continue
if CamerType == 'front':
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/board_ftmp_line.png")
h, w = edgeline.shape[:2]
# nh, nw = h//2, w//2
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
illus[0] = img_tracking
plt = plot_frameID_y2(vts)
plt.savefig(os.path.join(savepath_pipeline, "front_y2.png"))
if CamerType == 'back':
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/edgeline.png")
h, w = edgeline.shape[:2]
# nh, nw = h//2, w//2
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
illus[1] = img_tracking
illus = [im for im in illus if im is not None]
if len(illus):
img_cat = np.concatenate(illus, axis = 1)
if len(illus)==2:
H, W = img_cat.shape[:2]
cv2.line(img_cat, (int(W/2), 0), (int(W/2), int(H)), (128, 128, 255), 3)
trajpath = os.path.join(savepath_pipeline, "trajectory.png")
cv2.imwrite(trajpath, img_cat)
def main():
'''
函数pipeline(),遍历事件文件夹,选择类型 image 或 video,
'''
parmDict = {}
evtdir = r"D:\全实时\202502"
parmDict["SourceType"] = "video" # video, image
parmDict["savepath"] = r"D:\全实时\202502\result"
parmDict["weights"] = r'D:\DetectTracking\ckpts\best_cls10_0906.pt'
evtdir = Path(evtdir)
k, errEvents = 0, []
for item in evtdir.iterdir():
if item.is_dir():
item = evtdir/Path("20250228-160049-188_6921168558018_6921168558018")
parmDict["eventpath"] = item
# pipeline(**parmDict)
try:
pipeline(**parmDict)
except Exception as e:
errEvents.append(str(item))
k+=1
if k==1:
break
errfile = os.path.join(parmDict["savepath"], f'error_events.txt')
with open(errfile, 'w', encoding='utf-8') as f:
for line in errEvents:
f.write(line + '\n')
if __name__ == "__main__":
main()