更新 detacttracking
This commit is contained in:
147
detecttracking/stream_pipeline.py
Normal file
147
detecttracking/stream_pipeline.py
Normal file
@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tuesday Jan 14 2025
|
||||
|
||||
@author: liujiawei
|
||||
|
||||
@description: 读取网络图片,并优化轨迹,截取子图
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# from pipeline import pipeline
|
||||
from detecttracking.tracking import traclus as tr
|
||||
# from track_reid import parse_opt
|
||||
from detecttracking.track_reid import yolo_resnet_tracker
|
||||
from detecttracking.tracking.dotrack.dotracks_back import doBackTracks
|
||||
from PIL import Image
|
||||
|
||||
def save_event_subimgs(imgs, bboxes):
|
||||
img_list = {}
|
||||
for i, box in enumerate(bboxes):
|
||||
x1, y1, x2, y2, tid, score, cls, fid, bid = box
|
||||
|
||||
img_list[int(fid)] = imgs[fid][int(y1):int(y2), int(x1):int(x2), :]
|
||||
|
||||
return img_list
|
||||
|
||||
|
||||
def get_optimized_bboxes(event_tracks):
|
||||
vts_back = event_tracks
|
||||
points = []
|
||||
labels = []
|
||||
for track in vts_back.Residual:
|
||||
for ele in track.boxes:
|
||||
points.append([int(ele[2]), int(ele[3])])
|
||||
labels.append(int(ele[4])) # track_id
|
||||
points = np.array(points)
|
||||
|
||||
partitions, indices = tr.partition(points, progress_bar=False, w_perpendicular=100, w_angular=10)
|
||||
|
||||
bboxes_opt = []
|
||||
for track in vts_back.Residual:
|
||||
for i in indices:
|
||||
if i >= len(track.boxes): continue
|
||||
if labels[i] == track.boxes[i][4]:
|
||||
bboxes_opt.append(track.boxes[i])
|
||||
|
||||
return bboxes_opt
|
||||
|
||||
def get_tracking_info(
|
||||
vpath,
|
||||
resnetModel,
|
||||
yoloModel,
|
||||
SourceType = "video", # video
|
||||
stdfeat_path = None
|
||||
):
|
||||
optdict = {}
|
||||
|
||||
optdict["weights"] = './detecttracking/tracking/ckpts/best_cls10_0906.pt'
|
||||
optdict["yoloModel"] = yoloModel
|
||||
optdict["resnetModel"] = resnetModel
|
||||
optdict["is_save_img"] = False
|
||||
optdict["is_save_video"] = False
|
||||
|
||||
event_tracks = []
|
||||
video_frames = {}
|
||||
|
||||
'''Yolo + Resnet + Tracker'''
|
||||
optdict["source"] = vpath
|
||||
optdict["video_frames"] = video_frames
|
||||
optdict["is_annotate"] = False
|
||||
|
||||
yrtOut = yolo_resnet_tracker(**optdict)
|
||||
|
||||
trackerboxes = np.empty((0, 9), dtype=np.float64)
|
||||
trackefeats = {}
|
||||
for frameDict in yrtOut:
|
||||
tboxes = frameDict["tboxes"]
|
||||
ffeats = frameDict["feats"]
|
||||
|
||||
trackerboxes = np.concatenate((trackerboxes, np.array(tboxes)), axis=0)
|
||||
for i in range(len(tboxes)):
|
||||
fid, bid = int(tboxes[i, 7]), int(tboxes[i, 8])
|
||||
trackefeats.update({f"{fid}_{bid}": ffeats[f"{fid}_{bid}"]})
|
||||
|
||||
|
||||
vts = doBackTracks(trackerboxes, trackefeats)
|
||||
vts.classify()
|
||||
event_tracks.append(("back", vts))
|
||||
|
||||
return event_tracks, video_frames
|
||||
|
||||
def stream_pipeline(stream_dict, resnetModel, yoloModel):
|
||||
parmDict = {}
|
||||
parmDict["vpath"] = stream_dict["video"]
|
||||
|
||||
# parmDict["savepath"] = os.path.join('pipeline_output', info_dict["barcode"])
|
||||
parmDict["SourceType"] = "video" # video, image
|
||||
parmDict["stdfeat_path"] = None
|
||||
|
||||
event_tracks, video_frames = get_tracking_info(**parmDict, resnetModel=resnetModel, yoloModel=yoloModel)
|
||||
bboxes_opt = get_optimized_bboxes(event_tracks[0][1])
|
||||
subimg_dict = save_event_subimgs(video_frames, bboxes_opt)
|
||||
|
||||
sub_images = []
|
||||
for fid, img in subimg_dict.items():
|
||||
pil_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
sub_images.append(pil_image)
|
||||
|
||||
return sub_images
|
||||
|
||||
def main():
|
||||
'''
|
||||
sample stream_dict:
|
||||
'''
|
||||
stream_dict = {
|
||||
"goodsName" : "优诺优丝黄桃果粒风味发酵乳",
|
||||
"measureProperty" : 0,
|
||||
"qty" : 1,
|
||||
"price" : 25.9,
|
||||
"weight": 560, # 单位克
|
||||
"barcode": "6931806801024",
|
||||
"video" : "https://ieemoo-ai.obs.cn-east-3.myhuaweicloud.com/videos/20231009/04/04_20231009-082149_21f2ca35-f2c2-4386-8497-3e7a3b407f03_4901872831197.mp4",
|
||||
"goodsPic" : "https://ieemoo-storage.obs.cn-east-3.myhuaweicloud.com/lhpic/6931806801024.jpg",
|
||||
"measureUnit" : "组",
|
||||
"goodsSpec" : "405g"
|
||||
}
|
||||
subimg_list = stream_pipeline(stream_dict)
|
||||
save_path = os.path.join('subimg', stream_dict["barcode"])
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
else:
|
||||
for filename in os.listdir(save_path):
|
||||
file_path = os.path.join(save_path, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.unlink(file_path)
|
||||
|
||||
for i, img in enumerate(subimg_list):
|
||||
img.save(f'{save_path}/frame_{i}.jpg')
|
||||
|
||||
print(f'Finish crop subimages {stream_dict["barcode"]}!')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user