多机并行计算
This commit is contained in:
@ -4,7 +4,7 @@ import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tools.dataset import get_transform
|
||||
from model import resnet18
|
||||
from model import resnet18, resnet34, resnet50, resnet101
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
@ -50,7 +50,16 @@ class FeatureExtractor:
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
if conf['models']['backbone'] == 'resnet18':
|
||||
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet34':
|
||||
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet50':
|
||||
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
elif conf['models']['backbone'] == 'resnet101':
|
||||
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
|
||||
else:
|
||||
print("不支持的模型: {}".format(conf['models']['backbone']))
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
@ -168,7 +177,7 @@ class FeatureExtractor:
|
||||
# Validate input directory
|
||||
if not os.path.isdir(folder):
|
||||
raise ValueError(f"Invalid directory: {folder}")
|
||||
|
||||
i = 0
|
||||
# Process each barcode directory
|
||||
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||||
if not dirs: # Leaf directory (contains images)
|
||||
@ -180,14 +189,16 @@ class FeatureExtractor:
|
||||
ori_barcode = basename
|
||||
barcode = basename
|
||||
# Apply filter if provided
|
||||
i += 1
|
||||
print(ori_barcode, i)
|
||||
if filter and ori_barcode not in filter:
|
||||
continue
|
||||
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||||
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
f.write(ori_barcode + '\n')
|
||||
f.close()
|
||||
continue
|
||||
# elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
|
||||
# logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
# with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
# f.write(ori_barcode + '\n')
|
||||
# f.close()
|
||||
# continue
|
||||
|
||||
# Process image files
|
||||
if files:
|
||||
@ -299,7 +310,8 @@ class FeatureExtractor:
|
||||
dicts['value'] = truncated_imgs_list
|
||||
if create_single_json:
|
||||
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'],
|
||||
str(barcode_list[i]) + '.json')
|
||||
with open(json_path, 'w') as json_file:
|
||||
json.dump(dicts, json_file)
|
||||
else:
|
||||
@ -317,8 +329,10 @@ class FeatureExtractor:
|
||||
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||||
for barcode in os.listdir(pth):
|
||||
print("barcode length >> {}".format(len(barcode)))
|
||||
if len(barcode) > 13 or len(barcode) < 8:
|
||||
continue
|
||||
|
||||
# if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
|
||||
# continue
|
||||
|
||||
if filter is not None:
|
||||
f.writelines(barcode + '\n')
|
||||
if barcode in filter:
|
||||
|
Reference in New Issue
Block a user