diff --git a/contrast/feat_extract/config.py b/contrast/feat_extract/config.py index 463ec75..f0cc387 100644 --- a/contrast/feat_extract/config.py +++ b/contrast/feat_extract/config.py @@ -61,7 +61,7 @@ class Config: test_val = "D:/比对/cl" # test_val = "./data/test_data_100" - test_model = "checkpoints/best_resnet18_v11.pth" + test_model = "checkpoints/best_resnet18_v12.pth" # test_model = "checkpoints/zhanting_res_801.pth" diff --git a/contrast/one2one_contrast.py b/contrast/one2one_contrast.py index 3ee7eb9..f2b2ba3 100644 --- a/contrast/one2one_contrast.py +++ b/contrast/one2one_contrast.py @@ -194,6 +194,9 @@ def simi_calc(event, stdfeat): if len(evtfeat)==0 or len(stdfeat)==0: return None, None, None + evtfeat /= np.linalg.norm(evtfeat, axis=1)[:, None] + stdfeat /= np.linalg.norm(stdfeat, axis=1)[:, None] + matrix = 1 - cdist(evtfeat, stdfeat, 'cosine') matrix[matrix < 0] = 0 @@ -212,19 +215,19 @@ def build_std_evt_dict(): eventDataPath: Event对象地址 ''' - # stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and p.suffix=='.json'] + stdBarcode = [p.stem for p in Path(stdFeaturePath).iterdir() if p.is_file() and (p.suffix=='.json' or p.suffix=='.pickle')] '''*********** USearch ***********''' - stdFeaturePath = r"D:\contrast\stdlib\v11_test.json" - stdBarcode = [] - stdlib = {} - with open(stdFeaturePath, 'r', encoding='utf-8') as f: - data = json.load(f) - for dic in data['total']: - barcode = dic['key'] - feature = np.array(dic['value']) - stdBarcode.append(barcode) - stdlib[barcode] = feature + # stdFeaturePath = r"D:\contrast\stdlib\v11_test.json" + # stdBarcode = [] + # stdlib = {} + # with open(stdFeaturePath, 'r', encoding='utf-8') as f: + # data = json.load(f) + # for dic in data['total']: + # barcode = dic['key'] + # feature = np.array(dic['value']) + # stdBarcode.append(barcode) + # stdlib[barcode] = feature '''======1. 购物事件列表,该列表中的 Barcode 存在于标准的 stdBarcode 内 ===''' evtList = [(p.stem, p.stem.split('_')[-1]) for p in Path(eventDataPath).iterdir() @@ -237,18 +240,31 @@ def build_std_evt_dict(): barcodes = set([bcd for _, bcd in evtList]) '''======2. 构建用于比对的标准特征字典 =============''' - # stdDict = {} - # for barcode in barcodes: - # stdpath = os.path.join(stdFeaturePath, barcode+'.json') - # with open(stdpath, 'r', encoding='utf-8') as f: - # stddata = json.load(f) - # feat = np.array(stddata["value"]) - # stdDict[barcode] = feat + stdDict = {} + for stdfile in os.listdir(stdFeaturePath): + barcode, ext = os.path.splitext(stdfile) + if barcode not in barcodes: + continue + stdpath = os.path.join(stdFeaturePath, stdfile) + + if ext == ".json": + with open(stdpath, 'r', encoding='utf-8') as f: + stddata = json.load(f) + feat = np.array(stddata["value"]) + stdDict[barcode] = feat + if ext == ".pickle": + with open(stdpath, 'rb') as f: + stddata = pickle.load(f) + feat = stddata["feats_ft32"] + stdDict[barcode] = feat + + + '''*********** USearch ***********''' - stdDict = {} - for barcode in barcodes: - stdDict[barcode] = stdlib[barcode] + # stdDict = {} + # for barcode in barcodes: + # stdDict[barcode] = stdlib[barcode] '''======3. 构建用于比对的操作事件字典 =============''' evtDict = {} @@ -390,7 +406,7 @@ def one2one_simi(evtList, evtDict, stdDict): evtname, stdbcd, label = mergePairs[i] event = evtDict[evtname] if len(event.feats_compose)==0: continue - + stdfeat = stdDict[stdbcd] # float32 simi_mean, simi_max, simi_mfeat = simi_calc(event, stdfeat) @@ -402,7 +418,6 @@ def one2one_simi(evtList, evtDict, stdDict): '''================ float32、16、int8 精度比较与存储 =============''' # data_precision_compare(stdfeat, evtfeat, mergePairs[i], save=True) - return rltdata @@ -488,14 +503,28 @@ def gen_eventdict(sourcePath, saveimg=True): for source_path in sourcePath: evtpath, bname = os.path.split(source_path) - # bname = r"20241126-135911-bdf91cf9-3e9a-426d-94e8-ddf92238e175_6923555210479" - source_path = os.path.join(evtpath, bname) + ## 兼容事件的两种情况:文件夹 和 Yolo-Resnet-Tracker 的输出 + if os.path.isfile(source_path): + bname, ext = os.path.splitext(bname) + evt = bname.split("_") + + evt = bname.split('_') + condt = len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10 + if not condt: continue + # bname = r"20241126-135911-bdf91cf9-3e9a-426d-94e8-ddf92238e175_6923555210479" + # source_path = os.path.join(evtpath, bname) + + # 如果已完成事件生成,则不执行 pickpath = os.path.join(eventDataPath, f"{bname}.pickle") if os.path.isfile(pickpath): continue + # event = ShoppingEvent(source_path, stype="data") + # with open(pickpath, 'wb') as f: + # pickle.dump(event, f) + try: - event = ShoppingEvent(source_path, stype="data") + event = ShoppingEvent(source_path, stype="source") # save_data(event, resultPath) with open(pickpath, 'wb') as f: @@ -521,9 +550,18 @@ def init_std_evt_dict(): bcdList, event_spath = [], [] for evtpath in eventSourcePath: for evtname in os.listdir(evtpath): - evt = evtname.split('_') - dirpath = os.path.join(evtpath, evtname) - if os.path.isfile(dirpath): continue + bname, ext = os.path.splitext(evtname) + + ## 处理事件的两种情况:文件夹 和 Yolo-Resnet-Tracker 的输出 + fpath = os.path.join(evtpath, evtname) + if os.path.isfile(fpath) and (ext==".pkl" or ext==".pickle"): + evt = bname.split('_') + elif os.path.isdir(fpath): + evt = evtname.split('_') + else: + continue + + if len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10: bcdList.append(evt[-1]) event_spath.append(os.path.join(evtpath, evtname)) @@ -579,15 +617,20 @@ if __name__ == '__main__': (7) similPath: 1:1比对结果存储地址(事件级),在resultPath下 ''' - stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase" - stdBarcodePath = r"D:\exhibition\dataset\bcdpath" - stdFeaturePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\features_json\v11_barcode_11592" + # stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\展厅数据\v1.0\比对数据\整理\zhantingBase" + # stdBarcodePath = r"D:\exhibition\dataset\bcdpath" + # stdFeaturePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\features_json\v11_barcode_11592" # eventSourcePath = [r'D:\exhibition\images\20241202'] # eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\各模块测试记录\展厅测试\1129_展厅模型v801测试组测试"] - eventSourcePath = [r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\images"] - resultPath = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\result" + + stdSamplePath = r"\\192.168.1.28\share\数据\已完成数据\比对数据\barcode\all_totalBarocde\totalBarcode" + stdBarcodePath = r"D:\全实时\source_data\bcdpath" + stdFeaturePath = r"D:\全实时\source_data\stdfeats" + + eventSourcePath = [r"D:\全实时\result\pipeline\pipeline"] + resultPath = r"D:\全实时\result\pipeline" eventDataPath = os.path.join(resultPath, "evtobjs") diff --git a/pipeline.py b/pipeline.py index cf9adad..4cbafa3 100644 --- a/pipeline.py +++ b/pipeline.py @@ -254,11 +254,10 @@ def main(): ''' 函数:pipeline(),遍历事件文件夹,选择类型 image 或 video, ''' - evtdir = r"D:\全实时\source_data\2024122416" - evtdir = Path(evtdir) - + evtdir = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\images" + evtdir = Path(evtdir) parmDict = {} - parmDict["savepath"] = r"D:\全实时\result\pipeline" + parmDict["savepath"] = r"D:\contrast\202412测试" parmDict["SourceType"] = "video" # video, image parmDict["stdfeat_path"] = None diff --git a/tracking/utils/__pycache__/read_data.cpython-39.pyc b/tracking/utils/__pycache__/read_data.cpython-39.pyc index b60e5fe..f9eace9 100644 Binary files a/tracking/utils/__pycache__/read_data.cpython-39.pyc and b/tracking/utils/__pycache__/read_data.cpython-39.pyc differ diff --git a/tracking/utils/read_data.py b/tracking/utils/read_data.py index a30d2fe..bbcd2da 100644 --- a/tracking/utils/read_data.py +++ b/tracking/utils/read_data.py @@ -278,7 +278,8 @@ def read_tracking_output(filepath): line = line.strip() # 去除行尾的换行符和可能的空白字符 if not line: - continue + continue + if line.find("gift")>0: continue if line.endswith(','): line = line[:-1]