Files
detecttracking/tracking/dotrack/dotracks_back.py
2025-04-11 17:02:39 +08:00

277 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Mon Mar 4 18:36:31 2024
@author: ym
"""
import numpy as np
import cv2
import copy
import sys
from pathlib import Path
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from tracking.utils.mergetrack import track_equal_track
from scipy.spatial.distance import cdist
curpath = Path(__file__).resolve().parents[0]
curpath = Path(curpath)
parpath = curpath.parent
from .dotracks import doTracks, ShoppingCart
from .track_back import backTrack
class doBackTracks(doTracks):
def __init__(self, bboxes, trackefeats):
super().__init__(bboxes, trackefeats)
self.tracks = [backTrack(b, f) for b, f in zip(self.lboxes, self.lfeats)]
# self.similar_dict = self.similarity()
# self.shopcart = ShoppingCart(bboxes)
self.incart = self.getincart()
def getincart(self):
img1 = cv2.imread(str(parpath/'shopcart/cart_tempt/incart.png'), cv2.IMREAD_GRAYSCALE)
img2 = cv2.imread(str(parpath/'shopcart/cart_tempt/cartedge.png'), cv2.IMREAD_GRAYSCALE)
ret, binary1 = cv2.threshold(img1, 250, 255, cv2.THRESH_BINARY)
ret, binary2 = cv2.threshold(img2, 250, 255, cv2.THRESH_BINARY)
binary = cv2.bitwise_or(binary1, binary2)
return binary
def classify(self):
'''功能:对 tracks 中元素分类 '''
tracks = self.tracks
# 提取手的frame_id并和动目标的frame_id 进行关联
hand_tracks = [t for t in tracks if t.cls==0]
self.Hands.extend(hand_tracks)
tracks = self.sub_tracks(tracks, hand_tracks)
# 提取小孩的track并计算状态left, right, incart
kid_tracks = [t for t in tracks if t.cls==9]
kid_states = [self.kid_state(t) for t in kid_tracks]
self.Kids = [x for x in zip(kid_tracks, kid_states)]
tracks = self.sub_tracks(tracks, kid_tracks)
out_trcak = [t for t in tracks if t.isWholeOutCart]
tracks = self.sub_tracks(tracks, out_trcak)
static_tracks = [t for t in tracks if t.frnum>1 and t.is_static()]
self.Static.extend(static_tracks)
'''剔除静止目标后的 tracks'''
tracks = self.sub_tracks(tracks, static_tracks)
tracks_free = [t for t in tracks if t.frnum>1 and t.is_freemove()]
self.FreeMove.extend(tracks_free)
tracks = self.sub_tracks(tracks, tracks_free)
# '''购物框边界外具有运动状态的干扰目标'''
# out_trcak = [t for t in tracks if t.is_OutTrack()]
# tracks = self.sub_tracks(tracks, out_trcak)
'''轨迹循环归并'''
# merged_tracks = self.merge_tracks(tracks)
merged_tracks = self.merge_tracks_loop(tracks)
[self.associate_with_hand(htrack, gtrack) for htrack in hand_tracks for gtrack in tracks]
tracks = [t for t in merged_tracks if t.frnum > 1]
self.merged_tracks = merged_tracks
static_tracks = [t for t in tracks if t.frnum>1 and t.is_static()]
self.Static.extend(static_tracks)
tracks = self.sub_tracks(tracks, static_tracks)
# for gtrack in tracks:
# for htrack in hand_tracks:
# hand_ious = self.associate_with_hand(htrack, gtrack)
# if len(hand_ious):
# gtrack.Hands.append(htrack)
# gtrack.HandsIou.append(hand_ious)
# htrack.Goods.append((gtrack, hand_ious))
# for htrack in hand_tracks:
# self.merge_based_hands(htrack)
self.Residual = tracks
self.Confirmed = self.confirm_track()
def confirm_track(self):
Confirmed = None
mindist = 0
for track in self.Residual:
md = min(track.trajrects_wh)
if md > mindist:
mindist = copy.deepcopy(md)
Confirmed = copy.deepcopy(track)
if Confirmed is not None:
return [Confirmed]
return []
# def merge_based_hands(self, htrack):
# gtracks = htrack.Goods
# if len(gtracks) >= 2:
# atrack, afious = gtracks[0]
# btrack, bfious = gtracks[1]
def associate_with_hand(self, htrack, gtrack):
'''
迁移至基类:
手部 Track、商品 Track 建立关联的依据:
a. 运动帧的帧索引有交集
b. 帧索引交集部分iou均大于0
'''
assert htrack.cls==0 and gtrack.cls!=0 and gtrack.cls!=9, 'Track cls is Error!'
hand_ious = []
hboxes = np.empty(shape=(0, 9), dtype = np.float64)
gboxes = np.empty(shape=(0, 9), dtype = np.float64)
# start, end 为索引值,需要 start:(end+1)
for start, end in htrack.moving_index:
hboxes = np.concatenate((hboxes, htrack.boxes[start:end+1, :]), axis=0)
for start, end in gtrack.moving_index:
gboxes = np.concatenate((gboxes, gtrack.boxes[start:end+1, :]), axis=0)
hfids, gfids = hboxes[:, 7], gboxes[:, 7]
fids = sorted(set(hfids).intersection(set(gfids)))
if len(fids)==0:
return None
# print(f"Goods ID: {gtrack.tid}, Hand ID: {htrack.tid}")
for f in fids:
h = np.where(hboxes[:,7] == f)[0][0]
g = np.where(gboxes[:,7] == f)[0][0]
x11, y11, x12, y12 = hboxes[h, 0:4]
x21, y21, x22, y22 = gboxes[g, 0:4]
x1, y1 = max((x11, x21)), max((y11, y21))
x2, y2 = min((x12, x22)), min((y12, y22))
union = (x2 - x1).clip(0) * (y2 - y1).clip(0)
area1 = (x12 - x11) * (y12 - y11)
area2 = (x22 - x21) * (y22 - y21)
iou = union / (area1 + area2 - union + 1e-6)
if iou >= 0.01:
gtrack.Hands.append((htrack.tid, f, iou))
return gtrack.Hands
def merge_tracks(self, Residual):
"""
对不同id但可能是同一商品的目标进行归并
和 dotrack_front.py中函数相同可以合并可以合并至基类
"""
mergedTracks = self.base_merge_tracks(Residual)
oldtracks, newtracks = [], []
for tracklist in mergedTracks:
if len(tracklist) > 1:
boxes = np.empty((0, 9), dtype=np.float32)
feats = np.empty((0, 256), dtype=np.float32)
for i, track in enumerate(tracklist):
if i==0: ntid, ncls=track.boxes[0, 4], track.boxes[0, 6]
iboxes = track.boxes.copy()
ifeats = track.features.copy()
# iboxes[:, 4], iboxes[:, 6] = ntid, ncls
boxes = np.concatenate((boxes, iboxes), axis=0)
feats = np.concatenate((feats, ifeats), axis=0)
oldtracks.append(track)
fid_indices = np.argsort(boxes[:, 7])
boxes_fid = boxes[fid_indices]
feats_fid = feats[fid_indices]
newtracks.append(backTrack(boxes_fid, feats_fid))
elif len(tracklist) == 1:
oldtracks.append(tracklist[0])
newtracks.append(tracklist[0])
redu = self.sub_tracks(Residual, oldtracks)
merged = self.join_tracks(redu, newtracks)
return merged
def kid_state(self, track):
left_dist = track.cornpoints[:, 2]
right_dist = 1024 - track.cornpoints[:, 4]
if np.sum(left_dist<30)/track.frnum>0.8 and np.sum(right_dist>512)/track.frnum>0.7:
kidstate = "left"
elif np.sum(left_dist>512)/track.frnum>0.7 and np.sum(right_dist<30)/track.frnum>0.8:
kidstate = "right"
else:
kidstate = "incart"
return kidstate
def isuptrack(self, track):
Flag = False
return Flag
def isdowntrack(self, track):
Flag = False
return Flag
def isfreetrack(self, track):
Flag = False
return Flag