多机并行计算
This commit is contained in:
23
tools/dataset_electornic.txt
Normal file
23
tools/dataset_electornic.txt
Normal file
@ -0,0 +1,23 @@
|
||||
../electronic_imgs/0.jpg
|
||||
../electronic_imgs/1.jpg
|
||||
../electronic_imgs/2.jpg
|
||||
../electronic_imgs/3.jpg
|
||||
../electronic_imgs/4.jpg
|
||||
../electronic_imgs/5.jpg
|
||||
../electronic_imgs/6.jpg
|
||||
../electronic_imgs/7.jpg
|
||||
../electronic_imgs/8.jpg
|
||||
../electronic_imgs/9.jpg
|
||||
../electronic_imgs/10.jpg
|
||||
../electronic_imgs/11.jpg
|
||||
../electronic_imgs/12.jpg
|
||||
../electronic_imgs/13.jpg
|
||||
../electronic_imgs/14.jpg
|
||||
../electronic_imgs/15.jpg
|
||||
../electronic_imgs/16.jpg
|
||||
../electronic_imgs/17.jpg
|
||||
../electronic_imgs/18.jpg
|
||||
../electronic_imgs/19.jpg
|
||||
../electronic_imgs/20.jpg
|
||||
../electronic_imgs/21.jpg
|
||||
../electronic_imgs/22.jpg
|
@ -203,11 +203,11 @@ class PairGenerator:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
original_path = '/home/lc/data_center/contrast_data/v1/extra'
|
||||
original_path = '/home/lc/data_center/electornic/v1/val'
|
||||
parent_dir = str(Path(original_path).parent)
|
||||
generator = PairGenerator(original_path)
|
||||
|
||||
# Example usage:
|
||||
pairs = generator.get_pairs(original_path,
|
||||
output_txt=os.sep.join([parent_dir, 'extra_cross_same.txt'])) # Individual pairs
|
||||
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
|
||||
# groups = generator.get_group_pairs('val') # Group pairs
|
||||
|
@ -122,8 +122,9 @@ if __name__ == '__main__':
|
||||
|
||||
# Build model
|
||||
print('--> Building model')
|
||||
ret = rknn.build(do_quantization=False, # True
|
||||
dataset='./dataset.txt',
|
||||
ret = rknn.build(do_quantization=True, # True
|
||||
# dataset='./dataset.txt',
|
||||
dataset=conf['base']['dataset'],
|
||||
rknn_batch_size=conf['models']['rknn_batch_size'])
|
||||
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
|
||||
if ret != 0:
|
||||
|
@ -237,6 +237,6 @@ def get_histogram(data, label=None):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# picTopic_matrix = picDirSimilarAnalysis()
|
||||
# picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix')
|
||||
read_result_txt()
|
||||
picTopic_matrix = picDirSimilarAnalysis()
|
||||
picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix')
|
||||
# read_result_txt()
|
||||
|
@ -22,7 +22,7 @@ class SimilarAnalysis:
|
||||
"""初始化模型和度量方法"""
|
||||
tr_tools = trainer_tools(conf)
|
||||
backbone_mapping = tr_tools.get_backbone()
|
||||
|
||||
print('model_path {}'.format(conf['models']['model_path']))
|
||||
if conf['models']['backbone'] in backbone_mapping:
|
||||
model = backbone_mapping[conf['models']['backbone']]()
|
||||
else:
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tools.dataset import get_transform
|
||||
from model import resnet18
|
||||
from model import resnet18, resnet34, resnet50, resnet101
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
@ -50,7 +50,16 @@ class FeatureExtractor:
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
if conf['models']['backbone'] == 'resnet18':
|
||||
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet34':
|
||||
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet50':
|
||||
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet101':
|
||||
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
else:
|
||||
print("不支持的模型: {}".format(conf['models']['backbone']))
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
@ -168,7 +177,7 @@ class FeatureExtractor:
|
||||
# Validate input directory
|
||||
if not os.path.isdir(folder):
|
||||
raise ValueError(f"Invalid directory: {folder}")
|
||||
|
||||
i = 0
|
||||
# Process each barcode directory
|
||||
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||||
if not dirs: # Leaf directory (contains images)
|
||||
@ -180,14 +189,16 @@ class FeatureExtractor:
|
||||
ori_barcode = basename
|
||||
barcode = basename
|
||||
# Apply filter if provided
|
||||
i += 1
|
||||
print(ori_barcode, i)
|
||||
if filter and ori_barcode not in filter:
|
||||
continue
|
||||
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||||
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
f.write(ori_barcode + '\n')
|
||||
f.close()
|
||||
continue
|
||||
# elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
|
||||
# logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
# with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
# f.write(ori_barcode + '\n')
|
||||
# f.close()
|
||||
# continue
|
||||
|
||||
# Process image files
|
||||
if files:
|
||||
@ -299,7 +310,8 @@ class FeatureExtractor:
|
||||
dicts['value'] = truncated_imgs_list
|
||||
if create_single_json:
|
||||
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'],
|
||||
str(barcode_list[i]) + '.json')
|
||||
with open(json_path, 'w') as json_file:
|
||||
json.dump(dicts, json_file)
|
||||
else:
|
||||
@ -317,8 +329,10 @@ class FeatureExtractor:
|
||||
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||||
for barcode in os.listdir(pth):
|
||||
print("barcode length >> {}".format(len(barcode)))
|
||||
if len(barcode) > 13 or len(barcode) < 8:
|
||||
continue
|
||||
|
||||
# if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
|
||||
# continue
|
||||
|
||||
if filter is not None:
|
||||
f.writelines(barcode + '\n')
|
||||
if barcode in filter:
|
||||
|
Reference in New Issue
Block a user