import json import os import logging import numpy as np from typing import Dict, List, Optional, Tuple from tools.dataset import get_transform from model import resnet18, resnet34, resnet50, resnet101 import torch from PIL import Image import pandas as pd from tqdm import tqdm import yaml import shutil import struct # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class FeatureExtractor: def __init__(self, conf): self.conf = conf self.model = self.initModel() _, self.test_transform = get_transform(self.conf) pass def initModel(self, inference_model: Optional[str] = None) -> torch.nn.Module: """ Initialize and load the ResNet18 model for inference. Args: inference_model: Optional path to model weights. Uses conf.test_model if None. Returns: Loaded and configured PyTorch model in evaluation mode. Raises: FileNotFoundError: If model weights file is not found RuntimeError: If model loading fails """ model_path = inference_model if inference_model else self.conf['models']['checkpoints'] try: # Verify model file exists if not os.path.exists(model_path): raise FileNotFoundError(f"Model weights file not found: {model_path}") # Initialize model if conf['models']['backbone'] == 'resnet18': model = resnet18(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device']) elif conf['models']['backbone'] == 'resnet34': model = resnet34(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device']) elif conf['models']['backbone'] == 'resnet50': model = resnet50(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device']) elif conf['models']['backbone'] == 'resnet101': model = resnet101(scale=self.conf['models']['channel_ratio']).to(self.conf['base']['device']) else: print("不支持的模型: {}".format(conf['models']['backbone'])) # Handle multi-GPU case if conf['base']['distributed']: model = torch.nn.DataParallel(model) # Load weights state_dict = torch.load(model_path, map_location=conf['base']['device']) model.load_state_dict(state_dict) model.eval() logger.info(f"Successfully loaded model from {model_path}") return model except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") raise def convert_rgba_to_rgb(self, image_path): # 打开图像 img = Image.open(image_path) # 转换图像模式从RGBA到RGB # .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像 if img.mode == 'RGBA': # 转换为RGB模式 img_rgb = img.convert('RGB') # 保存转换后的图像 img_rgb.save(image_path) print(f"Image converted from RGBA to RGB and saved to {image_path}") def test_preprocess(self, images: list, actionModel=False) -> torch.Tensor: res = [] for img in images: try: im = self.test_transform(img) if actionModel else self.test_transform(Image.open(img)) res.append(im) except: continue data = torch.stack(res) return data def inference(self, images, model, actionModel=False): data = self.test_preprocess(images, actionModel) if torch.cuda.is_available(): data = data.to(conf['base']['device']) features = model(data) if conf['data']['half']: features = features.half() return features def group_image(self, images, batch=64) -> list: """Group image paths by batch size""" size = len(images) res = [] for i in range(0, size, batch): end = min(batch + i, size) res.append(images[i:end]) return res def getFeatureList(self, barList, imgList): featList = [[] for _ in range(len(barList))] for index, image_paths in enumerate(imgList): try: # Process images in batches for batch in self.group_image(image_paths): # Get features for batch features = self.inference(batch, self.model) # Process each feature in batch for feat in features: # Move to CPU and convert to numpy feat_np = feat.squeeze().detach().cpu().numpy() # Normalize first 256 dimensions normalized = self.normalize_256(feat_np[:256]) # Combine with remaining dimensions combined = np.concatenate([normalized, feat_np[256:]], axis=0) featList[index].append(combined) except Exception as e: logger.error(f"Error processing images for index {index}: {str(e)}") continue return featList def get_files( self, folder: str, filter: Optional[List[str]] = None, create_single_json: bool = False ) -> Dict[str, List[str]]: """ Recursively collect image files from directory structure. Args: folder: Root directory to scan filter: Optional list of barcodes to include create_single_json: Whether to create individual JSON files per barcode Returns: Dictionary mapping barcode names to lists of image paths Example: { "barcode1": ["path/to/img1.jpg", "path/to/img2.jpg"], "barcode2": ["path/to/img3.jpg"] } """ file_dicts = {} total_files = 0 feature_counts = [] barcode_count = 0 subclass = [str(i) for i in range(100)] # Validate input directory if not os.path.isdir(folder): raise ValueError(f"Invalid directory: {folder}") i = 0 # Process each barcode directory for root, dirs, files in tqdm(os.walk(folder), desc="Scanning directories"): if not dirs: # Leaf directory (contains images) basename = os.path.basename(root) if basename in subclass: ori_barcode = root.split('/')[-2] barcode = root.split('/')[-2] + '_' + basename else: ori_barcode = basename barcode = basename # Apply filter if provided i += 1 print(ori_barcode, i) if filter and ori_barcode not in filter: continue # elif len(ori_barcode) > 13 or len(ori_barcode) < 8: # barcode筛选长度 # logger.warning(f"Skipping invalid barcode {ori_barcode}") # with open(conf['save']['error_barcodes'], 'a') as f: # f.write(ori_barcode + '\n') # f.close() # continue # Process image files if files: image_paths = self._process_image_files(root, files) if not image_paths: continue # Update counters barcode_count += 1 file_count = len(image_paths) total_files += file_count feature_counts.append(file_count) # Handle output mode if create_single_json: self._process_single_barcode(barcode, image_paths) else: if barcode.split('_')[-1] == '0': barcode = barcode.split('_')[0] file_dicts[barcode] = image_paths # # Log summary # logger.info(f"Processed {barcode_count} barcodes with {total_files} total images") # logger.debug(f"Image counts per barcode: {feature_counts}") # Batch process if not creating individual JSONs if not create_single_json and file_dicts: self.createFeatureDict( file_dicts, create_single_json=False, ) return file_dicts def _process_image_files(self, root: str, files: List[str]) -> List[str]: """Process and validate image files in a directory.""" valid_paths = [] for filename in files: file_path = os.path.join(root, filename) try: # Convert RGBA to RGB if needed self.convert_rgba_to_rgb(file_path) valid_paths.append(file_path) except Exception as e: logger.warning(f"Skipping invalid image {file_path}: {str(e)}") return valid_paths def _process_single_barcode(self, barcode: str, image_paths: List[str]): """Process a single barcode and create individual JSON file.""" temp_dict = {barcode: image_paths} self.createFeatureDict( temp_dict, create_single_json=True, ) def normalize_256(self, queFeatList): queFeatList = queFeatList / np.linalg.norm(queFeatList) return queFeatList def img2feature( self, imgs_dict: Dict[str, List[str]] ) -> Tuple[List[str], List[List[np.ndarray]]]: """ Extract features for all images in the dictionary. Args: imgs_dict: Dictionary mapping barcodes to image paths model: Pretrained feature extraction model barcode_flag: Whether to include barcode info (unused) Returns: Tuple containing: - List of barcode IDs - List of feature lists (one per barcode) Raises: ValueError: If input dictionary is empty RuntimeError: If feature extraction fails """ if not imgs_dict: raise ValueError("No images provided for feature extraction") try: barcode_list = list(imgs_dict.keys()) image_list = list(imgs_dict.values()) feature_list = self.getFeatureList(barcode_list, image_list) logger.info(f"Successfully extracted features for {len(barcode_list)} barcodes") return barcode_list, feature_list except Exception as e: logger.error(f"Feature extraction failed: {str(e)}") raise RuntimeError(f"Feature extraction failed: {str(e)}") def createFeatureDict(self, imgs_dict, create_single_json=False): # imgs->{barcode1:[img1_1...img1_n], barcode2:[img2_1...img2_n]} dicts_all = {} value_list = [] barcode_list, imgs_list = self.img2feature(imgs_dict) for i in range(len(barcode_list)): dicts = {} imgs_list_ = [] for j in range(len(imgs_list[i])): imgs_list_.append(imgs_list[i][j].tolist()) dicts['key'] = barcode_list[i] truncated_imgs_list = [subarray[:256] for subarray in imgs_list_] dicts['value'] = truncated_imgs_list if create_single_json: # json_path = os.path.join("./search_library/v8021_overseas/", str(barcode_list[i]) + '.json') json_path = os.path.join(self.conf['save']['json_path'], str(barcode_list[i]) + '.json') with open(json_path, 'w') as json_file: json.dump(dicts, json_file) else: value_list.append(dicts) if not create_single_json: dicts_all['total'] = value_list with open(self.conf['save']['json_bin'], 'w') as json_file: json.dump(dicts_all, json_file) self.create_binary_files(self.conf['save']['json_bin']) def statisticsBarcodes(self, pth, filter=None): feature_num = 0 feature_num_lists = [] nn = 0 with open(conf['save']['barcodes_statistics'], 'w', encoding='utf-8') as f: for barcode in os.listdir(pth): print("barcode length >> {}".format(len(barcode))) # if len(barcode) > 13 or len(barcode) < 8: # barcode筛选长度 # continue if filter is not None: f.writelines(barcode + '\n') if barcode in filter: print(barcode) feature_num += len(os.listdir(os.path.join(pth, barcode))) nn += 1 else: print('barcode name >>{}'.format(barcode)) f.writelines(barcode + '\n') feature_num += len(os.listdir(os.path.join(pth, barcode))) feature_num_lists.append(feature_num) print("特征总量: {}".format(feature_num)) print("barcode总量: {}".format(nn)) f.close() def get_shop_barcodes(self, file_path): if file_path: df = pd.read_excel(file_path) column_values = list(df.iloc[:, 6].values) column_values = list(map(str, column_values)) return column_values else: return None def del_base_dir(self, pth): for root, dirs, files in os.walk(pth): if len(dirs) == 1: if dirs[0] == 'base': shutil.rmtree(os.path.join(root, dirs[0])) def write_binary_file(self, filename, datas): with open(filename, 'wb') as f: # 先写入数据中的key数量(为C++读取提供便利) key_count = len(datas) f.write(struct.pack('I', key_count)) # 'I'代表无符号整型(4字节) for data in datas: key = data['key'] feats = data['value'] key_bytes = key.encode('utf-8') key_len = len(key) length_byte = struct.pack('