modified for site test
This commit is contained in:
Binary file not shown.
Binary file not shown.
@ -119,12 +119,14 @@ class BOTSORT(BYTETracker):
|
||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
||||
return KalmanFilterXYWH()
|
||||
|
||||
def init_track(self, dets, scores, cls, imgs):
|
||||
def init_track(self, dets, scores, cls, imgs, features_keep):
|
||||
"""Initialize track with detections, scores, and classes."""
|
||||
if len(dets) == 0:
|
||||
return []
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
features_keep = self.encoder.inference(imgs, dets)
|
||||
if features_keep is None:
|
||||
features_keep = self.encoder.inference(imgs, dets)
|
||||
|
||||
return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
|
||||
else:
|
||||
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
||||
|
@ -18,7 +18,12 @@ def dists_update(dists, strack_pool, detections):
|
||||
blabel = np.array([int(stack.cls) for stack in detections])
|
||||
amlabel = np.expand_dims(alabel, axis=1).repeat(len(detections),axis=1)
|
||||
bmlabel = np.expand_dims(blabel, axis=0).repeat(len(strack_pool),axis=0)
|
||||
dist_label = 1 - (bmlabel == amlabel)
|
||||
|
||||
mlabel = bmlabel == amlabel
|
||||
iou_dist = matching.iou_distance(strack_pool, detections) > 0.1 #boxes iou>0.9时,可以不考虑类别
|
||||
dist_label = (1 - mlabel) & iou_dist # 不同类,且不是严格重叠,需考虑类别距离
|
||||
|
||||
dist_label = 1 - mlabel
|
||||
dists = np.where(dists > dist_label, dists, dist_label)
|
||||
return dists
|
||||
|
||||
@ -103,6 +108,7 @@ class STrack(BaseTrack):
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.first_find = False
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
@ -127,6 +133,7 @@ class STrack(BaseTrack):
|
||||
self.convert_coords(new_tlwh))
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.first_find = False
|
||||
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
@ -207,7 +214,7 @@ class BYTETracker:
|
||||
self.args.new_track_thresh = 0.5
|
||||
|
||||
|
||||
def update(self, results, img=None):
|
||||
def update(self, results, img=None, features=None):
|
||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
||||
self.frame_id += 1
|
||||
activated_stracks = []
|
||||
@ -240,7 +247,7 @@ class BYTETracker:
|
||||
cls_keep = cls[remain_inds]
|
||||
cls_second = cls[inds_second]
|
||||
|
||||
detections = self.init_track(dets, scores_keep, cls_keep, img)
|
||||
detections = self.init_track(dets, scores_keep, cls_keep, img, features)
|
||||
|
||||
# Add newly detected tracklets to tracked_stracks
|
||||
unconfirmed = []
|
||||
@ -283,7 +290,7 @@ class BYTETracker:
|
||||
|
||||
# Step 3: Second association, with low score detection boxes
|
||||
# association the untrack to the low score detections
|
||||
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
|
||||
detections_second = self.init_track(dets_second, scores_second, cls_second, img, features)
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
|
||||
# TODO
|
||||
@ -366,7 +373,7 @@ class BYTETracker:
|
||||
output2 = [x.tlwh_to_tlbr(x._tlwh).tolist() + [x.track_id, x.score, x.cls, x.frame_id, x.idx]
|
||||
for x in first_finded if x.first_find]
|
||||
|
||||
output = np.asarray(output1+output2, dtype=np.float32)
|
||||
output = np.asarray(output1 + output2, dtype=np.float32)
|
||||
|
||||
return output
|
||||
|
||||
@ -382,7 +389,7 @@ class BYTETracker:
|
||||
tracks = []
|
||||
feats = []
|
||||
for t in self.tracked_stracks:
|
||||
if t.is_activated:
|
||||
if t.is_activated or t.first_find:
|
||||
track = t.tlbr.tolist() + [t.track_id, t.score, t.cls, t.idx]
|
||||
feat = t.curr_feature
|
||||
|
||||
@ -398,7 +405,7 @@ class BYTETracker:
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
return KalmanFilterXYAH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
def init_track(self, dets, scores, cls, img=None, feats=None):
|
||||
"""Initialize object tracking with detections and scores using STrack algorithm."""
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
@ -455,7 +462,22 @@ class BYTETracker:
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""Remove duplicate stracks with non-maximum IOU distance."""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
|
||||
#### ===================================== written by WQG
|
||||
mlabel = []
|
||||
if len(stracksa) and len(stracksb):
|
||||
alabel = np.array([int(stack.cls) for stack in stracksa])
|
||||
blabel = np.array([int(stack.cls) for stack in stracksb])
|
||||
amlabel = np.expand_dims(alabel, axis=1).repeat(len(stracksb),axis=1)
|
||||
bmlabel = np.expand_dims(blabel, axis=0).repeat(len(stracksa),axis=0)
|
||||
mlabel = bmlabel == amlabel
|
||||
if len(mlabel):
|
||||
condt = (pdist<0.15) & mlabel # 需满足iou足够小,且类别相同,才予以排除
|
||||
else:
|
||||
condt = pdist<0.15
|
||||
|
||||
|
||||
pairs = np.where(condt)
|
||||
dupa, dupb = [], []
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
|
Binary file not shown.
Binary file not shown.
@ -45,8 +45,7 @@ class ReIDInterface:
|
||||
])
|
||||
|
||||
|
||||
self.model = nn.DataParallel(model).to(self.device)
|
||||
|
||||
# self.model = nn.DataParallel(model).to(self.device)
|
||||
self.model = model
|
||||
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
self.model.eval()
|
||||
|
Reference in New Issue
Block a user