Files
ieemoo-ai-review/detecttracking/pipeline.py
2025-01-22 13:16:44 +08:00

270 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Sun Sep 29 08:59:21 2024
@author: ym
"""
import os
# import sys
import cv2
import pickle
import numpy as np
from pathlib import Path
from track_reid import yolo_resnet_tracker
from tracking.dotrack.dotracks_back import doBackTracks
from tracking.dotrack.dotracks_front import doFrontTracks
from tracking.utils.drawtracks import plot_frameID_y2, draw_all_trajectories
from utils.getsource import get_image_pairs, get_video_pairs
from tracking.utils.read_data import read_similar
def save_subimgs(imgdict, boxes, spath, ctype):
for i in range(len(boxes)):
fid, bid = int(boxes[i, 7]), int(boxes[i, 8])
if f"{fid}_{bid}" in imgdict.keys():
img = imgdict[f"{fid}_{bid}"]
imgpath = spath / f"{ctype}_{fid}_{bid}.png"
cv2.imwrite(imgpath, img)
def pipeline(
eventpath,
savepath,
SourceType,
weights
):
'''
eventpath: 单个事件的存储路径
'''
if SourceType == "video":
vpaths = get_video_pairs(eventpath)
elif SourceType == "image":
vpaths = get_image_pairs(eventpath)
optdict = {}
optdict["weights"] = weights
event_tracks = []
## 构造购物事件字典
evtname = Path(eventpath).stem
barcode = evtname.split('_')[-1] if len(evtname.split('_'))>=2 \
and len(evtname.split('_')[-1])>=8 \
and evtname.split('_')[-1].isdigit() else ''
'''事件结果存储文件夹'''
if not savepath:
savepath = Path(__file__).resolve().parents[0] / "events_result"
savepath_pipeline = Path(savepath) / Path("Yolos_Tracking") / evtname
"""ShoppingDict pickle 文件保存地址 """
savepath_spdict = Path(savepath) / "ShoppingDict_pkfile"
if not savepath_spdict.exists():
savepath_spdict.mkdir(parents=True, exist_ok=True)
pf_path = Path(savepath_spdict) / Path(str(evtname)+".pickle")
# if pf_path.exists():
# return
ShoppingDict = {"eventPath": eventpath,
"eventName": evtname,
"barcode": barcode,
"eventType": '', # "input", "output", "other"
"frontCamera": {},
"backCamera": {},
"one2n": []
}
procpath = Path(eventpath).joinpath('process.data')
if procpath.is_file():
SimiDict = read_similar(procpath)
ShoppingDict["one2n"] = SimiDict['one2n']
for vpath in vpaths:
'''相机事件字典构造'''
CameraEvent = {"cameraType": '', # "front", "back"
"videoPath": '',
"imagePaths": [],
"yoloResnetTracker": [],
"tracking": [],
}
if isinstance(vpath, list):
CameraEvent["imagePaths"] = vpath
bname = os.path.basename(vpath[0])
if not isinstance(vpath, list):
CameraEvent["videoPath"] = vpath
bname = os.path.basename(vpath)
if bname.split('_')[0] == "0" or bname.find('back')>=0:
CameraEvent["cameraType"] = "back"
if bname.split('_')[0] == "1" or bname.find('front')>=0:
CameraEvent["cameraType"] = "front"
'''事件结果存储文件夹'''
if isinstance(vpath, list):
savepath_pipeline_imgs = savepath_pipeline / Path("images")
else:
savepath_pipeline_imgs = savepath_pipeline / Path(str(Path(vpath).stem))
if not savepath_pipeline_imgs.exists():
savepath_pipeline_imgs.mkdir(parents=True, exist_ok=True)
savepath_pipeline_subimgs = savepath_pipeline / Path("subimgs")
if not savepath_pipeline_subimgs.exists():
savepath_pipeline_subimgs.mkdir(parents=True, exist_ok=True)
'''Yolo + Resnet + Tracker'''
optdict["source"] = vpath
optdict["save_dir"] = savepath_pipeline_imgs
yrtOut = yolo_resnet_tracker(**optdict)
CameraEvent["yoloResnetTracker"] = yrtOut
# bboxes = np.empty((0, 9), dtype = np.float32)
# for frameDict in yrtOut:
# bboxes = np.concatenate([bboxes, frameDict["tboxes"]], axis=0)
trackerboxes = np.empty((0, 9), dtype=np.float64)
trackefeats = {}
for frameDict in yrtOut:
tboxes = frameDict["tboxes"]
ffeats = frameDict["feats"]
trackerboxes = np.concatenate((trackerboxes, np.array(tboxes)), axis=0)
for i in range(len(tboxes)):
fid, bid = int(tboxes[i, 7]), int(tboxes[i, 8])
trackefeats.update({f"{fid}_{bid}": ffeats[f"{fid}_{bid}"]})
'''tracking'''
if CameraEvent["cameraType"] == "back":
vts = doBackTracks(trackerboxes, trackefeats)
vts.classify()
event_tracks.append(("back", vts))
CameraEvent["tracking"] = vts
ShoppingDict["backCamera"] = CameraEvent
if CameraEvent["cameraType"] == "front":
vts = doFrontTracks(trackerboxes, trackefeats)
vts.classify()
event_tracks.append(("front", vts))
CameraEvent["tracking"] = vts
ShoppingDict["frontCamera"] = CameraEvent
with open(str(pf_path), 'wb') as f:
pickle.dump(ShoppingDict, f)
for CamerType, vts in event_tracks:
if len(vts.tracks)==0: continue
if CamerType == 'front':
yolos = ShoppingDict["frontCamera"]["yoloResnetTracker"]
ctype = 1
if CamerType == 'back':
yolos = ShoppingDict["backCamera"]["yoloResnetTracker"]
ctype = 0
imgdict = {}
for y in yolos:
imgdict.update(y["imgs"])
for track in vts.Residual:
if isinstance(track, np.ndarray):
save_subimgs(imgdict, track, savepath_pipeline_subimgs, ctype)
else:
save_subimgs(imgdict, track.boxes, savepath_pipeline_subimgs, ctype)
'''轨迹显示模块'''
illus = [None, None]
for CamerType, vts in event_tracks:
if len(vts.tracks)==0: continue
if CamerType == 'front':
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/board_ftmp_line.png")
h, w = edgeline.shape[:2]
# nh, nw = h//2, w//2
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
illus[0] = img_tracking
plt = plot_frameID_y2(vts)
plt.savefig(os.path.join(savepath_pipeline, "front_y2.png"))
if CamerType == 'back':
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/edgeline.png")
h, w = edgeline.shape[:2]
# nh, nw = h//2, w//2
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
illus[1] = img_tracking
illus = [im for im in illus if im is not None]
if len(illus):
img_cat = np.concatenate(illus, axis = 1)
if len(illus)==2:
H, W = img_cat.shape[:2]
cv2.line(img_cat, (int(W/2), 0), (int(W/2), int(H)), (128, 128, 255), 3)
trajpath = os.path.join(savepath_pipeline, "trajectory.png")
cv2.imwrite(trajpath, img_cat)
def main():
'''
函数pipeline(),遍历事件文件夹,选择类型 image 或 video,
'''
parmDict = {}
evtdir = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\images"
parmDict["SourceType"] = "video" # video, image
parmDict["savepath"] = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\result"
parmDict["weights"] = r'D:\DetectTracking\ckpts\best_cls10_0906.pt'
evtdir = Path(evtdir)
k, errEvents = 0, []
for item in evtdir.iterdir():
if item.is_dir():
# item = evtdir/Path("20241209-160201-b97f7a0e-7322-4375-9f17-c475500097e9_6926265317292")
parmDict["eventpath"] = item
# pipeline(**parmDict)
try:
pipeline(**parmDict)
except Exception as e:
errEvents.append(str(item))
k+=1
if k==1:
break
errfile = os.path.join(parmDict["savepath"], f'error_events.txt')
with open(errfile, 'w', encoding='utf-8') as f:
for line in errEvents:
f.write(line + '\n')
if __name__ == "__main__":
main()