This commit is contained in:
lee
2024-11-27 15:37:10 +08:00
commit 3a5214c796
696 changed files with 56947 additions and 0 deletions

175
tools/Interface.py Normal file
View File

@ -0,0 +1,175 @@
import abc
# import os
# import pdb
# import pickle
import sys
import cv2
import numpy as np
from tools.config import gvalue
sys.path.append('./ytracking')
sys.path.append('./contrast')
# from ytracking.tracking.dotrack import init_tracker, VideoTracks, boxes_add_fid
from ytracking.tracking.have_tracking import have_tracked
from ytracking.track_ import *
from contrast.logic import datacollection, similarityResult, similarity
from PIL import Image
from tools.config import gvalue
class AiInterface(metaclass=abc.ABCMeta):
@abc.abstractmethod
def getTrackingBox(self):
pass
@abc.abstractmethod
def getSimilarity(self):
pass
class AiClass(AiInterface):
def __init__(self):
pass
def get_xyxy_coordinates(self, box, frame_id_img):
"""
计算并返回边界框的坐标。
"""
try:
x1 = max(0, int(box[0]))
x2 = min(frame_id_img.shape[1], int(box[2]))
y1 = max(0, int(box[1]))
y2 = min(frame_id_img.shape[0], int(box[3]))
return x1, y1, x2, y2
except IndexError as e:
raise ValueError("边界框坐标超出图像尺寸") from e
def getTrackingBox(self, bboxes, features_dict, camera_id, frame_id_img, save_imgs_dir):
"""
根据提供的边界框和帧图像返回图像列表和轨迹ID列表。
"""
image_lists = {}
track_id_list = []
gt = Profile()
with gt:
vts = have_tracked(bboxes, features_dict, camera_id)
nn = 0
for res in vts.Residual:
for box in res.boxes:
try:
box = [int(i) for i in box.tolist()]
print('box[7] >>>> {}'.format(box[7]))
x1, y1, x2, y2 = self.get_xyxy_coordinates(box, frame_id_img[box[7]])
gvalue.track_y_lists.append(y1)
c_img = frame_id_img[box[7]][y1:y2, x1:x2][:, :, ::-1]
# c_img = frame_id_img[box[7]][box[1]:box[3], box[0]:box[2]][:, :, ::-1]
img_pil = Image.fromarray(c_img.astype('uint8'), 'RGB')
img_pil.save(os.sep.join([save_imgs_dir, str(nn) + '.jpg']))
nn += 1
track_id = str(box[4])
track_id_list.append(track_id)
if track_id not in image_lists:
image_lists[track_id] = []
image_lists[track_id].append(img_pil)
except Exception as e:
print("y1: {}, y2: {}, x1:{} x2:{}".format(box[2], box[3], box[0], box[1]))
print("x:{}, y:{}".format(frame_id_img[box[7]].shape[1], frame_id_img[box[7]].shape[0]))
print(f"处理边界框时发生错误: {e}")
continue
all_image_list = list(image_lists.values())
trackIdList = list(set(track_id_list))
return all_image_list, trackIdList
@staticmethod
def process_topn_data(source_data):
if source_data is None:
return None
if not isinstance(source_data, dict):
raise ValueError("输入数据必须是字典类型")
if not source_data:
return {}
total = {}
carId_barcode_trackId_list = []
data_category = []
for category, category_data in source_data.items():
carId_barcode_trackId_list.append(category)
for car_id, similarity in category_data.items():
data_category.append({'carId_barcode_trackId_n': car_id, 'similarity': similarity})
total['carId_barcode_trackId'] = carId_barcode_trackId_list
total['data'] = data_category
return total
@staticmethod
def process_top10_data(source_data):
if source_data is None:
return None
if not isinstance(source_data, dict):
raise ValueError("输入数据必须是字典类型")
if not source_data:
return {}
total = {}
data_category = []
for category, category_data in source_data.items():
trackid = category.split('_')[-1]
barcode = category.split('_')[-2]
for car_id, similarity in category_data.items():
data_category.append({'barcode': car_id, 'similarity': similarity, 'trackid': trackid})
total['barcode'] = barcode
total['data'] = data_category
return total
def getSimilarity(self, model, queueImgs):
data_collection = datacollection()
similarityRes = similarityResult()
data_collection.barcode_flag = queueImgs['barcode_flag']
data_collection.add_flag = queueImgs['add_flag']
data_collection.barcode_list = queueImgs['barcode_list'].strip("'").split(',')
data_collection.queImgsDict = queueImgs
similarityRes = similarity().getSimilarity(model, data_collection, similarityRes)
# print('similarityRes.top10: ------------------ {}'.format(similarityRes.top10))
if similarityRes.top1:
similarityRes.top1 = {"barcode": list(similarityRes.top1.keys())[0],
"similarity": list(similarityRes.top1.values())[0]}
# similarityRes.tempLibList = gvalue.tempLibList
# print('-------------------------', gvalue.tempLibLists)
if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
similarityRes.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
else:
similarityRes.tempLibList = []
similarityresult = {
'top10': AiClass.process_top10_data(similarityRes.top10),
'top1': similarityRes.top1,
'topn': AiClass.process_topn_data(similarityRes.topn),
'tempLibList': similarityRes.tempLibList,
'sequenceId': queueImgs['sequenceId'],
}
return similarityresult
if __name__ == '__main__':
AI = AiClass()
# track_boxes, frame_id_img = run()
# AI.getTrackingBox(track_boxes, frame_id_img)
# print('=== test ===')
# AI.getSimilarity(cfg.queueImgs)