Files
ieemoo-ai-imageassessment/tools/Interface.py
2024-11-27 15:37:10 +08:00

176 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import abc
# import os
# import pdb
# import pickle
import sys
import cv2
import numpy as np
from tools.config import gvalue
sys.path.append('./ytracking')
sys.path.append('./contrast')
# from ytracking.tracking.dotrack import init_tracker, VideoTracks, boxes_add_fid
from ytracking.tracking.have_tracking import have_tracked
from ytracking.track_ import *
from contrast.logic import datacollection, similarityResult, similarity
from PIL import Image
from tools.config import gvalue
class AiInterface(metaclass=abc.ABCMeta):
@abc.abstractmethod
def getTrackingBox(self):
pass
@abc.abstractmethod
def getSimilarity(self):
pass
class AiClass(AiInterface):
def __init__(self):
pass
def get_xyxy_coordinates(self, box, frame_id_img):
"""
计算并返回边界框的坐标。
"""
try:
x1 = max(0, int(box[0]))
x2 = min(frame_id_img.shape[1], int(box[2]))
y1 = max(0, int(box[1]))
y2 = min(frame_id_img.shape[0], int(box[3]))
return x1, y1, x2, y2
except IndexError as e:
raise ValueError("边界框坐标超出图像尺寸") from e
def getTrackingBox(self, bboxes, features_dict, camera_id, frame_id_img, save_imgs_dir):
"""
根据提供的边界框和帧图像返回图像列表和轨迹ID列表。
"""
image_lists = {}
track_id_list = []
gt = Profile()
with gt:
vts = have_tracked(bboxes, features_dict, camera_id)
nn = 0
for res in vts.Residual:
for box in res.boxes:
try:
box = [int(i) for i in box.tolist()]
print('box[7] >>>> {}'.format(box[7]))
x1, y1, x2, y2 = self.get_xyxy_coordinates(box, frame_id_img[box[7]])
gvalue.track_y_lists.append(y1)
c_img = frame_id_img[box[7]][y1:y2, x1:x2][:, :, ::-1]
# c_img = frame_id_img[box[7]][box[1]:box[3], box[0]:box[2]][:, :, ::-1]
img_pil = Image.fromarray(c_img.astype('uint8'), 'RGB')
img_pil.save(os.sep.join([save_imgs_dir, str(nn) + '.jpg']))
nn += 1
track_id = str(box[4])
track_id_list.append(track_id)
if track_id not in image_lists:
image_lists[track_id] = []
image_lists[track_id].append(img_pil)
except Exception as e:
print("y1: {}, y2: {}, x1:{} x2:{}".format(box[2], box[3], box[0], box[1]))
print("x:{}, y:{}".format(frame_id_img[box[7]].shape[1], frame_id_img[box[7]].shape[0]))
print(f"处理边界框时发生错误: {e}")
continue
all_image_list = list(image_lists.values())
trackIdList = list(set(track_id_list))
return all_image_list, trackIdList
@staticmethod
def process_topn_data(source_data):
if source_data is None:
return None
if not isinstance(source_data, dict):
raise ValueError("输入数据必须是字典类型")
if not source_data:
return {}
total = {}
carId_barcode_trackId_list = []
data_category = []
for category, category_data in source_data.items():
carId_barcode_trackId_list.append(category)
for car_id, similarity in category_data.items():
data_category.append({'carId_barcode_trackId_n': car_id, 'similarity': similarity})
total['carId_barcode_trackId'] = carId_barcode_trackId_list
total['data'] = data_category
return total
@staticmethod
def process_top10_data(source_data):
if source_data is None:
return None
if not isinstance(source_data, dict):
raise ValueError("输入数据必须是字典类型")
if not source_data:
return {}
total = {}
data_category = []
for category, category_data in source_data.items():
trackid = category.split('_')[-1]
barcode = category.split('_')[-2]
for car_id, similarity in category_data.items():
data_category.append({'barcode': car_id, 'similarity': similarity, 'trackid': trackid})
total['barcode'] = barcode
total['data'] = data_category
return total
def getSimilarity(self, model, queueImgs):
data_collection = datacollection()
similarityRes = similarityResult()
data_collection.barcode_flag = queueImgs['barcode_flag']
data_collection.add_flag = queueImgs['add_flag']
data_collection.barcode_list = queueImgs['barcode_list'].strip("'").split(',')
data_collection.queImgsDict = queueImgs
similarityRes = similarity().getSimilarity(model, data_collection, similarityRes)
# print('similarityRes.top10: ------------------ {}'.format(similarityRes.top10))
if similarityRes.top1:
similarityRes.top1 = {"barcode": list(similarityRes.top1.keys())[0],
"similarity": list(similarityRes.top1.values())[0]}
# similarityRes.tempLibList = gvalue.tempLibList
# print('-------------------------', gvalue.tempLibLists)
if gvalue.tempLibLists.get(gvalue.mac_id) is not None:
similarityRes.tempLibList = gvalue.tempLibLists[gvalue.mac_id]
else:
similarityRes.tempLibList = []
similarityresult = {
'top10': AiClass.process_top10_data(similarityRes.top10),
'top1': similarityRes.top1,
'topn': AiClass.process_topn_data(similarityRes.topn),
'tempLibList': similarityRes.tempLibList,
'sequenceId': queueImgs['sequenceId'],
}
return similarityresult
if __name__ == '__main__':
AI = AiClass()
# track_boxes, frame_id_img = run()
# AI.getTrackingBox(track_boxes, frame_id_img)
# print('=== test ===')
# AI.getSimilarity(cfg.queueImgs)