modified for site test

This commit is contained in:
王庆刚
2024-07-18 17:52:12 +08:00
parent f90ef72cbf
commit e986ec060b
39 changed files with 2279 additions and 375 deletions

View File

@ -119,12 +119,14 @@ class BOTSORT(BYTETracker):
"""Returns an instance of KalmanFilterXYWH for object tracking."""
return KalmanFilterXYWH()
def init_track(self, dets, scores, cls, imgs):
def init_track(self, dets, scores, cls, imgs, features_keep):
"""Initialize track with detections, scores, and classes."""
if len(dets) == 0:
return []
if self.args.with_reid and self.encoder is not None:
features_keep = self.encoder.inference(imgs, dets)
if features_keep is None:
features_keep = self.encoder.inference(imgs, dets)
return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
else:
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections

View File

@ -18,7 +18,12 @@ def dists_update(dists, strack_pool, detections):
blabel = np.array([int(stack.cls) for stack in detections])
amlabel = np.expand_dims(alabel, axis=1).repeat(len(detections),axis=1)
bmlabel = np.expand_dims(blabel, axis=0).repeat(len(strack_pool),axis=0)
dist_label = 1 - (bmlabel == amlabel)
mlabel = bmlabel == amlabel
iou_dist = matching.iou_distance(strack_pool, detections) > 0.1 #boxes iou>0.9时,可以不考虑类别
dist_label = (1 - mlabel) & iou_dist # 不同类,且不是严格重叠,需考虑类别距离
dist_label = 1 - mlabel
dists = np.where(dists > dist_label, dists, dist_label)
return dists
@ -103,6 +108,7 @@ class STrack(BaseTrack):
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.first_find = False
self.frame_id = frame_id
if new_id:
self.track_id = self.next_id()
@ -127,6 +133,7 @@ class STrack(BaseTrack):
self.convert_coords(new_tlwh))
self.state = TrackState.Tracked
self.is_activated = True
self.first_find = False
self.score = new_track.score
self.cls = new_track.cls
@ -207,7 +214,7 @@ class BYTETracker:
self.args.new_track_thresh = 0.5
def update(self, results, img=None):
def update(self, results, img=None, features=None):
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
self.frame_id += 1
activated_stracks = []
@ -240,7 +247,7 @@ class BYTETracker:
cls_keep = cls[remain_inds]
cls_second = cls[inds_second]
detections = self.init_track(dets, scores_keep, cls_keep, img)
detections = self.init_track(dets, scores_keep, cls_keep, img, features)
# Add newly detected tracklets to tracked_stracks
unconfirmed = []
@ -283,7 +290,7 @@ class BYTETracker:
# Step 3: Second association, with low score detection boxes
# association the untrack to the low score detections
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
detections_second = self.init_track(dets_second, scores_second, cls_second, img, features)
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
# TODO
@ -366,7 +373,7 @@ class BYTETracker:
output2 = [x.tlwh_to_tlbr(x._tlwh).tolist() + [x.track_id, x.score, x.cls, x.frame_id, x.idx]
for x in first_finded if x.first_find]
output = np.asarray(output1+output2, dtype=np.float32)
output = np.asarray(output1 + output2, dtype=np.float32)
return output
@ -382,7 +389,7 @@ class BYTETracker:
tracks = []
feats = []
for t in self.tracked_stracks:
if t.is_activated:
if t.is_activated or t.first_find:
track = t.tlbr.tolist() + [t.track_id, t.score, t.cls, t.idx]
feat = t.curr_feature
@ -398,7 +405,7 @@ class BYTETracker:
"""Returns a Kalman filter object for tracking bounding boxes."""
return KalmanFilterXYAH()
def init_track(self, dets, scores, cls, img=None):
def init_track(self, dets, scores, cls, img=None, feats=None):
"""Initialize object tracking with detections and scores using STrack algorithm."""
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
@ -455,7 +462,22 @@ class BYTETracker:
def remove_duplicate_stracks(stracksa, stracksb):
"""Remove duplicate stracks with non-maximum IOU distance."""
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
#### ===================================== written by WQG
mlabel = []
if len(stracksa) and len(stracksb):
alabel = np.array([int(stack.cls) for stack in stracksa])
blabel = np.array([int(stack.cls) for stack in stracksb])
amlabel = np.expand_dims(alabel, axis=1).repeat(len(stracksb),axis=1)
bmlabel = np.expand_dims(blabel, axis=0).repeat(len(stracksa),axis=0)
mlabel = bmlabel == amlabel
if len(mlabel):
condt = (pdist<0.15) & mlabel # 需满足iou足够小且类别相同才予以排除
else:
condt = pdist<0.15
pairs = np.where(condt)
dupa, dupb = [], []
for p, q in zip(*pairs):
timep = stracksa[p].frame_id - stracksa[p].start_frame

View File

@ -45,8 +45,7 @@ class ReIDInterface:
])
self.model = nn.DataParallel(model).to(self.device)
# self.model = nn.DataParallel(model).to(self.device)
self.model = model
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
self.model.eval()