Files
ieemoo-ai-contrast/tools/gift_data_pretreatment.py
2025-06-11 15:23:50 +08:00

93 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from config import config as conf
from PIL import Image
import numpy as np
def convert_rgba_to_rgb(image_path, output_path=None):
"""
将给定路径的4通道PNG图像转换为3通道并保存到指定输出路径。
:param image_path: 输入图像的路径
:param output_path: 转换后的图像保存路径
"""
# 打开图像
img = Image.open(image_path)
# 转换图像模式从RGBA到RGB
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
if img.mode == 'RGBA':
# 转换为RGB模式
img_rgb = img.convert('RGB')
# 保存转换后的图像
img_rgb.save(image_path)
# print(f"Image converted from RGBA to RGB and saved to {image_path}")
# else:
# # 如果已经是RGB或其他模式直接保存
# img.save(image_path)
# print(f"Image already in {img.mode} mode, saved to {image_path}")
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
res = []
for img in images:
try:
# print(img)
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
res.append(im)
except:
continue
data = torch.stack(res)
return data
def inference(images, model, actionModel=False):
data = test_preprocess(images, actionModel)
if torch.cuda.is_available():
data = data.to(conf.device)
features = model(data)
return features
def group_image(images, batch=64) -> list:
"""Group image paths by batch size"""
size = len(images)
res = []
for i in range(0, size, batch):
end = min(batch + i, size)
res.append(images[i:end])
return res
def normalize(queFeatList):
for num1 in range(len(queFeatList)):
for num2 in range(len(queFeatList[num1])):
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
return queFeatList
def getFeatureList(barList, imgList, model):
# featList = [[] for i in range(len(barList))]
# for index, feat in enumerate(imgList):
fe_nps = []
groups = group_image(imgList)
for group in groups:
feat_tensor = inference(group, model)
# for fe in feat_tensor:
if feat_tensor.device == 'cpu':
fe_np = feat_tensor.squeeze().detach().numpy()
# fe_np = fe_np[:, 256:]
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
else:
fe_np = feat_tensor.squeeze().detach().cpu().numpy()
# fe_np = fe_np[:, 256:]
# fe_np = fe_np[256:]
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
# fe_np = fe_np.reshape(1, fe_np.shape[0], 1, 1)
# print(fe_np)
fe_nps.append(fe_np)
# if fe_nps:
# merged_fe_np = np.concatenate(fe_nps, axis=0)
# else:
# merged_fe_np = np.array([]) #
# fe_list = normalize(fe_nps)
return fe_nps