update
This commit is contained in:
175
tools/Interface.py
Normal file
175
tools/Interface.py
Normal file
@ -0,0 +1,175 @@
|
||||
import abc
|
||||
# import os
|
||||
# import pdb
|
||||
# import pickle
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tools.config import gvalue
|
||||
|
||||
sys.path.append('./ytracking')
|
||||
sys.path.append('./contrast')
|
||||
# from ytracking.tracking.dotrack import init_tracker, VideoTracks, boxes_add_fid
|
||||
from ytracking.tracking.have_tracking import have_tracked
|
||||
from ytracking.track_ import *
|
||||
from contrast.logic import datacollection, similarityResult, similarity
|
||||
from PIL import Image
|
||||
from tools.config import gvalue
|
||||
|
||||
class AiInterface(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def getTrackingBox(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def getSimilarity(self):
|
||||
pass
|
||||
|
||||
|
||||
class AiClass(AiInterface):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_xyxy_coordinates(self, box, frame_id_img):
|
||||
"""
|
||||
计算并返回边界框的坐标。
|
||||
"""
|
||||
try:
|
||||
x1 = max(0, int(box[0]))
|
||||
x2 = min(frame_id_img.shape[1], int(box[2]))
|
||||
y1 = max(0, int(box[1]))
|
||||
y2 = min(frame_id_img.shape[0], int(box[3]))
|
||||
return x1, y1, x2, y2
|
||||
except IndexError as e:
|
||||
raise ValueError("边界框坐标超出图像尺寸") from e
|
||||
|
||||
def getTrackingBox(self, bboxes, features_dict, camera_id, frame_id_img, save_imgs_dir):
|
||||
"""
|
||||
根据提供的边界框和帧图像,返回图像列表和轨迹ID列表。
|
||||
"""
|
||||
|
||||
image_lists = {}
|
||||
track_id_list = []
|
||||
|
||||
gt = Profile()
|
||||
with gt:
|
||||
vts = have_tracked(bboxes, features_dict, camera_id)
|
||||
|
||||
nn = 0
|
||||
for res in vts.Residual:
|
||||
for box in res.boxes:
|
||||
try:
|
||||
box = [int(i) for i in box.tolist()]
|
||||
print('box[7] >>>> {}'.format(box[7]))
|
||||
x1, y1, x2, y2 = self.get_xyxy_coordinates(box, frame_id_img[box[7]])
|
||||
gvalue.track_y_lists.append(y1)
|
||||
c_img = frame_id_img[box[7]][y1:y2, x1:x2][:, :, ::-1]
|
||||
|
||||
# c_img = frame_id_img[box[7]][box[1]:box[3], box[0]:box[2]][:, :, ::-1]
|
||||
img_pil = Image.fromarray(c_img.astype('uint8'), 'RGB')
|
||||
|
||||
img_pil.save(os.sep.join([save_imgs_dir, str(nn) + '.jpg']))
|
||||
nn += 1
|
||||
|
||||
track_id = str(box[4])
|
||||
track_id_list.append(track_id)
|
||||
if track_id not in image_lists:
|
||||
image_lists[track_id] = []
|
||||
image_lists[track_id].append(img_pil)
|
||||
except Exception as e:
|
||||
print("y1: {}, y2: {}, x1:{} x2:{}".format(box[2], box[3], box[0], box[1]))
|
||||
print("x:{}, y:{}".format(frame_id_img[box[7]].shape[1], frame_id_img[box[7]].shape[0]))
|
||||
print(f"处理边界框时发生错误: {e}")
|
||||
continue
|
||||
|
||||
all_image_list = list(image_lists.values())
|
||||
trackIdList = list(set(track_id_list))
|
||||
|
||||
return all_image_list, trackIdList
|
||||
|
||||
@staticmethod
|
||||
def process_topn_data(source_data):
|
||||
if source_data is None:
|
||||
return None
|
||||
|
||||
if not isinstance(source_data, dict):
|
||||
raise ValueError("输入数据必须是字典类型")
|
||||
|
||||
if not source_data:
|
||||
return {}
|
||||
|
||||
total = {}
|
||||
carId_barcode_trackId_list = []
|
||||
data_category = []
|
||||
|
||||
for category, category_data in source_data.items():
|
||||
carId_barcode_trackId_list.append(category)
|
||||
for car_id, similarity in category_data.items():
|
||||
data_category.append({'carId_barcode_trackId_n': car_id, 'similarity': similarity})
|
||||
|
||||
total['carId_barcode_trackId'] = carId_barcode_trackId_list
|
||||
total['data'] = data_category
|
||||
|
||||
return total
|
||||
|
||||
@staticmethod
|
||||
def process_top10_data(source_data):
|
||||
if source_data is None:
|
||||
return None
|
||||
|
||||
if not isinstance(source_data, dict):
|
||||
raise ValueError("输入数据必须是字典类型")
|
||||
|
||||
if not source_data:
|
||||
return {}
|
||||
|
||||
total = {}
|
||||
data_category = []
|
||||
|
||||
for category, category_data in source_data.items():
|
||||
trackid = category.split('_')[-1]
|
||||
barcode = category.split('_')[-2]
|
||||
for car_id, similarity in category_data.items():
|
||||
data_category.append({'barcode': car_id, 'similarity': similarity, 'trackid': trackid})
|
||||
|
||||
total['barcode'] = barcode
|
||||
total['data'] = data_category
|
||||
return total
|
||||
|
||||
def getSimilarity(self, model, queueImgs):
|
||||
data_collection = datacollection()
|
||||
similarityRes = similarityResult()
|
||||
data_collection.barcode_flag = queueImgs['barcode_flag']
|
||||
data_collection.add_flag = queueImgs['add_flag']
|
||||
data_collection.barcode_list = queueImgs['barcode_list'].strip("'").split(',')
|
||||
data_collection.queImgsDict = queueImgs
|
||||
|
||||
similarityRes = similarity().getSimilarity(model, data_collection, similarityRes)
|
||||
# print('similarityRes.top10: ------------------ {}'.format(similarityRes.top10))
|
||||
if similarityRes.top1:
|
||||
similarityRes.top1 = {"barcode": list(similarityRes.top1.keys())[0],
|
||||
"similarity": list(similarityRes.top1.values())[0]}
|
||||
# similarityRes.tempLibList = gvalue.tempLibList
|
||||
# print('-------------------------', gvalue.tempLibLists)
|
||||
if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
|
||||
similarityRes.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
|
||||
else:
|
||||
similarityRes.tempLibList = []
|
||||
similarityresult = {
|
||||
'top10': AiClass.process_top10_data(similarityRes.top10),
|
||||
'top1': similarityRes.top1,
|
||||
'topn': AiClass.process_topn_data(similarityRes.topn),
|
||||
'tempLibList': similarityRes.tempLibList,
|
||||
'sequenceId': queueImgs['sequenceId'],
|
||||
}
|
||||
return similarityresult
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
AI = AiClass()
|
||||
|
||||
# track_boxes, frame_id_img = run()
|
||||
# AI.getTrackingBox(track_boxes, frame_id_img)
|
||||
# print('=== test ===')
|
||||
# AI.getSimilarity(cfg.queueImgs)
|
Reference in New Issue
Block a user