训练数据前置处理与提升训练效率
This commit is contained in:
@ -1,33 +1,50 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from tools.getHeatMap import cal_cam
|
||||
import os
|
||||
|
||||
|
||||
def merge_imgs(img1_path, img2_path, save_path, similar=None, label=None):
|
||||
position = (50, 50) # 文字的左上角坐标
|
||||
color = (255, 0, 0) # 红色文字,格式为 RGB
|
||||
if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||
os.makedirs(os.sep.join([save_path, str(label)]))
|
||||
save_path = os.sep.join([save_path, str(label)])
|
||||
img_name = os.path.basename(img1_path).split('.')[0]+'_'+os.path.basename(img2_path).split('.')[0]+'.png'
|
||||
def merge_imgs(img1_path, img2_path, conf, similar=None, label=None, cam=None):
|
||||
save = True
|
||||
position = (50, 50) # 文字的左上角坐标
|
||||
color = (255, 0, 0) # 红色文字,格式为 RGB
|
||||
# if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||
# os.makedirs(os.sep.join([save_path, str(label)]))
|
||||
# save_path = os.sep.join([save_path, str(label)])
|
||||
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||
if not conf['heatmap']['show_heatmap']:
|
||||
img1 = Image.open(img1_path)
|
||||
img2 = Image.open(img2_path)
|
||||
img1 = img1.resize((224,224))
|
||||
img2 = img2.resize((224,224))
|
||||
print('img1_path', img1)
|
||||
print('img2_path', img2)
|
||||
assert img1.height == img2.height
|
||||
img1 = img1.resize((224, 224))
|
||||
img2 = img2.resize((224, 224))
|
||||
save_path = conf['data']['image_joint_pth']
|
||||
else:
|
||||
assert cam is not None, 'cam is None'
|
||||
img1 = cam.get_hot_map(img1_path)
|
||||
img2 = cam.get_hot_map(img2_path)
|
||||
save_path = conf['heatmap']['image_joint_pth']
|
||||
# print('img1_path', img1)
|
||||
# print('img2_path', img2)
|
||||
if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||
os.makedirs(os.sep.join([save_path, str(label)]))
|
||||
save_path = os.sep.join([save_path, str(label)])
|
||||
img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||
assert img1.height == img2.height
|
||||
|
||||
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
||||
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
||||
|
||||
# print('new_img', new_img)
|
||||
new_img.paste(img1, (0, 0))
|
||||
new_img.paste(img2, (img1.width + 10, 0))
|
||||
# print('new_img', new_img)
|
||||
new_img.paste(img1, (0, 0))
|
||||
new_img.paste(img2, (img1.width + 10, 0))
|
||||
|
||||
if similar is not None:
|
||||
similar = str(similar)+'_'+str(label)
|
||||
draw = ImageDraw.Draw(new_img)
|
||||
draw.text(position, str(similar), color, font_size=36)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
img_save = os.path.join(save_path, img_name)
|
||||
if similar is not None:
|
||||
if label == '1' and similar > 0.5:
|
||||
save = False
|
||||
elif label == '0' and similar < 0.5:
|
||||
save = False
|
||||
similar = str(similar) + '_' + str(label)
|
||||
draw = ImageDraw.Draw(new_img)
|
||||
draw.text(position, str(similar), color, font_size=36)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
img_save = os.path.join(save_path, img_name)
|
||||
if save:
|
||||
new_img.save(img_save)
|
||||
|
||||
|
Reference in New Issue
Block a user