This commit is contained in:
王庆刚
2024-11-25 18:05:08 +08:00
parent c47894ddc0
commit 8bbee310ba
109 changed files with 1003 additions and 305 deletions

View File

@ -60,53 +60,57 @@ from tracking.trackers import BOTSORT, BYTETracker
from tracking.utils.showtrack import drawtracks
from hands.hand_inference import hand_pose
from tracking.trackers.reid.reid_interface import ReIDInterface
from tracking.trackers.reid.config import config as ReIDConfig
ReIDEncoder = ReIDInterface(ReIDConfig)
from contrast.feat_extract.config import config as conf
from contrast.feat_extract.inference import FeatsInterface
ReIDEncoder = FeatsInterface(conf)
# from tracking.trackers.reid.reid_interface import ReIDInterface
# from tracking.trackers.reid.config import config as ReIDConfig
# ReIDEncoder = ReIDInterface(ReIDConfig)
# tracker_yaml = r"./tracking/trackers/cfg/botsort.yaml"
def inference_image(image, detections):
H, W, _ = np.shape(image)
imgs = []
batch_patches = []
patches = []
for d in range(np.size(detections, 0)):
tlbr = detections[d, :4].astype(np.int_)
tlbr[0] = max(0, tlbr[0])
tlbr[1] = max(0, tlbr[1])
tlbr[2] = min(W - 1, tlbr[2])
tlbr[3] = min(H - 1, tlbr[3])
img1 = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
# def inference_image(image, detections):
# H, W, _ = np.shape(image)
# imgs = []
# batch_patches = []
# patches = []
# for d in range(np.size(detections, 0)):
# tlbr = detections[d, :4].astype(np.int_)
# tlbr[0] = max(0, tlbr[0])
# tlbr[1] = max(0, tlbr[1])
# tlbr[2] = min(W - 1, tlbr[2])
# tlbr[3] = min(H - 1, tlbr[3])
# img1 = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
img = img1[:, :, ::-1].copy() # the model expects RGB inputs
patch = ReIDEncoder.transform(img)
# img = img1[:, :, ::-1].copy() # the model expects RGB inputs
# patch = ReIDEncoder.transform(img)
imgs.append(img1)
# patch = patch.to(device=self.device).half()
if str(ReIDEncoder.device) != "cpu":
patch = patch.to(device=ReIDEncoder.device).half()
else:
patch = patch.to(device=ReIDEncoder.device)
# imgs.append(img1)
# # patch = patch.to(device=self.device).half()
# if str(ReIDEncoder.device) != "cpu":
# patch = patch.to(device=ReIDEncoder.device).half()
# else:
# patch = patch.to(device=ReIDEncoder.device)
patches.append(patch)
if (d + 1) % ReIDEncoder.batch_size == 0:
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
patches = []
# patches.append(patch)
# if (d + 1) % ReIDEncoder.batch_size == 0:
# patches = torch.stack(patches, dim=0)
# batch_patches.append(patches)
# patches = []
if len(patches):
patches = torch.stack(patches, dim=0)
batch_patches.append(patches)
# if len(patches):
# patches = torch.stack(patches, dim=0)
# batch_patches.append(patches)
features = np.zeros((0, ReIDEncoder.embedding_size))
for patches in batch_patches:
pred = ReIDEncoder.model(patches)
pred[torch.isinf(pred)] = 1.0
feat = pred.cpu().data.numpy()
features = np.vstack((features, feat))
# features = np.zeros((0, ReIDEncoder.embedding_size))
# for patches in batch_patches:
# pred = ReIDEncoder.model(patches)
# pred[torch.isinf(pred)] = 1.0
# feat = pred.cpu().data.numpy()
# features = np.vstack((features, feat))
return imgs, features
# return imgs, features
@ -127,6 +131,7 @@ def init_trackers(tracker_yaml = None, bs=1):
return trackers
'''=============== used in pipeline.py =================='''
@smart_inference_mode()
def yolo_resnet_tracker(
weights=ROOT / 'yolov5s.pt', # model path or triton URL
@ -237,7 +242,9 @@ def yolo_resnet_tracker(
'''================== 1. 存储 dets/subimgs/features Dict ============='''
imgs, features = inference_image(im0, tracks)
imgs, features = ReIDEncoder.inference(im0, tracks)
# imgs, features = inference_image(im0, tracks)
# TrackerFeats = np.concatenate([TrackerFeats, features], axis=0)
@ -499,7 +506,8 @@ def run(
tracks[:, 7] = frameId
'''================== 1. 存储 dets/subimgs/features Dict ============='''
imgs, features = inference_image(im0, tracks)
# imgs, features = inference_image(im0, tracks)
imgs, features = ReIDEncoder.inference(im0, tracks)
TrackerFeats = np.concatenate([TrackerFeats, features], axis=0)
@ -681,32 +689,17 @@ def main(opt):
optdict = vars(opt)
p = r"D:\datasets\ym"
p = r"D:\datasets\ym\exhibition\153112511_0_seek_105.mp4"
p = r"D:\exhibition\images\153112511_0_seek_105.mp4"
optdict["project"] = r"D:\exhibition\result"
files = []
k = 0
if os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))
for file in files:
optdict["source"] = file
run(**optdict)
k += 1
if k == 1:
break
optdict["source"] = files
elif os.path.isfile(p):
optdict["source"] = p
run(**vars(opt))
def main_imgdir(opt):
check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
optdict = vars(opt)
optdict["project"] = r"\\192.168.1.28\share\realtime"
optdict["source"] = r"\\192.168.1.28\share\realtime\addReturn\add\1728978052624"
run(**optdict)
@ -745,7 +738,7 @@ def main_loop(opt):
# break
elif os.path.isfile(p):
optdict["source"] = p
run(**vars(opt))
run(**optdict)
@ -754,7 +747,6 @@ if __name__ == '__main__':
opt = parse_opt()
main(opt)
# main_imgdir(opt)
# main_loop(opt)