20240102
This commit is contained in:
@ -61,7 +61,7 @@ class Config:
|
||||
test_val = "D:/比对/cl"
|
||||
# test_val = "./data/test_data_100"
|
||||
|
||||
test_model = "checkpoints/best_resnet18_v11.pth"
|
||||
test_model = "checkpoints/best_resnet18_v12.pth"
|
||||
# test_model = "checkpoints/zhanting_res_801.pth"
|
||||
|
||||
|
||||
|
@ -194,6 +194,9 @@ def simi_calc(event, stdfeat):
|
||||
if len(evtfeat)==0 or len(stdfeat)==0:
|
||||
return None, None, None
|
||||
|
||||
evtfeat /= np.linalg.norm(evtfeat, axis=1)[:, None]
|
||||
stdfeat /= np.linalg.norm(stdfeat, axis=1)[:, None]
|
||||
|
||||
matrix = 1 - cdist(evtfeat, stdfeat, 'cosine')
|
||||
matrix[matrix < 0] = 0
|
||||
|
||||
@ -212,19 +215,19 @@ def build_std_evt_dict():
|
||||
eventDataPath: Event对象地址
|
||||
'''
|
||||
|
||||
# stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and p.suffix=='.json']
|
||||
stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and (p.suffix=='.json' or p.suffix=='.pickle')]
|
||||
|
||||
'''*********** USearch ***********'''
|
||||
stdFeaturePath = r"D:\contrast\stdlib\v11_test.json"
|
||||
stdBarcode = []
|
||||
stdlib = {}
|
||||
with open(stdFeaturePath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
for dic in data['total']:
|
||||
barcode = dic['key']
|
||||
feature = np.array(dic['value'])
|
||||
stdBarcode.append(barcode)
|
||||
stdlib[barcode] = feature
|
||||
# stdFeaturePath = r"D:\contrast\stdlib\v11_test.json"
|
||||
# stdBarcode = []
|
||||
# stdlib = {}
|
||||
# with open(stdFeaturePath, 'r', encoding='utf-8') as f:
|
||||
# data = json.load(f)
|
||||
# for dic in data['total']:
|
||||
# barcode = dic['key']
|
||||
# feature = np.array(dic['value'])
|
||||
# stdBarcode.append(barcode)
|
||||
# stdlib[barcode] = feature
|
||||
|
||||
'''======1. 购物事件列表,该列表中的 Barcode 存在于标准的 stdBarcode 内 ==='''
|
||||
evtList = [(p.stem, p.stem.split('_')[-1]) for p in Path(eventDataPath).iterdir()
|
||||
@ -237,18 +240,31 @@ def build_std_evt_dict():
|
||||
barcodes = set([bcd for _, bcd in evtList])
|
||||
|
||||
'''======2. 构建用于比对的标准特征字典 ============='''
|
||||
# stdDict = {}
|
||||
# for barcode in barcodes:
|
||||
# stdpath = os.path.join(stdFeaturePath, barcode+'.json')
|
||||
# with open(stdpath, 'r', encoding='utf-8') as f:
|
||||
# stddata = json.load(f)
|
||||
# feat = np.array(stddata["value"])
|
||||
# stdDict[barcode] = feat
|
||||
stdDict = {}
|
||||
for stdfile in os.listdir(stdFeaturePath):
|
||||
barcode, ext = os.path.splitext(stdfile)
|
||||
if barcode not in barcodes:
|
||||
continue
|
||||
stdpath = os.path.join(stdFeaturePath, stdfile)
|
||||
|
||||
if ext == ".json":
|
||||
with open(stdpath, 'r', encoding='utf-8') as f:
|
||||
stddata = json.load(f)
|
||||
feat = np.array(stddata["value"])
|
||||
stdDict[barcode] = feat
|
||||
if ext == ".pickle":
|
||||
with open(stdpath, 'rb') as f:
|
||||
stddata = pickle.load(f)
|
||||
feat = stddata["feats_ft32"]
|
||||
stdDict[barcode] = feat
|
||||
|
||||
|
||||
|
||||
|
||||
'''*********** USearch ***********'''
|
||||
stdDict = {}
|
||||
for barcode in barcodes:
|
||||
stdDict[barcode] = stdlib[barcode]
|
||||
# stdDict = {}
|
||||
# for barcode in barcodes:
|
||||
# stdDict[barcode] = stdlib[barcode]
|
||||
|
||||
'''======3. 构建用于比对的操作事件字典 ============='''
|
||||
evtDict = {}
|
||||
@ -390,7 +406,7 @@ def one2one_simi(evtList, evtDict, stdDict):
|
||||
evtname, stdbcd, label = mergePairs[i]
|
||||
event = evtDict[evtname]
|
||||
if len(event.feats_compose)==0: continue
|
||||
|
||||
|
||||
stdfeat = stdDict[stdbcd] # float32
|
||||
|
||||
simi_mean, simi_max, simi_mfeat = simi_calc(event, stdfeat)
|
||||
@ -402,7 +418,6 @@ def one2one_simi(evtList, evtDict, stdDict):
|
||||
'''================ float32、16、int8 精度比较与存储 ============='''
|
||||
# data_precision_compare(stdfeat, evtfeat, mergePairs[i], save=True)
|
||||
|
||||
|
||||
|
||||
return rltdata
|
||||
|
||||
@ -488,14 +503,28 @@ def gen_eventdict(sourcePath, saveimg=True):
|
||||
for source_path in sourcePath:
|
||||
evtpath, bname = os.path.split(source_path)
|
||||
|
||||
# bname = r"20241126-135911-bdf91cf9-3e9a-426d-94e8-ddf92238e175_6923555210479"
|
||||
source_path = os.path.join(evtpath, bname)
|
||||
## 兼容事件的两种情况:文件夹 和 Yolo-Resnet-Tracker 的输出
|
||||
if os.path.isfile(source_path):
|
||||
bname, ext = os.path.splitext(bname)
|
||||
evt = bname.split("_")
|
||||
|
||||
evt = bname.split('_')
|
||||
condt = len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10
|
||||
if not condt: continue
|
||||
|
||||
# bname = r"20241126-135911-bdf91cf9-3e9a-426d-94e8-ddf92238e175_6923555210479"
|
||||
# source_path = os.path.join(evtpath, bname)
|
||||
|
||||
# 如果已完成事件生成,则不执行
|
||||
pickpath = os.path.join(eventDataPath, f"{bname}.pickle")
|
||||
if os.path.isfile(pickpath): continue
|
||||
|
||||
# event = ShoppingEvent(source_path, stype="data")
|
||||
# with open(pickpath, 'wb') as f:
|
||||
# pickle.dump(event, f)
|
||||
|
||||
try:
|
||||
event = ShoppingEvent(source_path, stype="data")
|
||||
event = ShoppingEvent(source_path, stype="source")
|
||||
# save_data(event, resultPath)
|
||||
|
||||
with open(pickpath, 'wb') as f:
|
||||
@ -521,9 +550,18 @@ def init_std_evt_dict():
|
||||
bcdList, event_spath = [], []
|
||||
for evtpath in eventSourcePath:
|
||||
for evtname in os.listdir(evtpath):
|
||||
evt = evtname.split('_')
|
||||
dirpath = os.path.join(evtpath, evtname)
|
||||
if os.path.isfile(dirpath): continue
|
||||
bname, ext = os.path.splitext(evtname)
|
||||
|
||||
## 处理事件的两种情况:文件夹 和 Yolo-Resnet-Tracker 的输出
|
||||
fpath = os.path.join(evtpath, evtname)
|
||||
if os.path.isfile(fpath) and (ext==".pkl" or ext==".pickle"):
|
||||
evt = bname.split('_')
|
||||
elif os.path.isdir(fpath):
|
||||
evt = evtname.split('_')
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
if len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10:
|
||||
bcdList.append(evt[-1])
|
||||
event_spath.append(os.path.join(evtpath, evtname))
|
||||
@ -579,15 +617,20 @@ if __name__ == '__main__':
|
||||
(7) similPath: 1:1比对结果存储地址(事件级),在resultPath下
|
||||
'''
|
||||
|
||||
stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase"
|
||||
stdBarcodePath = r"D:\exhibition\dataset\bcdpath"
|
||||
stdFeaturePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\features_json\v11_barcode_11592"
|
||||
# stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase"
|
||||
# stdBarcodePath = r"D:\exhibition\dataset\bcdpath"
|
||||
# stdFeaturePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\features_json\v11_barcode_11592"
|
||||
|
||||
# eventSourcePath = [r'D:\exhibition\images\20241202']
|
||||
# eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\各模块测试记录\展厅测试\1129_展厅模型v801测试组测试"]
|
||||
|
||||
eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\images"]
|
||||
resultPath = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\result"
|
||||
|
||||
stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\totalBarcode"
|
||||
stdBarcodePath = r"D:\全实时\source_data\bcdpath"
|
||||
stdFeaturePath = r"D:\全实时\source_data\stdfeats"
|
||||
|
||||
eventSourcePath = [r"D:\全实时\result\pipeline\pipeline"]
|
||||
resultPath = r"D:\全实时\result\pipeline"
|
||||
|
||||
|
||||
eventDataPath = os.path.join(resultPath, "evtobjs")
|
||||
|
Reference in New Issue
Block a user