多机并行计算

This commit is contained in:
lee
2025-08-14 10:09:54 +08:00
parent bc896fc688
commit 99a204ee22
18 changed files with 105 additions and 55 deletions

View File

@ -0,0 +1,23 @@
../electronic_imgs/0.jpg
../electronic_imgs/1.jpg
../electronic_imgs/2.jpg
../electronic_imgs/3.jpg
../electronic_imgs/4.jpg
../electronic_imgs/5.jpg
../electronic_imgs/6.jpg
../electronic_imgs/7.jpg
../electronic_imgs/8.jpg
../electronic_imgs/9.jpg
../electronic_imgs/10.jpg
../electronic_imgs/11.jpg
../electronic_imgs/12.jpg
../electronic_imgs/13.jpg
../electronic_imgs/14.jpg
../electronic_imgs/15.jpg
../electronic_imgs/16.jpg
../electronic_imgs/17.jpg
../electronic_imgs/18.jpg
../electronic_imgs/19.jpg
../electronic_imgs/20.jpg
../electronic_imgs/21.jpg
../electronic_imgs/22.jpg

View File

@ -203,11 +203,11 @@ class PairGenerator:
if __name__ == "__main__":
original_path = '/home/lc/data_center/contrast_data/v1/extra'
original_path = '/home/lc/data_center/electornic/v1/val'
parent_dir = str(Path(original_path).parent)
generator = PairGenerator(original_path)
# Example usage:
pairs = generator.get_pairs(original_path,
output_txt=os.sep.join([parent_dir, 'extra_cross_same.txt'])) # Individual pairs
output_txt=os.sep.join([parent_dir, 'cross_same.txt'])) # Individual pairs
# groups = generator.get_group_pairs('val') # Group pairs

View File

@ -122,8 +122,9 @@ if __name__ == '__main__':
# Build model
print('--> Building model')
ret = rknn.build(do_quantization=False, # True
dataset='./dataset.txt',
ret = rknn.build(do_quantization=True, # True
# dataset='./dataset.txt',
dataset=conf['base']['dataset'],
rknn_batch_size=conf['models']['rknn_batch_size'])
# ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
if ret != 0:

View File

@ -237,6 +237,6 @@ def get_histogram(data, label=None):
if __name__ == '__main__':
# picTopic_matrix = picDirSimilarAnalysis()
# picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix')
read_result_txt()
picTopic_matrix = picDirSimilarAnalysis()
picTopic_matrix.get_group_similarity_matrix('/home/lc/data_center/image_analysis/pic_pic_similar_maxtrix')
# read_result_txt()

View File

@ -22,7 +22,7 @@ class SimilarAnalysis:
"""初始化模型和度量方法"""
tr_tools = trainer_tools(conf)
backbone_mapping = tr_tools.get_backbone()
print('model_path {}'.format(conf['models']['model_path']))
if conf['models']['backbone'] in backbone_mapping:
model = backbone_mapping[conf['models']['backbone']]()
else:

View File

@ -4,7 +4,7 @@ import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
from tools.dataset import get_transform
from model import resnet18
from model import resnet18, resnet34, resnet50, resnet101
import torch
from PIL import Image
import pandas as pd
@ -50,7 +50,16 @@ class FeatureExtractor:
raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
if conf['models']['backbone'] == 'resnet18':
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet34':
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet50':
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet101':
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
else:
print("不支持的模型: {}".format(conf['models']['backbone']))
# Handle multi-GPU case
if conf['base']['distributed']:
@ -168,7 +177,7 @@ class FeatureExtractor:
# Validate input directory
if not os.path.isdir(folder):
raise ValueError(f"Invalid directory: {folder}")
i = 0
# Process each barcode directory
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
if not dirs: # Leaf directory (contains images)
@ -180,14 +189,16 @@ class FeatureExtractor:
ori_barcode = basename
barcode = basename
# Apply filter if provided
i += 1
print(ori_barcode, i)
if filter and ori_barcode not in filter:
continue
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
logger.warning(f"Skipping invalid barcode {ori_barcode}")
with open(conf['save']['error_barcodes'], 'a') as f:
f.write(ori_barcode + '\n')
f.close()
continue
# elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
# logger.warning(f"Skipping invalid barcode {ori_barcode}")
# with open(conf['save']['error_barcodes'], 'a') as f:
# f.write(ori_barcode + '\n')
# f.close()
# continue
# Process image files
if files:
@ -299,7 +310,8 @@ class FeatureExtractor:
dicts['value'] = truncated_imgs_list
if create_single_json:
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'],
str(barcode_list[i]) + '.json')
with open(json_path, 'w') as json_file:
json.dump(dicts, json_file)
else:
@ -317,8 +329,10 @@ class FeatureExtractor:
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
for barcode in os.listdir(pth):
print("barcode length >> {}".format(len(barcode)))
if len(barcode) > 13 or len(barcode) < 8:
continue
# if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
# continue
if filter is not None:
f.writelines(barcode + '\n')
if barcode in filter: