412 lines
16 KiB
Python
412 lines
16 KiB
Python
import json
|
||
import os
|
||
import logging
|
||
import numpy as np
|
||
from typing import Dict, List, Optional, Tuple
|
||
from tools.dataset import get_transform
|
||
from model import resnet18
|
||
import torch
|
||
from PIL import Image
|
||
import pandas as pd
|
||
from tqdm import tqdm
|
||
import yaml
|
||
import shutil
|
||
import struct
|
||
|
||
# Configure logging
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class FeatureExtractor:
|
||
def __init__(self, conf):
|
||
self.conf = conf
|
||
self.model = self.initModel()
|
||
_, self.test_transform = get_transform(self.conf)
|
||
pass
|
||
|
||
def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module:
|
||
"""
|
||
Initialize and load the ResNet18 model for inference.
|
||
|
||
Args:
|
||
inference_model: Optional path to model weights. Uses conf.test_model if None.
|
||
|
||
Returns:
|
||
Loaded and configured PyTorch model in evaluation mode.
|
||
|
||
Raises:
|
||
FileNotFoundError: If model weights file is not found
|
||
RuntimeError: If model loading fails
|
||
"""
|
||
model_path = inference_model if inference_model else self.conf['models']['checkpoints']
|
||
|
||
try:
|
||
# Verify model file exists
|
||
if not os.path.exists(model_path):
|
||
raise FileNotFoundError(f"Model weights file not found: {model_path}")
|
||
|
||
# Initialize model
|
||
model = resnet18().to(self.conf['base']['device'])
|
||
|
||
# Handle multi-GPU case
|
||
if conf['base']['distributed']:
|
||
model = torch.nn.DataParallel(model)
|
||
|
||
# Load weights
|
||
state_dict = torch.load(model_path, map_location=conf['base']['device'])
|
||
model.load_state_dict(state_dict)
|
||
|
||
model.eval()
|
||
logger.info(f"Successfully loaded model from {model_path}")
|
||
return model
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize model: {str(e)}")
|
||
raise
|
||
|
||
def convert_rgba_to_rgb(self, image_path):
|
||
# 打开图像
|
||
img = Image.open(image_path)
|
||
# 转换图像模式从RGBA到RGB
|
||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||
if img.mode == 'RGBA':
|
||
# 转换为RGB模式
|
||
img_rgb = img.convert('RGB')
|
||
# 保存转换后的图像
|
||
img_rgb.save(image_path)
|
||
print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||
|
||
def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor:
|
||
res = []
|
||
for img in images:
|
||
try:
|
||
im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img))
|
||
res.append(im)
|
||
except:
|
||
continue
|
||
data = torch.stack(res)
|
||
return data
|
||
|
||
def inference(self, images, model, actionModel=False):
|
||
data = self.test_preprocess(images, actionModel)
|
||
if torch.cuda.is_available():
|
||
data = data.to(conf['base']['device'])
|
||
features = model(data)
|
||
if conf['data']['half']:
|
||
features = features.half()
|
||
return features
|
||
|
||
def group_image(self, images, batch=64) -> list:
|
||
"""Group image paths by batch size"""
|
||
size = len(images)
|
||
res = []
|
||
for i in range(0, size, batch):
|
||
end = min(batch + i, size)
|
||
res.append(images[i:end])
|
||
return res
|
||
|
||
def getFeatureList(self, barList, imgList):
|
||
featList = [[] for _ in range(len(barList))]
|
||
|
||
for index, image_paths in enumerate(imgList):
|
||
try:
|
||
# Process images in batches
|
||
for batch in self.group_image(image_paths):
|
||
# Get features for batch
|
||
features = self.inference(batch, self.model)
|
||
|
||
# Process each feature in batch
|
||
for feat in features:
|
||
# Move to CPU and convert to numpy
|
||
feat_np = feat.squeeze().detach().cpu().numpy()
|
||
|
||
# Normalize first 256 dimensions
|
||
normalized = self.normalize_256(feat_np[:256])
|
||
|
||
# Combine with remaining dimensions
|
||
combined = np.concatenate([normalized, feat_np[256:]], axis=0)
|
||
|
||
featList[index].append(combined)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing images for index {index}: {str(e)}")
|
||
continue
|
||
return featList
|
||
|
||
def get_files(
|
||
self,
|
||
folder: str,
|
||
filter: Optional[List[str]] = None,
|
||
create_single_json: bool = False
|
||
) -> Dict[str, List[str]]:
|
||
"""
|
||
Recursively collect image files from directory structure.
|
||
|
||
Args:
|
||
folder: Root directory to scan
|
||
filter: Optional list of barcodes to include
|
||
create_single_json: Whether to create individual JSON files per barcode
|
||
|
||
Returns:
|
||
Dictionary mapping barcode names to lists of image paths
|
||
|
||
Example:
|
||
{
|
||
"barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"],
|
||
"barcode2": ["path/to/img3.jpg"]
|
||
}
|
||
"""
|
||
file_dicts = {}
|
||
total_files = 0
|
||
feature_counts = []
|
||
barcode_count = 0
|
||
subclass = [str(i) for i in range(100)]
|
||
# Validate input directory
|
||
if not os.path.isdir(folder):
|
||
raise ValueError(f"Invalid directory: {folder}")
|
||
|
||
# Process each barcode directory
|
||
for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"):
|
||
if not dirs: # Leaf directory (contains images)
|
||
basename = os.path.basename(root)
|
||
if basename in subclass:
|
||
ori_barcode = root.split('/')[-2]
|
||
barcode = root.split('/')[-2] + '_' + basename
|
||
else:
|
||
ori_barcode = basename
|
||
barcode = basename
|
||
# Apply filter if provided
|
||
if filter and ori_barcode not in filter:
|
||
continue
|
||
elif len(ori_barcode) > 13 or len(ori_barcode) < 8:
|
||
logger.warning(f"Skipping invalid barcode {ori_barcode}")
|
||
with open(conf['save']['error_barcodes'], 'a') as f:
|
||
f.write(ori_barcode + '\n')
|
||
f.close()
|
||
continue
|
||
|
||
# Process image files
|
||
if files:
|
||
image_paths = self._process_image_files(root, files)
|
||
if not image_paths:
|
||
continue
|
||
|
||
# Update counters
|
||
barcode_count += 1
|
||
file_count = len(image_paths)
|
||
total_files += file_count
|
||
feature_counts.append(file_count)
|
||
|
||
# Handle output mode
|
||
if create_single_json:
|
||
self._process_single_barcode(barcode, image_paths)
|
||
else:
|
||
if barcode.split('_')[-1] == '0':
|
||
barcode = barcode.split('_')[0]
|
||
file_dicts[barcode] = image_paths
|
||
|
||
# # Log summary
|
||
# logger.info(f"Processed {barcode_count} barcodes with {total_files} total images")
|
||
# logger.debug(f"Image counts per barcode: {feature_counts}")
|
||
|
||
# Batch process if not creating individual JSONs
|
||
if not create_single_json and file_dicts:
|
||
self.createFeatureDict(
|
||
file_dicts,
|
||
create_single_json=False,
|
||
)
|
||
return file_dicts
|
||
|
||
def _process_image_files(self, root: str, files: List[str]) -> List[str]:
|
||
"""Process and validate image files in a directory."""
|
||
valid_paths = []
|
||
for filename in files:
|
||
file_path = os.path.join(root, filename)
|
||
try:
|
||
# Convert RGBA to RGB if needed
|
||
self.convert_rgba_to_rgb(file_path)
|
||
valid_paths.append(file_path)
|
||
except Exception as e:
|
||
logger.warning(f"Skipping invalid image {file_path}: {str(e)}")
|
||
return valid_paths
|
||
|
||
def _process_single_barcode(self, barcode: str, image_paths: List[str]):
|
||
"""Process a single barcode and create individual JSON file."""
|
||
temp_dict = {barcode: image_paths}
|
||
self.createFeatureDict(
|
||
temp_dict,
|
||
create_single_json=True,
|
||
)
|
||
|
||
def normalize_256(self, queFeatList):
|
||
queFeatList = queFeatList / np.linalg.norm(queFeatList)
|
||
return queFeatList
|
||
|
||
def img2feature(
|
||
self,
|
||
imgs_dict: Dict[str, List[str]]
|
||
) -> Tuple[List[str], List[List[np.ndarray]]]:
|
||
"""
|
||
Extract features for all images in the dictionary.
|
||
|
||
Args:
|
||
imgs_dict: Dictionary mapping barcodes to image paths
|
||
model: Pretrained feature extraction model
|
||
barcode_flag: Whether to include barcode info (unused)
|
||
|
||
Returns:
|
||
Tuple containing:
|
||
- List of barcode IDs
|
||
- List of feature lists (one per barcode)
|
||
|
||
Raises:
|
||
ValueError: If input dictionary is empty
|
||
RuntimeError: If feature extraction fails
|
||
"""
|
||
if not imgs_dict:
|
||
raise ValueError("No images provided for feature extraction")
|
||
|
||
try:
|
||
barcode_list = list(imgs_dict.keys())
|
||
image_list = list(imgs_dict.values())
|
||
feature_list = self.getFeatureList(barcode_list, image_list)
|
||
|
||
logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes")
|
||
return barcode_list, feature_list
|
||
|
||
except Exception as e:
|
||
logger.error(f"Feature extraction failed: {str(e)}")
|
||
raise RuntimeError(f"Feature extraction failed: {str(e)}")
|
||
|
||
def createFeatureDict(self, imgs_dict,
|
||
create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]}
|
||
dicts_all = {}
|
||
value_list = []
|
||
barcode_list, imgs_list = self.img2feature(imgs_dict)
|
||
for i in range(len(barcode_list)):
|
||
dicts = {}
|
||
|
||
imgs_list_ = []
|
||
for j in range(len(imgs_list[i])):
|
||
imgs_list_.append(imgs_list[i][j].tolist())
|
||
|
||
dicts['key'] = barcode_list[i]
|
||
truncated_imgs_list = [subarray[:256] for subarray in imgs_list_]
|
||
dicts['value'] = truncated_imgs_list
|
||
if create_single_json:
|
||
# json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json')
|
||
json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json')
|
||
with open(json_path, 'w') as json_file:
|
||
json.dump(dicts, json_file)
|
||
else:
|
||
value_list.append(dicts)
|
||
if not create_single_json:
|
||
dicts_all['total'] = value_list
|
||
with open(self.conf['save']['json_bin'], 'w') as json_file:
|
||
json.dump(dicts_all, json_file)
|
||
self.create_binary_files(self.conf['save']['json_bin'])
|
||
|
||
def statisticsBarcodes(self, pth, filter=None):
|
||
feature_num = 0
|
||
feature_num_lists = []
|
||
nn = 0
|
||
with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f:
|
||
for barcode in os.listdir(pth):
|
||
print("barcode length >> {}".format(len(barcode)))
|
||
if len(barcode) > 13 or len(barcode) < 8:
|
||
continue
|
||
if filter is not None:
|
||
f.writelines(barcode + '\n')
|
||
if barcode in filter:
|
||
print(barcode)
|
||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||
nn += 1
|
||
else:
|
||
print('barcode name >>{}'.format(barcode))
|
||
f.writelines(barcode + '\n')
|
||
feature_num += len(os.listdir(os.path.join(pth, barcode)))
|
||
feature_num_lists.append(feature_num)
|
||
print("特征总量: {}".format(feature_num))
|
||
print("barcode总量: {}".format(nn))
|
||
f.close()
|
||
|
||
def get_shop_barcodes(self, file_path):
|
||
if file_path:
|
||
df = pd.read_excel(file_path)
|
||
column_values = list(df.iloc[:, 6].values)
|
||
column_values = list(map(str, column_values))
|
||
return column_values
|
||
else:
|
||
return None
|
||
|
||
def del_base_dir(self, pth):
|
||
for root, dirs, files in os.walk(pth):
|
||
if len(dirs) == 1:
|
||
if dirs[0] == 'base':
|
||
shutil.rmtree(os.path.join(root, dirs[0]))
|
||
|
||
def write_binary_file(self, filename, datas):
|
||
with open(filename, 'wb') as f:
|
||
# 先写入数据中的key数量(为C++读取提供便利)
|
||
key_count = len(datas)
|
||
f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节)
|
||
for data in datas:
|
||
key = data['key']
|
||
feats = data['value']
|
||
key_bytes = key.encode('utf-8')
|
||
key_len = len(key)
|
||
length_byte = struct.pack('<B', key_len)
|
||
f.write(length_byte)
|
||
# f.write(struct.pack('Q', len(key_bytes)))
|
||
f.write(key_bytes)
|
||
value_count = len(feats)
|
||
f.write(struct.pack('I', (value_count * 256)))
|
||
# 遍历字典,写入每个key及其对应的浮点数值列表
|
||
for values in feats:
|
||
# 写入每个浮点数值(保留小数点后六位)
|
||
for value in values:
|
||
# 使用'f'格式(单精度浮点,4字节),并四舍五入保留六位小数
|
||
value_half = np.float16(value)
|
||
# print(value_half.tobytes())
|
||
f.write(value_half.tobytes())
|
||
|
||
def create_binary_file(self, json_path, flag=True):
|
||
# 1. 打开JSON文件
|
||
with open(json_path, 'r', encoding='utf-8') as file:
|
||
# 2. 读取并解析JSON文件内容
|
||
data = json.load(file)
|
||
if flag:
|
||
for flag, values in data.items():
|
||
# 逐个写入values中的每个值,保留小数点后六位,每个值占一行
|
||
self.write_binary_file(self.conf['save']['json_bin'].replace('json', 'bin'), values)
|
||
else:
|
||
self.write_binary_file(json_path.replace('.json', '.bin'), [data])
|
||
|
||
def create_binary_files(self, index_file_pth):
|
||
if os.path.isfile(index_file_pth):
|
||
self.create_binary_file(index_file_pth)
|
||
else:
|
||
for name in os.listdir(index_file_pth):
|
||
jsonpth = os.sep.join([index_file_pth, name])
|
||
self.create_binary_file(jsonpth, False)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
with open('../configs/write_feature.yml', 'r') as f:
|
||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||
###将图片名称和模型推理特征向量字典存为json文件
|
||
# xlsx_pth = './shop_xlsx/曹家桥门店在售商品表.xlsx'
|
||
# xlsx_pth = None
|
||
# del_base_dir(mg_path)
|
||
|
||
extractor = FeatureExtractor(conf)
|
||
column_values = extractor.get_shop_barcodes(conf['data']['xlsx_pth'])
|
||
imgs_dict = extractor.get_files(conf['data']['img_dirs_path'],
|
||
filter=column_values,
|
||
create_single_json=False) # False
|
||
extractor.statisticsBarcodes(conf['data']['img_dirs_path'], column_values)
|