rebuild
This commit is contained in:
92
tools/gift_data_pretreatment.py
Normal file
92
tools/gift_data_pretreatment.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from config import config as conf
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image_path, output_path=None):
|
||||
"""
|
||||
将给定路径的4通道PNG图像转换为3通道,并保存到指定输出路径。
|
||||
|
||||
:param image_path: 输入图像的路径
|
||||
:param output_path: 转换后的图像保存路径
|
||||
"""
|
||||
# 打开图像
|
||||
img = Image.open(image_path)
|
||||
# 转换图像模式从RGBA到RGB
|
||||
# .convert('RGB')会丢弃Alpha通道并转换为纯RGB图像
|
||||
if img.mode == 'RGBA':
|
||||
# 转换为RGB模式
|
||||
img_rgb = img.convert('RGB')
|
||||
# 保存转换后的图像
|
||||
img_rgb.save(image_path)
|
||||
# print(f"Image converted from RGBA to RGB and saved to {image_path}")
|
||||
# else:
|
||||
# # 如果已经是RGB或其他模式,直接保存
|
||||
# img.save(image_path)
|
||||
# print(f"Image already in {img.mode} mode, saved to {image_path}")
|
||||
|
||||
|
||||
def test_preprocess(images: list, actionModel=False) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
try:
|
||||
# print(img)
|
||||
im = conf.test_transform(img) if actionModel else conf.test_transform(Image.open(img))
|
||||
res.append(im)
|
||||
except:
|
||||
continue
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def inference(images, model, actionModel=False):
|
||||
data = test_preprocess(images, actionModel)
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(conf.device)
|
||||
features = model(data)
|
||||
return features
|
||||
|
||||
|
||||
def group_image(images, batch=64) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i:end])
|
||||
return res
|
||||
|
||||
def normalize(queFeatList):
|
||||
for num1 in range(len(queFeatList)):
|
||||
for num2 in range(len(queFeatList[num1])):
|
||||
queFeatList[num1][num2] = queFeatList[num1][num2] / np.linalg.norm(queFeatList[num1][num2])
|
||||
return queFeatList
|
||||
|
||||
def getFeatureList(barList, imgList, model):
|
||||
# featList = [[] for i in range(len(barList))]
|
||||
# for index, feat in enumerate(imgList):
|
||||
fe_nps = []
|
||||
groups = group_image(imgList)
|
||||
for group in groups:
|
||||
feat_tensor = inference(group, model)
|
||||
# for fe in feat_tensor:
|
||||
if feat_tensor.device == 'cpu':
|
||||
fe_np = feat_tensor.squeeze().detach().numpy()
|
||||
# fe_np = fe_np[:, 256:]
|
||||
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||
else:
|
||||
fe_np = feat_tensor.squeeze().detach().cpu().numpy()
|
||||
# fe_np = fe_np[:, 256:]
|
||||
# fe_np = fe_np[256:]
|
||||
# fe_np = fe_np.reshape(fe_np.shape[0], fe_np.shape[1], 1, 1)
|
||||
# fe_np = fe_np.reshape(1, fe_np.shape[0], 1, 1)
|
||||
# print(fe_np)
|
||||
|
||||
fe_nps.append(fe_np)
|
||||
# if fe_nps:
|
||||
# merged_fe_np = np.concatenate(fe_nps, axis=0)
|
||||
# else:
|
||||
# merged_fe_np = np.array([]) #
|
||||
# fe_list = normalize(fe_nps)
|
||||
return fe_nps
|
Reference in New Issue
Block a user