This commit is contained in:
lee
2025-06-11 15:23:50 +08:00
commit 37ecef40f7
79 changed files with 26981 additions and 0 deletions

411
tools/write_feature_json.py Normal file
View File

@ -0,0 +1,411 @@
import json
import os
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
from tools.dataset import get_transform
from model import resnet18
import torch
from PIL import Image
import pandas as pd
from tqdm import tqdm
import yaml
import shutil
import struct
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class FeatureExtractor:
def __init__(self, conf):
self.conf = conf
self.model = self.initModel()
_, self.test_transform = get_transform(self.conf)
pass
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
"""
Initialize and load the ResNet18 model for inference.
Args:
inference_model: Optional path to model weights. Uses conf.test_model if None.
Returns:
Loaded and configured PyTorch model in evaluation mode.
Raises:
FileNotFoundError: If model weights file is not found
RuntimeError: If model loading fails
"""
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
try:
# Verify model file exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model
model = resnet18().to(self.conf['base']['device'])
# Handle multi-GPU case
if conf['base']['distributed']:
model = torch.nn.DataParallel(model)
# Load weights
state_dict = torch.load(model_path, map_location=conf['base']['device'])
model.load_state_dict(state_dict)
model.eval()
logger.info(f"Successfully loaded model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
raise
def convert_rgba_to_rgb(self, image_path):
# 打开图像
img = Image.open(image_path)
# 转换图像模式从RGBA到RGB
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
if img.mode == 'RGBA':
# 转换为RGB模式
img_rgb = img.convert('RGB')
# 保存转换后的图像
img_rgb.save(image_path)
print(f"Image converted from RGBA to RGB and saved to {image_path}")
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
res = []
for img in images:
try:
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(self, images, model, actionModel=False):
data = self.test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf['base']['device'])
features = model(data)
if conf['data']['half']:
features = features.half()
return features
def group_image(self, images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def getFeatureList(self, barList, imgList):
featList = [[] for _ in range(len(barList))]
for index, image_paths in enumerate(imgList):
try:
# Process images in batches
for batch in self.group_image(image_paths):
# Get features for batch
features = self.inference(batch, self.model)
# Process each feature in batch
for feat in features:
# Move to CPU and convert to numpy
feat_np = feat.squeeze().detach().cpu().numpy()
# Normalize first 256 dimensions
normalized = self.normalize_256(feat_np[:256])
# Combine with remaining dimensions
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
featList[index].append(combined)
except Exception as e:
logger.error(f"Error processing images for index {index}: {str(e)}")
continue
return featList
def get_files(
self,
folder: str,
filter: Optional[List[str]] = None,
create_single_json: bool = False
) -> Dict[str, List[str]]:
"""
Recursively collect image files from directory structure.
Args:
folder: Root directory to scan
filter: Optional list of barcodes to include
create_single_json: Whether to create individual JSON files per barcode
Returns:
Dictionary mapping barcode names to lists of image paths
Example:
{
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
"barcode2": ["path/to/img3.jpg"]
}
"""
file_dicts = {}
total_files = 0
feature_counts = []
barcode_count = 0
subclass = [str(i) for i in range(100)]
# Validate input directory
if not os.path.isdir(folder):
raise ValueError(f"Invalid directory: {folder}")
# Process each barcode directory
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
if not dirs: # Leaf directory (contains images)
basename = os.path.basename(root)
if basename in subclass:
ori_barcode = root.split('/')[-2]
barcode = root.split('/')[-2] + '_' + basename
else:
ori_barcode = basename
barcode = basename
# Apply filter if provided
if filter and ori_barcode not in filter:
continue
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
logger.warning(f"Skipping invalid barcode {ori_barcode}")
with open(conf['save']['error_barcodes'], 'a') as f:
f.write(ori_barcode + '\n')
f.close()
continue
# Process image files
if files:
image_paths = self._process_image_files(root, files)
if not image_paths:
continue
# Update counters
barcode_count += 1
file_count = len(image_paths)
total_files += file_count
feature_counts.append(file_count)
# Handle output mode
if create_single_json:
self._process_single_barcode(barcode, image_paths)
else:
if barcode.split('_')[-1] == '0':
barcode = barcode.split('_')[0]
file_dicts[barcode] = image_paths
# # Log summary
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
# logger.debug(f"Image counts per barcode: {feature_counts}")
# Batch process if not creating individual JSONs
if not create_single_json and file_dicts:
self.createFeatureDict(
file_dicts,
create_single_json=False,
)
return file_dicts
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
"""Process and validate image files in a directory."""
valid_paths = []
for filename in files:
file_path = os.path.join(root, filename)
try:
# Convert RGBA to RGB if needed
self.convert_rgba_to_rgb(file_path)
valid_paths.append(file_path)
except Exception as e:
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
return valid_paths
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
"""Process a single barcode and create individual JSON file."""
temp_dict = {barcode: image_paths}
self.createFeatureDict(
temp_dict,
create_single_json=True,
)
def normalize_256(self, queFeatList):
queFeatList = queFeatList / np.linalg.norm(queFeatList)
return queFeatList
def img2feature(
self,
imgs_dict: Dict[str, List[str]]
) -> Tuple[List[str], List[List[np.ndarray]]]:
"""
Extract features for all images in the dictionary.
Args:
imgs_dict: Dictionary mapping barcodes to image paths
model: Pretrained feature extraction model
barcode_flag: Whether to include barcode info (unused)
Returns:
Tuple containing:
- List of barcode IDs
- List of feature lists (one per barcode)
Raises:
ValueError: If input dictionary is empty
RuntimeError: If feature extraction fails
"""
if not imgs_dict:
raise ValueError("No images provided for feature extraction")
try:
barcode_list = list(imgs_dict.keys())
image_list = list(imgs_dict.values())
feature_list = self.getFeatureList(barcode_list, image_list)
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
return barcode_list, feature_list
except Exception as e:
logger.error(f"Feature extraction failed: {str(e)}")
raise RuntimeError(f"Feature extraction failed: {str(e)}")
def createFeatureDict(self, imgs_dict,
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
dicts_all = {}
value_list = []
barcode_list, imgs_list = self.img2feature(imgs_dict)
for i in range(len(barcode_list)):
dicts = {}
imgs_list_ = []
for j in range(len(imgs_list[i])):
imgs_list_.append(imgs_list[i][j].tolist())
dicts['key'] = barcode_list[i]
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
dicts['value'] = truncated_imgs_list
if create_single_json:
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
with open(json_path, 'w') as json_file:
json.dump(dicts, json_file)
else:
value_list.append(dicts)
if not create_single_json:
dicts_all['total'] = value_list
with open(self.conf['save']['json_bin'], 'w') as json_file:
json.dump(dicts_all, json_file)
self.create_binary_files(self.conf['save']['json_bin'])
def statisticsBarcodes(self, pth, filter=None):
feature_num = 0
feature_num_lists = []
nn = 0
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
for barcode in os.listdir(pth):
print("barcode length >> {}".format(len(barcode)))
if len(barcode) > 13 or len(barcode) < 8:
continue
if filter is not None:
f.writelines(barcode + '\n')
if barcode in filter:
print(barcode)
feature_num += len(os.listdir(os.path.join(pth, barcode)))
nn += 1
else:
print('barcode name >>{}'.format(barcode))
f.writelines(barcode + '\n')
feature_num += len(os.listdir(os.path.join(pth, barcode)))
feature_num_lists.append(feature_num)
print("特征总量: {}".format(feature_num))
print("barcode总量 {}".format(nn))
f.close()
def get_shop_barcodes(self, file_path):
if file_path:
df = pd.read_excel(file_path)
column_values = list(df.iloc[:, 6].values)
column_values = list(map(str, column_values))
return column_values
else:
return None
def del_base_dir(self, pth):
for root, dirs, files in os.walk(pth):
if len(dirs) == 1:
if dirs[0] == 'base':
shutil.rmtree(os.path.join(root, dirs[0]))
def write_binary_file(self, filename, datas):
with open(filename, 'wb') as f:
# 先写入数据中的key数量为C++读取提供便利)
key_count = len(datas)
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型4字节
for data in datas:
key = data['key']
feats = data['value']
key_bytes = key.encode('utf-8')
key_len = len(key)
length_byte = struct.pack('<B', key_len)
f.write(length_byte)
# f.write(struct.pack('Q', len(key_bytes)))
f.write(key_bytes)
value_count = len(feats)
f.write(struct.pack('I', (value_count * 256)))
# 遍历字典写入每个key及其对应的浮点数值列表
for values in feats:
# 写入每个浮点数值(保留小数点后六位)
for value in values:
# 使用'f'格式单精度浮点4字节并四舍五入保留六位小数
value_half = np.float16(value)
# print(value_half.tobytes())
f.write(value_half.tobytes())
def create_binary_file(self, json_path, flag=True):
# 1. 打开JSON文件
with open(json_path, 'r', encoding='utf-8') as file:
# 2. 读取并解析JSON文件内容
data = json.load(file)
if flag:
for flag, values in data.items():
# 逐个写入values中的每个值保留小数点后六位每个值占一行
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
else:
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
def create_binary_files(self, index_file_pth):
if os.path.isfile(index_file_pth):
self.create_binary_file(index_file_pth)
else:
for name in os.listdir(index_file_pth):
jsonpth = os.sep.join([index_file_pth, name])
self.create_binary_file(jsonpth, False)
if __name__ == "__main__":
with open('../configs/write_feature.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
###将图片名称和模型推理特征向量字典存为json文件
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
# xlsx_pth = None
# del_base_dir(mg_path)
extractor = FeatureExtractor(conf)
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
filter=column_values,
create_single_json=False) # False
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)