This commit is contained in:
王庆刚
2025-01-02 18:23:16 +08:00
parent 7e13e0f5b4
commit 661489120b
5 changed files with 83 additions and 40 deletions

View File

@ -61,7 +61,7 @@ class Config:
test_val = "D:/比对/cl"
# test_val = "./data/test_data_100"
test_model = "checkpoints/best_resnet18_v11.pth"
test_model = "checkpoints/best_resnet18_v12.pth"
# test_model = "checkpoints/zhanting_res_801.pth"

View File

@ -194,6 +194,9 @@ def simi_calc(event, stdfeat):
if len(evtfeat)==0 or len(stdfeat)==0:
return None, None, None
evtfeat /= np.linalg.norm(evtfeat, axis=1)[:, None]
stdfeat /= np.linalg.norm(stdfeat, axis=1)[:, None]
matrix = 1 - cdist(evtfeat, stdfeat, 'cosine')
matrix[matrix < 0] = 0
@ -212,19 +215,19 @@ def build_std_evt_dict():
eventDataPath: Event对象地址
'''
# stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and p.suffix=='.json']
stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and (p.suffix=='.json' or p.suffix=='.pickle')]
'''*********** USearch ***********'''
stdFeaturePath = r"D:\contrast\stdlib\v11_test.json"
stdBarcode = []
stdlib = {}
with open(stdFeaturePath, 'r', encoding='utf-8') as f:
data = json.load(f)
for dic in data['total']:
barcode = dic['key']
feature = np.array(dic['value'])
stdBarcode.append(barcode)
stdlib[barcode] = feature
# stdFeaturePath = r"D:\contrast\stdlib\v11_test.json"
# stdBarcode = []
# stdlib = {}
# with open(stdFeaturePath, 'r', encoding='utf-8') as f:
# data = json.load(f)
# for dic in data['total']:
# barcode = dic['key']
# feature = np.array(dic['value'])
# stdBarcode.append(barcode)
# stdlib[barcode] = feature
'''======1. 购物事件列表,该列表中的 Barcode 存在于标准的 stdBarcode 内 ==='''
evtList = [(p.stem, p.stem.split('_')[-1]) for p in Path(eventDataPath).iterdir()
@ -237,18 +240,31 @@ def build_std_evt_dict():
barcodes = set([bcd for _, bcd in evtList])
'''======2. 构建用于比对的标准特征字典 ============='''
# stdDict = {}
# for barcode in barcodes:
# stdpath = os.path.join(stdFeaturePath, barcode+'.json')
# with open(stdpath, 'r', encoding='utf-8') as f:
# stddata = json.load(f)
# feat = np.array(stddata["value"])
# stdDict[barcode] = feat
stdDict = {}
for stdfile in os.listdir(stdFeaturePath):
barcode, ext = os.path.splitext(stdfile)
if barcode not in barcodes:
continue
stdpath = os.path.join(stdFeaturePath, stdfile)
if ext == ".json":
with open(stdpath, 'r', encoding='utf-8') as f:
stddata = json.load(f)
feat = np.array(stddata["value"])
stdDict[barcode] = feat
if ext == ".pickle":
with open(stdpath, 'rb') as f:
stddata = pickle.load(f)
feat = stddata["feats_ft32"]
stdDict[barcode] = feat
'''*********** USearch ***********'''
stdDict = {}
for barcode in barcodes:
stdDict[barcode] = stdlib[barcode]
# stdDict = {}
# for barcode in barcodes:
# stdDict[barcode] = stdlib[barcode]
'''======3. 构建用于比对的操作事件字典 ============='''
evtDict = {}
@ -390,7 +406,7 @@ def one2one_simi(evtList, evtDict, stdDict):
evtname, stdbcd, label = mergePairs[i]
event = evtDict[evtname]
if len(event.feats_compose)==0: continue
stdfeat = stdDict[stdbcd] # float32
simi_mean, simi_max, simi_mfeat = simi_calc(event, stdfeat)
@ -402,7 +418,6 @@ def one2one_simi(evtList, evtDict, stdDict):
'''================ float32、16、int8 精度比较与存储 ============='''
# data_precision_compare(stdfeat, evtfeat, mergePairs[i], save=True)
return rltdata
@ -488,14 +503,28 @@ def gen_eventdict(sourcePath, saveimg=True):
for source_path in sourcePath:
evtpath, bname = os.path.split(source_path)
# bname = r"20241126-135911-bdf91cf9-3e9a-426d-94e8-ddf92238e175_6923555210479"
source_path = os.path.join(evtpath, bname)
## 兼容事件的两种情况:文件夹 和 Yolo-Resnet-Tracker 的输出
if os.path.isfile(source_path):
bname, ext = os.path.splitext(bname)
evt = bname.split("_")
evt = bname.split('_')
condt = len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10
if not condt: continue
# bname = r"20241126-135911-bdf91cf9-3e9a-426d-94e8-ddf92238e175_6923555210479"
# source_path = os.path.join(evtpath, bname)
# 如果已完成事件生成,则不执行
pickpath = os.path.join(eventDataPath, f"{bname}.pickle")
if os.path.isfile(pickpath): continue
# event = ShoppingEvent(source_path, stype="data")
# with open(pickpath, 'wb') as f:
# pickle.dump(event, f)
try:
event = ShoppingEvent(source_path, stype="data")
event = ShoppingEvent(source_path, stype="source")
# save_data(event, resultPath)
with open(pickpath, 'wb') as f:
@ -521,9 +550,18 @@ def init_std_evt_dict():
bcdList, event_spath = [], []
for evtpath in eventSourcePath:
for evtname in os.listdir(evtpath):
evt = evtname.split('_')
dirpath = os.path.join(evtpath, evtname)
if os.path.isfile(dirpath): continue
bname, ext = os.path.splitext(evtname)
## 处理事件的两种情况:文件夹 和 Yolo-Resnet-Tracker 的输出
fpath = os.path.join(evtpath, evtname)
if os.path.isfile(fpath) and (ext==".pkl" or ext==".pickle"):
evt = bname.split('_')
elif os.path.isdir(fpath):
evt = evtname.split('_')
else:
continue
if len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10:
bcdList.append(evt[-1])
event_spath.append(os.path.join(evtpath, evtname))
@ -579,15 +617,20 @@ if __name__ == '__main__':
(7) similPath: 1:1比对结果存储地址(事件级)在resultPath下
'''
stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase"
stdBarcodePath = r"D:\exhibition\dataset\bcdpath"
stdFeaturePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\features_json\v11_barcode_11592"
# stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase"
# stdBarcodePath = r"D:\exhibition\dataset\bcdpath"
# stdFeaturePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\features_json\v11_barcode_11592"
# eventSourcePath = [r'D:\exhibition\images\20241202']
# eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\各模块测试记录\展厅测试\1129_展厅模型v801测试组测试"]
eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\images"]
resultPath = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\result"
stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\totalBarcode"
stdBarcodePath = r"D:\全实时\source_data\bcdpath"
stdFeaturePath = r"D:\全实时\source_data\stdfeats"
eventSourcePath = [r"D:\全实时\result\pipeline\pipeline"]
resultPath = r"D:\全实时\result\pipeline"
eventDataPath = os.path.join(resultPath, "evtobjs")

View File

@ -254,11 +254,10 @@ def main():
'''
函数pipeline(),遍历事件文件夹,选择类型 image 或 video,
'''
evtdir = r"D:\全实时\source_data\2024122416"
evtdir = Path(evtdir)
evtdir = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\images"
evtdir = Path(evtdir)
parmDict = {}
parmDict["savepath"] = r"D:\全实时\result\pipeline"
parmDict["savepath"] = r"D:\contrast\202412测试"
parmDict["SourceType"] = "video" # video, image
parmDict["stdfeat_path"] = None

View File

@ -278,7 +278,8 @@ def read_tracking_output(filepath):
line = line.strip() # 去除行尾的换行符和可能的空白字符
if not line:
continue
continue
if line.find("gift")>0: continue
if line.endswith(','):
line = line[:-1]