abc
This commit is contained in:
@ -27,7 +27,7 @@ model.load_state_dict(torch.load(model_path, map_location=conf.device))
|
||||
model.eval()
|
||||
print('load model {} '.format(conf.testbackbone))
|
||||
|
||||
def get_std_barcodeDict(bcdpath, savepath):
|
||||
def get_std_barcodeDict(bcdpath, savepath, bcdSet):
|
||||
'''
|
||||
inputs:
|
||||
bcdpath: 已清洗的barcode样本图像,如果barcode下有'base'文件夹,只选用该文件夹下图像
|
||||
@ -42,10 +42,14 @@ def get_std_barcodeDict(bcdpath, savepath):
|
||||
'''读取数据集中 barcode 列表'''
|
||||
stdBarcodeList = []
|
||||
for filename in os.listdir(bcdpath):
|
||||
filepath = os.path.join(bcdpath, filename)
|
||||
# filepath = os.path.join(bcdpath, filename)
|
||||
# if not os.path.isdir(filepath) or not filename.isdigit() or len(filename)<8:
|
||||
# continue
|
||||
stdBarcodeList.append(filename)
|
||||
if bcdSet is None:
|
||||
stdBarcodeList.append(filename)
|
||||
elif filename in bcdSet:
|
||||
stdBarcodeList.append(filename)
|
||||
|
||||
|
||||
bcdPaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBarcodeList]
|
||||
|
||||
@ -184,18 +188,15 @@ def stdfeat_infer(imgPath, featPath, bcdSet=None):
|
||||
|
||||
|
||||
|
||||
def genfeatures(imgpath, bcdpath, featpath):
|
||||
def genfeatures(imgpath, bcdpath, featpath, bcdSet=None):
|
||||
''' 生成标准特征集 '''
|
||||
'''1. 提取 imgpath 中样本地址,生成字典{barcode: [imgpath1, imgpath1, ...]}
|
||||
并存储于: bcdpath, 格式为 barcode.pickle'''
|
||||
get_std_barcodeDict(imgpath, bcdpath, bcdSet)
|
||||
|
||||
get_std_barcodeDict(imgpath, bcdpath)
|
||||
stdfeat_infer(bcdpath, featpath, bcdSet=None)
|
||||
'''2. 特征提取,并保存至文件夹 featpath 中,也根据 bcdSet 交集执行'''
|
||||
stdfeat_infer(bcdpath, featpath, bcdSet)
|
||||
|
||||
print(f"Features have generated, saved in: {featpath}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
imgpath = r"\\192.168.1.28\share\展厅barcode数据\整理\zhantingBase"
|
||||
bcdpath = r"D:\exhibition\dataset\bcdpath"
|
||||
|
Reference in New Issue
Block a user