176 lines
6.0 KiB
Python
176 lines
6.0 KiB
Python
import abc
|
||
# import os
|
||
# import pdb
|
||
# import pickle
|
||
import sys
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from tools.config import gvalue
|
||
|
||
sys.path.append('./ytracking')
|
||
sys.path.append('./contrast')
|
||
# from ytracking.tracking.dotrack import init_tracker, VideoTracks, boxes_add_fid
|
||
from ytracking.tracking.have_tracking import have_tracked
|
||
from ytracking.track_ import *
|
||
from contrast.logic import datacollection, similarityResult, similarity
|
||
from PIL import Image
|
||
from tools.config import gvalue
|
||
|
||
class AiInterface(metaclass=abc.ABCMeta):
|
||
@abc.abstractmethod
|
||
def getTrackingBox(self):
|
||
pass
|
||
|
||
@abc.abstractmethod
|
||
def getSimilarity(self):
|
||
pass
|
||
|
||
|
||
class AiClass(AiInterface):
|
||
def __init__(self):
|
||
pass
|
||
|
||
def get_xyxy_coordinates(self, box, frame_id_img):
|
||
"""
|
||
计算并返回边界框的坐标。
|
||
"""
|
||
try:
|
||
x1 = max(0, int(box[0]))
|
||
x2 = min(frame_id_img.shape[1], int(box[2]))
|
||
y1 = max(0, int(box[1]))
|
||
y2 = min(frame_id_img.shape[0], int(box[3]))
|
||
return x1, y1, x2, y2
|
||
except IndexError as e:
|
||
raise ValueError("边界框坐标超出图像尺寸") from e
|
||
|
||
def getTrackingBox(self, bboxes, features_dict, camera_id, frame_id_img, save_imgs_dir):
|
||
"""
|
||
根据提供的边界框和帧图像,返回图像列表和轨迹ID列表。
|
||
"""
|
||
|
||
image_lists = {}
|
||
track_id_list = []
|
||
|
||
gt = Profile()
|
||
with gt:
|
||
vts = have_tracked(bboxes, features_dict, camera_id)
|
||
|
||
nn = 0
|
||
for res in vts.Residual:
|
||
for box in res.boxes:
|
||
try:
|
||
box = [int(i) for i in box.tolist()]
|
||
print('box[7] >>>> {}'.format(box[7]))
|
||
x1, y1, x2, y2 = self.get_xyxy_coordinates(box, frame_id_img[box[7]])
|
||
gvalue.track_y_lists.append(y1)
|
||
c_img = frame_id_img[box[7]][y1:y2, x1:x2][:, :, ::-1]
|
||
|
||
# c_img = frame_id_img[box[7]][box[1]:box[3], box[0]:box[2]][:, :, ::-1]
|
||
img_pil = Image.fromarray(c_img.astype('uint8'), 'RGB')
|
||
|
||
img_pil.save(os.sep.join([save_imgs_dir, str(nn) + '.jpg']))
|
||
nn += 1
|
||
|
||
track_id = str(box[4])
|
||
track_id_list.append(track_id)
|
||
if track_id not in image_lists:
|
||
image_lists[track_id] = []
|
||
image_lists[track_id].append(img_pil)
|
||
except Exception as e:
|
||
print("y1: {}, y2: {}, x1:{} x2:{}".format(box[2], box[3], box[0], box[1]))
|
||
print("x:{}, y:{}".format(frame_id_img[box[7]].shape[1], frame_id_img[box[7]].shape[0]))
|
||
print(f"处理边界框时发生错误: {e}")
|
||
continue
|
||
|
||
all_image_list = list(image_lists.values())
|
||
trackIdList = list(set(track_id_list))
|
||
|
||
return all_image_list, trackIdList
|
||
|
||
@staticmethod
|
||
def process_topn_data(source_data):
|
||
if source_data is None:
|
||
return None
|
||
|
||
if not isinstance(source_data, dict):
|
||
raise ValueError("输入数据必须是字典类型")
|
||
|
||
if not source_data:
|
||
return {}
|
||
|
||
total = {}
|
||
carId_barcode_trackId_list = []
|
||
data_category = []
|
||
|
||
for category, category_data in source_data.items():
|
||
carId_barcode_trackId_list.append(category)
|
||
for car_id, similarity in category_data.items():
|
||
data_category.append({'carId_barcode_trackId_n': car_id, 'similarity': similarity})
|
||
|
||
total['carId_barcode_trackId'] = carId_barcode_trackId_list
|
||
total['data'] = data_category
|
||
|
||
return total
|
||
|
||
@staticmethod
|
||
def process_top10_data(source_data):
|
||
if source_data is None:
|
||
return None
|
||
|
||
if not isinstance(source_data, dict):
|
||
raise ValueError("输入数据必须是字典类型")
|
||
|
||
if not source_data:
|
||
return {}
|
||
|
||
total = {}
|
||
data_category = []
|
||
|
||
for category, category_data in source_data.items():
|
||
trackid = category.split('_')[-1]
|
||
barcode = category.split('_')[-2]
|
||
for car_id, similarity in category_data.items():
|
||
data_category.append({'barcode': car_id, 'similarity': similarity, 'trackid': trackid})
|
||
|
||
total['barcode'] = barcode
|
||
total['data'] = data_category
|
||
return total
|
||
|
||
def getSimilarity(self, model, queueImgs):
|
||
data_collection = datacollection()
|
||
similarityRes = similarityResult()
|
||
data_collection.barcode_flag = queueImgs['barcode_flag']
|
||
data_collection.add_flag = queueImgs['add_flag']
|
||
data_collection.barcode_list = queueImgs['barcode_list'].strip("'").split(',')
|
||
data_collection.queImgsDict = queueImgs
|
||
|
||
similarityRes = similarity().getSimilarity(model, data_collection, similarityRes)
|
||
# print('similarityRes.top10: ------------------ {}'.format(similarityRes.top10))
|
||
if similarityRes.top1:
|
||
similarityRes.top1 = {"barcode": list(similarityRes.top1.keys())[0],
|
||
"similarity": list(similarityRes.top1.values())[0]}
|
||
# similarityRes.tempLibList = gvalue.tempLibList
|
||
# print('-------------------------', gvalue.tempLibLists)
|
||
if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||
similarityRes.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
|
||
else:
|
||
similarityRes.tempLibList = []
|
||
similarityresult = {
|
||
'top10': AiClass.process_top10_data(similarityRes.top10),
|
||
'top1': similarityRes.top1,
|
||
'topn': AiClass.process_topn_data(similarityRes.topn),
|
||
'tempLibList': similarityRes.tempLibList,
|
||
'sequenceId': queueImgs['sequenceId'],
|
||
}
|
||
return similarityresult
|
||
|
||
|
||
if __name__ == '__main__':
|
||
AI = AiClass()
|
||
|
||
# track_boxes, frame_id_img = run()
|
||
# AI.getTrackingBox(track_boxes, frame_id_img)
|
||
# print('=== test ===')
|
||
# AI.getSimilarity(cfg.queueImgs)
|