rebuild
This commit is contained in:
411
tools/write_feature_json.py
Normal file
411
tools/write_feature_json.py
Normal file
@ -0,0 +1,411 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tools.dataset import get_transform
|
||||
from model import resnet18
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import yaml
|
||||
import shutil
|
||||
import struct
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FeatureExtractor:
|
||||
def __init__(self, conf):
|
||||
self.conf = conf
|
||||
self.model = self.initModel()
|
||||
_, self.test_transform = get_transform(self.conf)
|
||||
pass
|
||||
|
||||
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
|
||||
"""
|
||||
Initialize and load the ResNet18 model for inference.
|
||||
|
||||
Args:
|
||||
inference_model: Optional path to model weights. Uses conf.test_model if None.
|
||||
|
||||
Returns:
|
||||
Loaded and configured PyTorch model in evaluation mode.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If model weights file is not found
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
|
||||
|
||||
try:
|
||||
# Verify model file exists
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||||
|
||||
# Initialize model
|
||||
model = resnet18().to(self.conf['base']['device'])
|
||||
|
||||
# Handle multi-GPU case
|
||||
if conf['base']['distributed']:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Load weights
|
||||
state_dict = torch.load(model_path, map_location=conf['base']['device'])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model.eval()
|
||||
logger.info(f"Successfully loaded model from {model_path}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
raise
|
||||
|
||||
def convert_rgba_to_rgb(self, image_path):
|
||||
# 打开图像
|
||||
img = Image.open(image_path)
|
||||
# 转换图像模式从RGBA到RGB
|
||||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||
if img.mode == 'RGBA':
|
||||
# 转换为RGB模式
|
||||
img_rgb = img.convert('RGB')
|
||||
# 保存转换后的图像
|
||||
img_rgb.save(image_path)
|
||||
print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||
|
||||
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
try:
|
||||
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
def inference(self, images, model, actionModel=False):
|
||||
data = self.test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf['base']['device'])
|
||||
features = model(data)
|
||||
if conf['data']['half']:
|
||||
features = features.half()
|
||||
return features
|
||||
|
||||
def group_image(self, images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
def getFeatureList(self, barList, imgList):
|
||||
featList = [[] for _ in range(len(barList))]
|
||||
|
||||
for index, image_paths in enumerate(imgList):
|
||||
try:
|
||||
# Process images in batches
|
||||
for batch in self.group_image(image_paths):
|
||||
# Get features for batch
|
||||
features = self.inference(batch, self.model)
|
||||
|
||||
# Process each feature in batch
|
||||
for feat in features:
|
||||
# Move to CPU and convert to numpy
|
||||
feat_np = feat.squeeze().detach().cpu().numpy()
|
||||
|
||||
# Normalize first 256 dimensions
|
||||
normalized = self.normalize_256(feat_np[:256])
|
||||
|
||||
# Combine with remaining dimensions
|
||||
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
|
||||
|
||||
featList[index].append(combined)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing images for index {index}: {str(e)}")
|
||||
continue
|
||||
return featList
|
||||
|
||||
def get_files(
|
||||
self,
|
||||
folder: str,
|
||||
filter: Optional[List[str]] = None,
|
||||
create_single_json: bool = False
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Recursively collect image files from directory structure.
|
||||
|
||||
Args:
|
||||
folder: Root directory to scan
|
||||
filter: Optional list of barcodes to include
|
||||
create_single_json: Whether to create individual JSON files per barcode
|
||||
|
||||
Returns:
|
||||
Dictionary mapping barcode names to lists of image paths
|
||||
|
||||
Example:
|
||||
{
|
||||
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
|
||||
"barcode2": ["path/to/img3.jpg"]
|
||||
}
|
||||
"""
|
||||
file_dicts = {}
|
||||
total_files = 0
|
||||
feature_counts = []
|
||||
barcode_count = 0
|
||||
subclass = [str(i) for i in range(100)]
|
||||
# Validate input directory
|
||||
if not os.path.isdir(folder):
|
||||
raise ValueError(f"Invalid directory: {folder}")
|
||||
|
||||
# Process each barcode directory
|
||||
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||||
if not dirs: # Leaf directory (contains images)
|
||||
basename = os.path.basename(root)
|
||||
if basename in subclass:
|
||||
ori_barcode = root.split('/')[-2]
|
||||
barcode = root.split('/')[-2] + '_' + basename
|
||||
else:
|
||||
ori_barcode = basename
|
||||
barcode = basename
|
||||
# Apply filter if provided
|
||||
if filter and ori_barcode not in filter:
|
||||
continue
|
||||
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||||
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||||
with open(conf['save']['error_barcodes'], 'a') as f:
|
||||
f.write(ori_barcode + '\n')
|
||||
f.close()
|
||||
continue
|
||||
|
||||
# Process image files
|
||||
if files:
|
||||
image_paths = self._process_image_files(root, files)
|
||||
if not image_paths:
|
||||
continue
|
||||
|
||||
# Update counters
|
||||
barcode_count += 1
|
||||
file_count = len(image_paths)
|
||||
total_files += file_count
|
||||
feature_counts.append(file_count)
|
||||
|
||||
# Handle output mode
|
||||
if create_single_json:
|
||||
self._process_single_barcode(barcode, image_paths)
|
||||
else:
|
||||
if barcode.split('_')[-1] == '0':
|
||||
barcode = barcode.split('_')[0]
|
||||
file_dicts[barcode] = image_paths
|
||||
|
||||
# # Log summary
|
||||
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
|
||||
# logger.debug(f"Image counts per barcode: {feature_counts}")
|
||||
|
||||
# Batch process if not creating individual JSONs
|
||||
if not create_single_json and file_dicts:
|
||||
self.createFeatureDict(
|
||||
file_dicts,
|
||||
create_single_json=False,
|
||||
)
|
||||
return file_dicts
|
||||
|
||||
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
|
||||
"""Process and validate image files in a directory."""
|
||||
valid_paths = []
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
try:
|
||||
# Convert RGBA to RGB if needed
|
||||
self.convert_rgba_to_rgb(file_path)
|
||||
valid_paths.append(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
|
||||
return valid_paths
|
||||
|
||||
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
|
||||
"""Process a single barcode and create individual JSON file."""
|
||||
temp_dict = {barcode: image_paths}
|
||||
self.createFeatureDict(
|
||||
temp_dict,
|
||||
create_single_json=True,
|
||||
)
|
||||
|
||||
def normalize_256(self, queFeatList):
|
||||
queFeatList = queFeatList / np.linalg.norm(queFeatList)
|
||||
return queFeatList
|
||||
|
||||
def img2feature(
|
||||
self,
|
||||
imgs_dict: Dict[str, List[str]]
|
||||
) -> Tuple[List[str], List[List[np.ndarray]]]:
|
||||
"""
|
||||
Extract features for all images in the dictionary.
|
||||
|
||||
Args:
|
||||
imgs_dict: Dictionary mapping barcodes to image paths
|
||||
model: Pretrained feature extraction model
|
||||
barcode_flag: Whether to include barcode info (unused)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of barcode IDs
|
||||
- List of feature lists (one per barcode)
|
||||
|
||||
Raises:
|
||||
ValueError: If input dictionary is empty
|
||||
RuntimeError: If feature extraction fails
|
||||
"""
|
||||
if not imgs_dict:
|
||||
raise ValueError("No images provided for feature extraction")
|
||||
|
||||
try:
|
||||
barcode_list = list(imgs_dict.keys())
|
||||
image_list = list(imgs_dict.values())
|
||||
feature_list = self.getFeatureList(barcode_list, image_list)
|
||||
|
||||
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
|
||||
return barcode_list, feature_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Feature extraction failed: {str(e)}")
|
||||
raise RuntimeError(f"Feature extraction failed: {str(e)}")
|
||||
|
||||
def createFeatureDict(self, imgs_dict,
|
||||
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||||
dicts_all = {}
|
||||
value_list = []
|
||||
barcode_list, imgs_list = self.img2feature(imgs_dict)
|
||||
for i in range(len(barcode_list)):
|
||||
dicts = {}
|
||||
|
||||
imgs_list_ = []
|
||||
for j in range(len(imgs_list[i])):
|
||||
imgs_list_.append(imgs_list[i][j].tolist())
|
||||
|
||||
dicts['key'] = barcode_list[i]
|
||||
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
|
||||
dicts['value'] = truncated_imgs_list
|
||||
if create_single_json:
|
||||
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||||
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||||
with open(json_path, 'w') as json_file:
|
||||
json.dump(dicts, json_file)
|
||||
else:
|
||||
value_list.append(dicts)
|
||||
if not create_single_json:
|
||||
dicts_all['total'] = value_list
|
||||
with open(self.conf['save']['json_bin'], 'w') as json_file:
|
||||
json.dump(dicts_all, json_file)
|
||||
self.create_binary_files(self.conf['save']['json_bin'])
|
||||
|
||||
def statisticsBarcodes(self, pth, filter=None):
|
||||
feature_num = 0
|
||||
feature_num_lists = []
|
||||
nn = 0
|
||||
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||||
for barcode in os.listdir(pth):
|
||||
print("barcode length >> {}".format(len(barcode)))
|
||||
if len(barcode) > 13 or len(barcode) < 8:
|
||||
continue
|
||||
if filter is not None:
|
||||
f.writelines(barcode + '\n')
|
||||
if barcode in filter:
|
||||
print(barcode)
|
||||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||
nn += 1
|
||||
else:
|
||||
print('barcode name >>{}'.format(barcode))
|
||||
f.writelines(barcode + '\n')
|
||||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||||
feature_num_lists.append(feature_num)
|
||||
print("特征总量: {}".format(feature_num))
|
||||
print("barcode总量: {}".format(nn))
|
||||
f.close()
|
||||
|
||||
def get_shop_barcodes(self, file_path):
|
||||
if file_path:
|
||||
df = pd.read_excel(file_path)
|
||||
column_values = list(df.iloc[:, 6].values)
|
||||
column_values = list(map(str, column_values))
|
||||
return column_values
|
||||
else:
|
||||
return None
|
||||
|
||||
def del_base_dir(self, pth):
|
||||
for root, dirs, files in os.walk(pth):
|
||||
if len(dirs) == 1:
|
||||
if dirs[0] == 'base':
|
||||
shutil.rmtree(os.path.join(root, dirs[0]))
|
||||
|
||||
def write_binary_file(self, filename, datas):
|
||||
with open(filename, 'wb') as f:
|
||||
# 先写入数据中的key数量(为C++读取提供便利)
|
||||
key_count = len(datas)
|
||||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||||
for data in datas:
|
||||
key = data['key']
|
||||
feats = data['value']
|
||||
key_bytes = key.encode('utf-8')
|
||||
key_len = len(key)
|
||||
length_byte = struct.pack('<B', key_len)
|
||||
f.write(length_byte)
|
||||
# f.write(struct.pack('Q', len(key_bytes)))
|
||||
f.write(key_bytes)
|
||||
value_count = len(feats)
|
||||
f.write(struct.pack('I', (value_count * 256)))
|
||||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||||
for values in feats:
|
||||
# 写入每个浮点数值(保留小数点后六位)
|
||||
for value in values:
|
||||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||||
value_half = np.float16(value)
|
||||
# print(value_half.tobytes())
|
||||
f.write(value_half.tobytes())
|
||||
|
||||
def create_binary_file(self, json_path, flag=True):
|
||||
# 1. 打开JSON文件
|
||||
with open(json_path, 'r', encoding='utf-8') as file:
|
||||
# 2. 读取并解析JSON文件内容
|
||||
data = json.load(file)
|
||||
if flag:
|
||||
for flag, values in data.items():
|
||||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||||
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
|
||||
else:
|
||||
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||||
|
||||
def create_binary_files(self, index_file_pth):
|
||||
if os.path.isfile(index_file_pth):
|
||||
self.create_binary_file(index_file_pth)
|
||||
else:
|
||||
for name in os.listdir(index_file_pth):
|
||||
jsonpth = os.sep.join([index_file_pth, name])
|
||||
self.create_binary_file(jsonpth, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open('../configs/write_feature.yml', 'r') as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
###将图片名称和模型推理特征向量字典存为json文件
|
||||
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
|
||||
# xlsx_pth = None
|
||||
# del_base_dir(mg_path)
|
||||
|
||||
extractor = FeatureExtractor(conf)
|
||||
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
||||
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
||||
filter=column_values,
|
||||
create_single_json=False) # False
|
||||
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|
Reference in New Issue
Block a user