# -*- coding: utf-8 -*- """ Created on Mon Mar 4 18:38:20 2024 @author: ym """ import numpy as np from utils.mergetrack import track_equal_track from .dotracks import doTracks from .track_front import frontTrack class doFrontTracks(doTracks): def __init__(self, bboxes, TracksDict): super().__init__(bboxes, TracksDict) self.tracks = [frontTrack(b) for b in self.lboxes] def classify(self): '''功能:对 tracks 中元素分类 ''' tracks = self.tracks '''提取手的 tracks''' hand_tracks = [t for t in tracks if t.cls==0] self.Hands.extend(hand_tracks) tracks = self.sub_tracks(tracks, hand_tracks) '''提取小孩的 tracks''' kid_tracks = [t for t in tracks if t.cls==9] tracks = self.sub_tracks(tracks, kid_tracks) '''静态 tracks''' static_tracks = [t for t in tracks if t.frnum>1 and t.is_static()] '''剔除静止目标后的 tracks''' tracks = self.sub_tracks(tracks, static_tracks) '''轨迹循环归并''' merged_tracks = self.merge_tracks_loop(tracks) tracks = [t for t in merged_tracks if t.frnum > 1] for gtrack in tracks: # print(f"Goods ID:{gtrack.tid}") for htrack in hand_tracks: if self.is_associate_with_hand(htrack, gtrack): gtrack.hands.append(htrack) freemoved_tracks = [t for t in tracks if t.is_free_move()] tracks = self.sub_tracks(tracks, freemoved_tracks) self.Residual = tracks def is_associate_with_hand(self, htrack, gtrack): '''手部 Track、商品 Track 建立关联的依据: a. 运动帧的帧索引有交集 b. 帧索引交集部分iou均大于0 ''' assert htrack.cls==0 and gtrack.cls!=0 and gtrack.cls!=9, 'Track cls is Error!' hboxes = np.empty(shape=(0, 9), dtype = np.float) gboxes = np.empty(shape=(0, 9), dtype = np.float) # start, end 为索引值,需要 start:(end+1) for start, end in htrack.dynamic_y2: hboxes = np.concatenate((hboxes, htrack.boxes[start:end+1, :]), axis=0) for start, end in gtrack.dynamic_y1: gboxes = np.concatenate((gboxes, gtrack.boxes[start:end+1, :]), axis=0) hfids, gfids = hboxes[:, 7], gboxes[:, 7] fids = set(hfids).intersection(set(gfids)) if len(fids)==0: return False # print(f"Goods ID: {gtrack.tid}, Hand ID: {htrack.tid}") ious = [] for f in fids: h = np.where(hfids==f)[0][0] g = np.where(gfids==f)[0][0] x11, y11, x12, y12 = hboxes[h, 0:4] x21, y21, x22, y22 = gboxes[g, 0:4] x1, y1 = max((x11, x21)), max((y11, y21)) x2, y2 = min((x12, x22)), min((y12, y22)) union = (x2 - x1).clip(0) * (y2 - y1).clip(0) area1 = (x12 - x11) * (y12 - y11) area2 = (x22 - x21) * (y22 - y21) iou = union / (area1 + area2 - union + 1e-6) if iou>0: ious.append(iou) return len(ious) def merge_tracks(self, Residual): """ 对不同id,但可能是同一商品的目标进行归并 """ # ============================================================================= # mergedTracks = [] # alist = [t for t in Residual] # while alist: # atrack = alist[0] # cur_list = [] # cur_list.append(atrack) # alist.pop(0) # # blist = [b for b in alist] # alist = [] # for btrack in blist: # if track_equal_track(atrack, btrack, self.TracksDict): # cur_list.append(btrack) # else: # alist.append(btrack) # # mergedTracks.append(cur_list) # ============================================================================= mergedTracks = self.base_merge_tracks(Residual) oldtracks, newtracks = [], [] for tracklist in mergedTracks: if len(tracklist) > 1: boxes = np.empty((0, 9), dtype=np.float32) for i, track in enumerate(tracklist): if i==0: ntid, ncls=track.boxes[0, 4], track.boxes[0, 6] iboxes = track.boxes.copy() iboxes[:, 4], iboxes[:, 6] = ntid, ncls boxes = np.concatenate((boxes, iboxes), axis=0) oldtracks.append(track) fid_indices = np.argsort(boxes[:, 7]) boxes_fid = boxes[fid_indices] newtracks.append(frontTrack(boxes_fid)) elif len(tracklist) == 1: oldtracks.append(tracklist[0]) newtracks.append(tracklist[0]) redu = self.sub_tracks(Residual, oldtracks) merged = self.join_tracks(redu, newtracks) return merged # ============================================================================= # def array2list(self): # ''' # 将 bboxes 变换为 track 列表 # bboxes: [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index] # Return: # lboxes:列表,列表中元素具有同一 track_id,x1y1x2y2 格式 # [x1, y1, x2, y2, track_id, score, cls, frame_index, box_index] # ''' # track_ids = set(self.bboxes[:, 4]) # lboxes = [] # for t_id in track_ids: # # print(f"The ID is: {t_id}") # idx = np.where(self.bboxes[:, 4] == t_id)[0] # box = self.bboxes[idx, :] # # lboxes.append(box) # # return lboxes # =============================================================================