Files
detecttracking/contrast/trail2trail.py
2025-04-11 17:02:39 +08:00

172 lines
5.0 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
取出再放回场景下商品轨迹特征比对方式与性能分析
Created on Tue Apr 1 17:17:47 2025
@author: wqg
"""
import os
import pickle
import random
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from utils.calsimi import calsiml, calsimi_vs_evts
def read_eventdict(evtpaths):
evtDict = {}
for filename in os.listdir(evtpaths):
evtname, ext = os.path.splitext(filename)
if ext != ".pickle": continue
evtpath = os.path.join(evtpaths, filename)
with open(evtpath, 'rb') as f:
evtdata = pickle.load(f)
evtDict[evtname] = evtdata
return evtDict
def compute_show_pr(Same, Cross):
TPFN = len(Same)
TNFP = len(Cross)
Recall_Pos, Recall_Neg = [], []
Precision_Pos, Precision_Neg = [], []
Correct = []
Thresh = np.linspace(-0.2, 1, 100)
for th in Thresh:
TP = np.sum(Same >= th)
FN = np.sum(Same < th)
# FN = TPFN - TP
TN = np.sum(Cross < th)
FP = np.sum(Cross >= th)
# FP = TNFP - TN
Precision_Pos.append(TP/(TP+FP+1e-6))
Precision_Neg.append(TN/(TN+FN+1e-6))
Recall_Pos.append(TP/(TP+FN+1e-6))
Recall_Neg.append(TN/(TN+FP+1e-6))
# Recall_Pos.append(TP/TPFN)
# Recall_Neg.append(TN/TNFP)
Correct.append((TN+TP)/(TPFN+TNFP))
fig, ax = plt.subplots()
ax.plot(Thresh, Precision_Pos, 'r', label='Precision_Pos: TP/(TP+FP)')
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
ax.plot(Thresh, Correct, 'c', label='Correct: (TN+TP)/(TPFN+TNFP)')
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks(np.arange(0, 1, 0.1))
ax.set_yticks(np.arange(0, 1, 0.1))
ax.grid(True, linestyle='--')
ax.set_title('PrecisePos & PreciseNeg')
ax.set_xlabel(f"Same Num: {TPFN}, Cross Num: {TNFP}")
ax.legend()
plt.show()
# rltpath = os.path.join(similPath, f'pr_1to1_{simType}.png')
# plt.savefig(rltpath) # svg, png, pdf
fig, axes = plt.subplots(2,1)
axes[0].hist(Same, bins=60, range=(-0.2, 1), edgecolor='black')
axes[0].set_xlim([-0.2, 1])
axes[0].set_title(f'TP({len(Same)})')
axes[1].hist(Cross, bins=60, range=(-0.2, 1), edgecolor='black')
axes[1].set_xlim([-0.2, 1])
axes[1].set_title(f'TN({len(Cross)})')
# rltpath = os.path.join(similPath, f'hist_1to1_{simType}.png')
# plt.savefig(rltpath)
plt.show()
def trail_to_trail(evtpaths, rltpaths):
# select the method type of how to calculate the feat similarity of trail
simType = 2
##1. read all the ShoppingEvent object in the dir 'evtpaths'
evtDicts = read_eventdict(evtpaths)
##2. Combine event object with the same barcode
barcodes, evtpairDict = [], {}
for k in evtDicts.keys():
evt = k.split('_')
condt = len(evt)>=2 and evt[-1].isdigit() and len(evt[-1])>=10
if not condt: continue
barcode = evt[-1]
if barcode not in evtpairDict.keys():
evtpairDict[barcode] = []
barcodes.append(barcode)
evtpairDict[barcode].append(evtDicts[k])
barcodes = set(barcodes)
AA_list, AB_list = [], []
for barcode in evtpairDict.keys():
events = evtpairDict[barcode]
if len(events)>1:
evta, evtb = random.sample(events, 2)
AA_list.append((evta, evtb, "same"))
evtc = random.sample(events, 1)[0]
dset = list(barcodes.symmetric_difference(set([barcode])))
bcd = random.sample(dset, 1)[0]
evtd = random.sample(evtpairDict[bcd], 1)[0]
AB_list.append((evtc, evtd, "diff"))
mergePairs = AA_list + AB_list
##3. calculate the similar of two event: evta, evtb
new_pirs = []
for evta, evtb, label in mergePairs:
similar = calsimi_vs_evts(evta, evtb, simType)
if similar is None:
continue
new_pirs.append((label, round(similar, 3), evta.evtname[:15], evtb.evtname[:15]))
##4. compute PR and showing
Same = np.array([s for label, s, _, _ in new_pirs if label=="same"])
Cross = np.array([s for label, s, _, _ in new_pirs if label=="diff"])
compute_show_pr(Same, Cross)
def main():
evttypes = ["single_event_V10", "single_event_V5", "performence_V10", "performence_V5"]
# evttypes = ["single_event_V10"]
for evttype in evttypes:
evtpaths = "/home/wqg/dataset/pipeline/contrast/{}/evtobjs/".format(evttype)
rltpaths = "/home/wqg/dataset/pipeline/yrt/{}/yolos_tracking".format(evttype)
trail_to_trail(evtpaths, rltpaths)
if __name__ == '__main__':
main()