Files
ieemoo-ai-contrast/tools/write_feature_json.py
2025-08-14 10:09:54 +08:00

426 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
from tools.dataset import get_transform
from model import resnet18, resnet34, resnet50, resnet101
import torch
from PIL import Image
import pandas as pd
from tqdm import tqdm
import yaml
import shutil
import struct
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class FeatureExtractor:
def __init__(self, conf):
self.conf = conf
self.model = self.initModel()
_, self.test_transform = get_transform(self.conf)
pass
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
"""
Initialize and load the ResNet18 model for inference.
Args:
inference_model: Optional path to model weights. Uses conf.test_model if None.
Returns:
Loaded and configured PyTorch model in evaluation mode.
Raises:
FileNotFoundError: If model weights file is not found
RuntimeError: If model loading fails
"""
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
try:
# Verify model file exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model weights file not found: {model_path}")
# Initialize model
if conf['models']['backbone'] == 'resnet18':
model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet34':
model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet50':
model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
elif conf['models']['backbone'] == 'resnet101':
model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device'])
else:
print("不支持的模型: {}".format(conf['models']['backbone']))
# Handle multi-GPU case
if conf['base']['distributed']:
model = torch.nn.DataParallel(model)
# Load weights
state_dict = torch.load(model_path, map_location=conf['base']['device'])
model.load_state_dict(state_dict)
model.eval()
logger.info(f"Successfully loaded model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
raise
def convert_rgba_to_rgb(self, image_path):
# 打开图像
img = Image.open(image_path)
# 转换图像模式从RGBA到RGB
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
if img.mode == 'RGBA':
# 转换为RGB模式
img_rgb = img.convert('RGB')
# 保存转换后的图像
img_rgb.save(image_path)
print(f"Image converted from RGBA to RGB and saved to {image_path}")
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
res = []
for img in images:
try:
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(self, images, model, actionModel=False):
data = self.test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf['base']['device'])
features = model(data)
if conf['data']['half']:
features = features.half()
return features
def group_image(self, images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def getFeatureList(self, barList, imgList):
featList = [[] for _ in range(len(barList))]
for index, image_paths in enumerate(imgList):
try:
# Process images in batches
for batch in self.group_image(image_paths):
# Get features for batch
features = self.inference(batch, self.model)
# Process each feature in batch
for feat in features:
# Move to CPU and convert to numpy
feat_np = feat.squeeze().detach().cpu().numpy()
# Normalize first 256 dimensions
normalized = self.normalize_256(feat_np[:256])
# Combine with remaining dimensions
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
featList[index].append(combined)
except Exception as e:
logger.error(f"Error processing images for index {index}: {str(e)}")
continue
return featList
def get_files(
self,
folder: str,
filter: Optional[List[str]] = None,
create_single_json: bool = False
) -> Dict[str, List[str]]:
"""
Recursively collect image files from directory structure.
Args:
folder: Root directory to scan
filter: Optional list of barcodes to include
create_single_json: Whether to create individual JSON files per barcode
Returns:
Dictionary mapping barcode names to lists of image paths
Example:
{
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
"barcode2": ["path/to/img3.jpg"]
}
"""
file_dicts = {}
total_files = 0
feature_counts = []
barcode_count = 0
subclass = [str(i) for i in range(100)]
# Validate input directory
if not os.path.isdir(folder):
raise ValueError(f"Invalid directory: {folder}")
i = 0
# Process each barcode directory
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
if not dirs: # Leaf directory (contains images)
basename = os.path.basename(root)
if basename in subclass:
ori_barcode = root.split('/')[-2]
barcode = root.split('/')[-2] + '_' + basename
else:
ori_barcode = basename
barcode = basename
# Apply filter if provided
i += 1
print(ori_barcode, i)
if filter and ori_barcode not in filter:
continue
# elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度
# logger.warning(f"Skipping invalid barcode {ori_barcode}")
# with open(conf['save']['error_barcodes'], 'a') as f:
# f.write(ori_barcode + '\n')
# f.close()
# continue
# Process image files
if files:
image_paths = self._process_image_files(root, files)
if not image_paths:
continue
# Update counters
barcode_count += 1
file_count = len(image_paths)
total_files += file_count
feature_counts.append(file_count)
# Handle output mode
if create_single_json:
self._process_single_barcode(barcode, image_paths)
else:
if barcode.split('_')[-1] == '0':
barcode = barcode.split('_')[0]
file_dicts[barcode] = image_paths
# # Log summary
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
# logger.debug(f"Image counts per barcode: {feature_counts}")
# Batch process if not creating individual JSONs
if not create_single_json and file_dicts:
self.createFeatureDict(
file_dicts,
create_single_json=False,
)
return file_dicts
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
"""Process and validate image files in a directory."""
valid_paths = []
for filename in files:
file_path = os.path.join(root, filename)
try:
# Convert RGBA to RGB if needed
self.convert_rgba_to_rgb(file_path)
valid_paths.append(file_path)
except Exception as e:
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
return valid_paths
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
"""Process a single barcode and create individual JSON file."""
temp_dict = {barcode: image_paths}
self.createFeatureDict(
temp_dict,
create_single_json=True,
)
def normalize_256(self, queFeatList):
queFeatList = queFeatList / np.linalg.norm(queFeatList)
return queFeatList
def img2feature(
self,
imgs_dict: Dict[str, List[str]]
) -> Tuple[List[str], List[List[np.ndarray]]]:
"""
Extract features for all images in the dictionary.
Args:
imgs_dict: Dictionary mapping barcodes to image paths
model: Pretrained feature extraction model
barcode_flag: Whether to include barcode info (unused)
Returns:
Tuple containing:
- List of barcode IDs
- List of feature lists (one per barcode)
Raises:
ValueError: If input dictionary is empty
RuntimeError: If feature extraction fails
"""
if not imgs_dict:
raise ValueError("No images provided for feature extraction")
try:
barcode_list = list(imgs_dict.keys())
image_list = list(imgs_dict.values())
feature_list = self.getFeatureList(barcode_list, image_list)
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
return barcode_list, feature_list
except Exception as e:
logger.error(f"Feature extraction failed: {str(e)}")
raise RuntimeError(f"Feature extraction failed: {str(e)}")
def createFeatureDict(self, imgs_dict,
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
dicts_all = {}
value_list = []
barcode_list, imgs_list = self.img2feature(imgs_dict)
for i in range(len(barcode_list)):
dicts = {}
imgs_list_ = []
for j in range(len(imgs_list[i])):
imgs_list_.append(imgs_list[i][j].tolist())
dicts['key'] = barcode_list[i]
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
dicts['value'] = truncated_imgs_list
if create_single_json:
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
json_path = os.path.join(self.conf['save']['json_path'],
str(barcode_list[i]) + '.json')
with open(json_path, 'w') as json_file:
json.dump(dicts, json_file)
else:
value_list.append(dicts)
if not create_single_json:
dicts_all['total'] = value_list
with open(self.conf['save']['json_bin'], 'w') as json_file:
json.dump(dicts_all, json_file)
self.create_binary_files(self.conf['save']['json_bin'])
def statisticsBarcodes(self, pth, filter=None):
feature_num = 0
feature_num_lists = []
nn = 0
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
for barcode in os.listdir(pth):
print("barcode length >> {}".format(len(barcode)))
# if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度
# continue
if filter is not None:
f.writelines(barcode + '\n')
if barcode in filter:
print(barcode)
feature_num += len(os.listdir(os.path.join(pth, barcode)))
nn += 1
else:
print('barcode name >>{}'.format(barcode))
f.writelines(barcode + '\n')
feature_num += len(os.listdir(os.path.join(pth, barcode)))
feature_num_lists.append(feature_num)
print("特征总量: {}".format(feature_num))
print("barcode总量 {}".format(nn))
f.close()
def get_shop_barcodes(self, file_path):
if file_path:
df = pd.read_excel(file_path)
column_values = list(df.iloc[:, 6].values)
column_values = list(map(str, column_values))
return column_values
else:
return None
def del_base_dir(self, pth):
for root, dirs, files in os.walk(pth):
if len(dirs) == 1:
if dirs[0] == 'base':
shutil.rmtree(os.path.join(root, dirs[0]))
def write_binary_file(self, filename, datas):
with open(filename, 'wb') as f:
# 先写入数据中的key数量为C++读取提供便利)
key_count = len(datas)
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型4字节
for data in datas:
key = data['key']
feats = data['value']
key_bytes = key.encode('utf-8')
key_len = len(key)
length_byte = struct.pack('<B', key_len)
f.write(length_byte)
# f.write(struct.pack('Q', len(key_bytes)))
f.write(key_bytes)
value_count = len(feats)
f.write(struct.pack('I', (value_count * 256)))
# 遍历字典写入每个key及其对应的浮点数值列表
for values in feats:
# 写入每个浮点数值(保留小数点后六位)
for value in values:
# 使用'f'格式单精度浮点4字节并四舍五入保留六位小数
value_half = np.float16(value)
# print(value_half.tobytes())
f.write(value_half.tobytes())
def create_binary_file(self, json_path, flag=True):
# 1. 打开JSON文件
with open(json_path, 'r', encoding='utf-8') as file:
# 2. 读取并解析JSON文件内容
data = json.load(file)
if flag:
for flag, values in data.items():
# 逐个写入values中的每个值保留小数点后六位每个值占一行
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
else:
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
def create_binary_files(self, index_file_pth):
if os.path.isfile(index_file_pth):
self.create_binary_file(index_file_pth)
else:
for name in os.listdir(index_file_pth):
jsonpth = os.sep.join([index_file_pth, name])
self.create_binary_file(jsonpth, False)
if __name__ == "__main__":
with open('../configs/write_feature.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
###将图片名称和模型推理特征向量字典存为json文件
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
# xlsx_pth = None
# del_base_dir(mg_path)
extractor = FeatureExtractor(conf)
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
filter=column_values,
create_single_json=conf['save']['create_single_json']) # False
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)