训练数据前置处理与提升训练效率
This commit is contained in:
12
.idea/CopilotChatHistory.xml
generated
12
.idea/CopilotChatHistory.xml
generated
@ -3,6 +3,18 @@
|
|||||||
<component name="CopilotChatHistory">
|
<component name="CopilotChatHistory">
|
||||||
<option name="conversations">
|
<option name="conversations">
|
||||||
<list>
|
<list>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1752114061266" />
|
||||||
|
<option name="id" value="0197f222dfd27515a3dbfea638532ee5" />
|
||||||
|
<option name="title" value="新对话 2025年7月10日 10:21:01" />
|
||||||
|
<option name="updateTime" value="1752114061266" />
|
||||||
|
</Conversation>
|
||||||
|
<Conversation>
|
||||||
|
<option name="createTime" value="1751970991660" />
|
||||||
|
<option name="id" value="0197e99bce2c7a569dee594fb9b6e152" />
|
||||||
|
<option name="title" value="新对话 2025年7月08日 18:36:31" />
|
||||||
|
<option name="updateTime" value="1751970991660" />
|
||||||
|
</Conversation>
|
||||||
<Conversation>
|
<Conversation>
|
||||||
<option name="createTime" value="1751441743239" />
|
<option name="createTime" value="1751441743239" />
|
||||||
<option name="id" value="0197ca101d8771bd80f2bc4aaf1a8f19" />
|
<option name="id" value="0197ca101d8771bd80f2bc4aaf1a8f19" />
|
||||||
|
26
configs/sub_data.yml
Normal file
26
configs/sub_data.yml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# configs/sub_data.yml
|
||||||
|
# 专为对比模型训练的数据集设计的配置文件
|
||||||
|
# 支持对比不同训练策略(如蒸馏vs独立训练)
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
source_dir: "../../data_center/contrast_data/total" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
train_dir: "../../data_center/contrast_data/v1/train" # 训练数据集根目录
|
||||||
|
val_dir: "../../data_center/contrast_data/v1/val" # 验证数据集根目录
|
||||||
|
data_extra_dir: "../../data_center/contrast_data/v1/extra"
|
||||||
|
max_files_ratio: 0.1
|
||||||
|
min_files: 10
|
||||||
|
split_ratio: 0.9
|
||||||
|
|
||||||
|
|
||||||
|
extend:
|
||||||
|
extend_same_dir: true
|
||||||
|
extend_extra: true
|
||||||
|
extend_extra_dir: "../../data_center/contrast_data/v1/extra"
|
||||||
|
extend_train: true
|
||||||
|
extend_train_dir: "../../data_center/contrast_data/v1/train"
|
||||||
|
|
||||||
|
limit:
|
||||||
|
count_limit: true
|
||||||
|
limit_count: 200
|
||||||
|
limit_dir: "../../data_center/contrast_data/v1/train"
|
@ -8,13 +8,13 @@ base:
|
|||||||
log_level: "info" # 日志级别(debug/info/warning/error)
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
embedding_size: 256 # 特征维度
|
embedding_size: 256 # 特征维度
|
||||||
pin_memory: true # 是否启用pin_memory
|
pin_memory: true # 是否启用pin_memory
|
||||||
distributed: true # 是否启用分布式训练
|
distributed: false # 是否启用分布式训练
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet18'
|
||||||
channel_ratio: 1.0
|
channel_ratio: 1.0
|
||||||
model_path: "checkpoints/resnet18_scatter_6.26/best.pth"
|
model_path: "checkpoints/resnet18_scatter_7.3/best.pth"
|
||||||
half: false # 是否启用半精度测试(fp16)
|
half: false # 是否启用半精度测试(fp16)
|
||||||
contrast_learning: false
|
contrast_learning: false
|
||||||
|
|
||||||
@ -22,9 +22,9 @@ models:
|
|||||||
data:
|
data:
|
||||||
test_batch_size: 128 # 训练批次大小
|
test_batch_size: 128 # 训练批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
test_dir: "../data_center/scatter/v2/val_extar" # 验证数据集根目录
|
test_dir: "../data_center/scatter/v4/val" # 验证数据集根目录
|
||||||
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
test_group_json: "../data_center/contrast_learning/model_test_data/test/inner_group_pairs.json"
|
||||||
test_list: "../data_center/scatter/val_extar_cross_same.txt"
|
test_list: "../data_center/scatter/v4/standard_cross_same.txt"
|
||||||
group_test: false
|
group_test: false
|
||||||
save_image_joint: true
|
save_image_joint: true
|
||||||
image_joint_pth: "./joint_images"
|
image_joint_pth: "./joint_images"
|
||||||
@ -37,6 +37,11 @@ transform:
|
|||||||
RandomRotation: 180 # 随机旋转角度
|
RandomRotation: 180 # 随机旋转角度
|
||||||
ColorJitter: 0.5 # 随机颜色抖动强度
|
ColorJitter: 0.5 # 随机颜色抖动强度
|
||||||
|
|
||||||
|
heatmap:
|
||||||
|
image_joint_pth: "./heatmap_joint_images"
|
||||||
|
feature_layer: "layer4"
|
||||||
|
show_heatmap: true
|
||||||
|
|
||||||
save:
|
save:
|
||||||
save_dir: ""
|
save_dir: ""
|
||||||
save_name: ""
|
save_name: ""
|
||||||
|
@ -15,11 +15,11 @@ base:
|
|||||||
# 模型配置
|
# 模型配置
|
||||||
models:
|
models:
|
||||||
backbone: 'resnet18'
|
backbone: 'resnet18'
|
||||||
channel_ratio: 0.75
|
channel_ratio: 1.0
|
||||||
model_path: "../checkpoints/resnet18_1009/best.pth"
|
model_path: "../checkpoints/resnet18_1009/best.pth"
|
||||||
onnx_model: "../checkpoints/resnet18_1009/best.onnx"
|
onnx_model: "../checkpoints/resnet18_3399_sancheng/best.onnx"
|
||||||
rknn_model: "../checkpoints/resnet18_1009/best_rknn2.3.2_batch16.rknn"
|
rknn_model: "../checkpoints/resnet18_3399_sancheng/best_rknn2.3.2_RK3566.rknn"
|
||||||
rknn_batch_size: 16
|
rknn_batch_size: 1
|
||||||
|
|
||||||
# 日志与监控
|
# 日志与监控
|
||||||
logging:
|
logging:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
def count_files(directory):
|
def count_files(directory):
|
||||||
"""统计目录中的文件数量"""
|
"""统计目录中的文件数量"""
|
||||||
try:
|
try:
|
||||||
@ -26,6 +26,20 @@ def clear_empty_dirs(path):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e.strerror}")
|
print(f"Error: {e.strerror}")
|
||||||
|
|
||||||
|
def get_max_files(conf):
|
||||||
|
max_files_ratio = conf['data']['max_files_ratio']
|
||||||
|
files_number = []
|
||||||
|
for root, dirs, files in os.walk(conf['data']['source_dir']):
|
||||||
|
if len(dirs) == 0:
|
||||||
|
if len(files) == 0:
|
||||||
|
print(root, dirs,files)
|
||||||
|
files_number.append(len(files))
|
||||||
|
files_number = sorted(files_number, reverse=False)
|
||||||
|
max_files = files_number[int(max_files_ratio * len(files_number))]
|
||||||
|
print(f"max_files: {max_files}")
|
||||||
|
if max_files < conf['data']['min_files']:
|
||||||
|
max_files = conf['data']['min_files']
|
||||||
|
return max_files
|
||||||
def megre_subdirs(pth):
|
def megre_subdirs(pth):
|
||||||
for roots, dir_names, files in os.walk(pth):
|
for roots, dir_names, files in os.walk(pth):
|
||||||
print(f"image {dir_names}")
|
print(f"image {dir_names}")
|
||||||
@ -41,19 +55,24 @@ def megre_subdirs(pth):
|
|||||||
clear_empty_dirs(pth)
|
clear_empty_dirs(pth)
|
||||||
|
|
||||||
|
|
||||||
def split_subdirs(source_dir, target_dir, max_files=10):
|
# def split_subdirs(source_dir, target_dir, max_files=10):
|
||||||
|
def split_subdirs(conf):
|
||||||
"""
|
"""
|
||||||
复制文件数≤max_files的子目录到目标目录
|
复制文件数≤max_files的子目录到目标目录
|
||||||
:param source_dir: 源目录路径
|
:param source_dir: 源目录路径
|
||||||
:param target_dir: 目标目录路径
|
:param target_dir: 目标目录路径
|
||||||
:param max_files: 最大文件数阈值
|
:param max_files: 最大文件数阈值
|
||||||
"""
|
"""
|
||||||
|
source_dir = conf['data']['source_dir']
|
||||||
|
target_extra_dir = conf['data']['data_extra_dir']
|
||||||
|
train_dir = conf['data']['train_dir']
|
||||||
|
max_files = get_max_files(conf)
|
||||||
megre_subdirs(source_dir) # 合并子目录,删除上级目录
|
megre_subdirs(source_dir) # 合并子目录,删除上级目录
|
||||||
# 创建目标目录
|
# 创建目标目录
|
||||||
Path(target_dir).mkdir(parents=True, exist_ok=True)
|
Path(target_extra_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"开始处理目录: {source_dir}")
|
print(f"开始处理目录: {source_dir}")
|
||||||
print(f"目标目录: {target_dir}")
|
print(f"目标目录: {target_extra_dir}")
|
||||||
print(f"筛选条件: 文件数 ≤ {max_files}\n")
|
print(f"筛选条件: 文件数 ≤ {max_files}\n")
|
||||||
|
|
||||||
# 遍历源目录
|
# 遍历源目录
|
||||||
@ -65,18 +84,18 @@ def split_subdirs(source_dir, target_dir, max_files=10):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
file_count = count_files(subdir_path)
|
file_count = count_files(subdir_path)
|
||||||
|
print(f"复制 {subdir} (包含 {file_count} 个文件)")
|
||||||
if file_count <= max_files:
|
if file_count <= max_files:
|
||||||
print(f"复制 {subdir} (包含 {file_count} 个文件)")
|
dest_path = os.path.join(target_extra_dir, subdir)
|
||||||
dest_path = os.path.join(target_dir, subdir)
|
else:
|
||||||
|
dest_path = os.path.join(train_dir, subdir)
|
||||||
# 如果目标目录已存在则跳过
|
# 如果目标目录已存在则跳过
|
||||||
if os.path.exists(dest_path):
|
if os.path.exists(dest_path):
|
||||||
print(f"目录已存在,跳过: {dest_path}")
|
print(f"目录已存在,跳过: {dest_path}")
|
||||||
continue
|
continue
|
||||||
|
print(f"复制 {subdir} (包含 {file_count} 个文件) 至 {dest_path}")
|
||||||
# shutil.copytree(subdir_path, dest_path)
|
shutil.copytree(subdir_path, dest_path)
|
||||||
shutil.move(subdir_path, dest_path)
|
# shutil.move(subdir_path, dest_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"处理目录 {subdir} 时出错: {e}")
|
print(f"处理目录 {subdir} 时出错: {e}")
|
||||||
@ -85,8 +104,8 @@ def split_subdirs(source_dir, target_dir, max_files=10):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 配置路径
|
# 配置路径
|
||||||
SOURCE_DIR = r"C:\Users\123\Desktop\test1\scatter_sub_class"
|
with open('../configs/sub_data.yml', 'r') as f:
|
||||||
TARGET_DIR = "scatter_mini"
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
# 执行复制操作
|
# 执行复制操作
|
||||||
split_subdirs(SOURCE_DIR, TARGET_DIR)
|
split_subdirs(conf)
|
||||||
|
@ -8,7 +8,9 @@ def is_image_file(filename):
|
|||||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||||
return filename.lower().endswith(image_extensions)
|
return filename.lower().endswith(image_extensions)
|
||||||
|
|
||||||
def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
|
|
||||||
|
|
||||||
|
def split_directory(conf):
|
||||||
"""
|
"""
|
||||||
分割目录中的图像文件到train和val目录
|
分割目录中的图像文件到train和val目录
|
||||||
:param src_dir: 源目录路径
|
:param src_dir: 源目录路径
|
||||||
@ -17,22 +19,22 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
|
|||||||
:param split_ratio: 训练集比例(默认0.9)
|
:param split_ratio: 训练集比例(默认0.9)
|
||||||
"""
|
"""
|
||||||
# 创建目标目录
|
# 创建目标目录
|
||||||
Path(train_dir).mkdir(parents=True, exist_ok=True)
|
train_dir = conf['data']['train_dir']
|
||||||
|
val_dir = conf['data']['val_dir']
|
||||||
|
split_ratio = conf['data']['split_ratio']
|
||||||
Path(val_dir).mkdir(parents=True, exist_ok=True)
|
Path(val_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# 遍历源目录
|
# 遍历源目录
|
||||||
for root, dirs, files in os.walk(src_dir):
|
for root, dirs, files in os.walk(train_dir):
|
||||||
# 获取相对路径(相对于src_dir)
|
# 获取相对路径(train_dir)
|
||||||
rel_path = os.path.relpath(root, src_dir)
|
rel_path = os.path.relpath(root, train_dir)
|
||||||
|
|
||||||
# 跳过当前目录(.)
|
# 跳过当前目录(.)
|
||||||
if rel_path == '.':
|
if rel_path == '.':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 创建对应的目标子目录
|
# 创建对应的目标子目录
|
||||||
train_subdir = os.path.join(train_dir, rel_path)
|
|
||||||
val_subdir = os.path.join(val_dir, rel_path)
|
val_subdir = os.path.join(val_dir, rel_path)
|
||||||
os.makedirs(train_subdir, exist_ok=True)
|
|
||||||
os.makedirs(val_subdir, exist_ok=True)
|
os.makedirs(val_subdir, exist_ok=True)
|
||||||
|
|
||||||
# 筛选图像文件
|
# 筛选图像文件
|
||||||
@ -46,13 +48,6 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
|
|||||||
# 计算分割点
|
# 计算分割点
|
||||||
split_point = int(len(image_files) * split_ratio)
|
split_point = int(len(image_files) * split_ratio)
|
||||||
|
|
||||||
# 复制文件到训练集
|
|
||||||
for file in image_files[:split_point]:
|
|
||||||
src = os.path.join(root, file)
|
|
||||||
dst = os.path.join(train_subdir, file)
|
|
||||||
# shutil.copy2(src, dst)
|
|
||||||
shutil.move(src, dst)
|
|
||||||
|
|
||||||
# 复制文件到验证集
|
# 复制文件到验证集
|
||||||
for file in image_files[split_point:]:
|
for file in image_files[split_point:]:
|
||||||
src = os.path.join(root, file)
|
src = os.path.join(root, file)
|
||||||
@ -62,12 +57,14 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
|
|||||||
|
|
||||||
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
|
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
|
||||||
|
|
||||||
|
def control_train_number():
|
||||||
|
pass
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 设置目录路径
|
# 设置目录路径
|
||||||
SOURCE_DIR = "scatter_add"
|
|
||||||
TRAIN_DIR = "scatter_data/train"
|
TRAIN_DIR = "scatter_data/train"
|
||||||
VAL_DIR = "scatter_data/val"
|
VAL_DIR = "scatter_data/val"
|
||||||
|
|
||||||
print("开始分割数据集...")
|
print("开始分割数据集...")
|
||||||
split_directory(SOURCE_DIR, TRAIN_DIR, VAL_DIR)
|
split_directory(TRAIN_DIR, VAL_DIR)
|
||||||
print("数据集分割完成")
|
print("数据集分割完成")
|
||||||
|
@ -5,6 +5,9 @@ from PIL import Image, ImageEnhance
|
|||||||
|
|
||||||
|
|
||||||
class ImageExtendProcessor:
|
class ImageExtendProcessor:
|
||||||
|
def __init__(self, conf):
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
def is_image_file(self, filename):
|
def is_image_file(self, filename):
|
||||||
"""检查文件是否为图像文件"""
|
"""检查文件是否为图像文件"""
|
||||||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||||
@ -80,7 +83,7 @@ class ImageExtendProcessor:
|
|||||||
print(f"处理图像 {image_path} 时出错: {e}")
|
print(f"处理图像 {image_path} 时出错: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def process_image_directory(self, src_dir, dst_dir, same_directory, **kwargs):
|
def process_extra_directory(self, src_dir, dst_dir, same_directory, dir_name):
|
||||||
"""
|
"""
|
||||||
处理单个目录中的图像文件
|
处理单个目录中的图像文件
|
||||||
:param src_dir: 源目录路径
|
:param src_dir: 源目录路径
|
||||||
@ -99,19 +102,24 @@ class ImageExtendProcessor:
|
|||||||
if not same_directory:
|
if not same_directory:
|
||||||
# 复制原始文件 (另存文件夹时启用)
|
# 复制原始文件 (另存文件夹时启用)
|
||||||
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
|
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
|
||||||
|
if dir_name == 'extra':
|
||||||
|
# 生成并保存旋转后的图像
|
||||||
|
for angle in [90, 180, 270]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
||||||
|
self.rotate_image(src_path, dst_path, angle)
|
||||||
|
for ratio in [0.8, 0.85, 0.9]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
|
||||||
|
self.random_cute_image(src_path, dst_path, ratio)
|
||||||
|
for brightness_factor in [0.8, 0.9, 1.0]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
|
||||||
|
self.random_brightness(src_path, dst_path, brightness_factor)
|
||||||
|
elif dir_name == 'train':
|
||||||
|
# 生成并保存旋转后的图像
|
||||||
|
for angle in [90, 180, 270]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
||||||
|
self.rotate_image(src_path, dst_path, angle)
|
||||||
|
|
||||||
# 生成并保存旋转后的图像
|
def image_extend(self, src_dir, dst_dir, same_directory=False, dir_name=None):
|
||||||
for angle in [90, 180, 270]:
|
|
||||||
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
|
||||||
self.rotate_image(src_path, dst_path, angle)
|
|
||||||
for ratio in [0.8, 0.85, 0.9]:
|
|
||||||
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
|
|
||||||
self.random_cute_image(src_path, dst_path, ratio)
|
|
||||||
for brightness_factor in [0.8, 0.9, 1.0]:
|
|
||||||
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
|
|
||||||
self.random_brightness(src_path, dst_path, brightness_factor)
|
|
||||||
|
|
||||||
def image_extend(self, src_dir, dst_dir, same_directory=False, **kwargs):
|
|
||||||
if same_directory:
|
if same_directory:
|
||||||
n_dst_dir = src_dir
|
n_dst_dir = src_dir
|
||||||
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
|
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
|
||||||
@ -121,7 +129,60 @@ class ImageExtendProcessor:
|
|||||||
for src_subdir in os.listdir(src_dir):
|
for src_subdir in os.listdir(src_dir):
|
||||||
src_subdir_path = os.path.join(src_dir, src_subdir)
|
src_subdir_path = os.path.join(src_dir, src_subdir)
|
||||||
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
|
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
|
||||||
self.process_image_directory(src_subdir_path, dst_subdir_path, same_directory)
|
if dir_name == 'extra':
|
||||||
|
self.process_extra_directory(src_subdir_path,
|
||||||
|
dst_subdir_path,
|
||||||
|
same_directory,
|
||||||
|
dir_name)
|
||||||
|
if dir_name == 'train':
|
||||||
|
if len(os.listdir(src_subdir_path)) < 50:
|
||||||
|
self.process_extra_directory(src_subdir_path,
|
||||||
|
dst_subdir_path,
|
||||||
|
same_directory,
|
||||||
|
dir_name)
|
||||||
|
|
||||||
|
def random_remove_image(self, subdir_path, max_count=200):
|
||||||
|
"""
|
||||||
|
随机删除子目录中的图像文件,直到数量不超过max_count
|
||||||
|
:param subdir_path: 子目录路径
|
||||||
|
:param max_count: 最大允许的图像数量
|
||||||
|
"""
|
||||||
|
# 统计图像文件数量
|
||||||
|
image_files = [f for f in os.listdir(subdir_path)
|
||||||
|
if self.is_image_file(f) and os.path.isfile(os.path.join(subdir_path, f))]
|
||||||
|
current_count = len(image_files)
|
||||||
|
|
||||||
|
# 如果图像数量不超过max_count,则无需删除
|
||||||
|
if current_count <= max_count:
|
||||||
|
print(f"无需处理 {subdir_path} (包含 {current_count} 个图像)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 计算需要删除的文件数
|
||||||
|
remove_count = current_count - max_count
|
||||||
|
|
||||||
|
# 随机选择要删除的文件
|
||||||
|
files_to_remove = random.sample(image_files, remove_count)
|
||||||
|
|
||||||
|
# 删除选中的文件
|
||||||
|
for file in files_to_remove:
|
||||||
|
file_path = os.path.join(subdir_path, file)
|
||||||
|
os.remove(file_path)
|
||||||
|
print(f"已删除 {file_path}")
|
||||||
|
|
||||||
|
def control_number(self):
|
||||||
|
if self.conf['extend']['extend_extra']:
|
||||||
|
self.image_extend(self.conf['extend']['extend_extra_dir'],
|
||||||
|
'',
|
||||||
|
same_directory=self.conf['extend']['extend_same_dir'],
|
||||||
|
dir_name='extra')
|
||||||
|
if self.conf['extend']['extend_train']:
|
||||||
|
self.image_extend(self.conf['extend']['extend_train_dir'],
|
||||||
|
'',
|
||||||
|
same_directory=self.conf['extend']['extend_same_dir'],
|
||||||
|
dir_name='train')
|
||||||
|
if self.conf['limit']['count_limit']:
|
||||||
|
self.random_remove_image(self.conf['limit']['limit_dir'],
|
||||||
|
max_count=self.conf['limit']['limit_count'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
from create_extra import split_subdirs
|
|
||||||
from data_split import split_directory
|
|
||||||
from extend import ImageExtendProcessor
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
def data_preprocessing(conf):
|
|
||||||
split_subdirs(conf['data']['source_dir'], conf['data']['data_extra_dir'], conf['data']['max_files'])
|
|
||||||
split_directory(conf['data']['source_dir'], conf['data']['train_dir'],
|
|
||||||
conf['data']['val_dir'], conf['data']['split_ratio'])
|
|
||||||
image_ex = ImageExtendProcessor()
|
|
||||||
image_ex.image_extend(conf['data']['extra_dir'],
|
|
||||||
'',
|
|
||||||
same_directory=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
with open('../configs/scatter_data.yml', 'r') as f:
|
|
||||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
data_preprocessing(conf)
|
|
17
data_preprocessing/sub_data_preprocessing.py
Normal file
17
data_preprocessing/sub_data_preprocessing.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from create_extra import split_subdirs
|
||||||
|
from data_split import split_directory
|
||||||
|
from extend import ImageExtendProcessor
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def data_preprocessing(conf):
|
||||||
|
# split_subdirs(conf)
|
||||||
|
# image_ex = ImageExtendProcessor(conf)
|
||||||
|
# image_ex.control_number()
|
||||||
|
split_directory(conf)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open('../configs/sub_data.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
data_preprocessing(conf)
|
@ -222,13 +222,13 @@ class ResNet(nn.Module):
|
|||||||
self.bn1 = norm_layer(self.inplanes)
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
# self.adaptiveMaxPool = nn.AdaptiveMaxPool2d((1, 1))
|
||||||
self.maxpool2 = nn.Sequential(
|
# self.maxpool2 = nn.Sequential(
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
# nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
# nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
|
||||||
)
|
# )
|
||||||
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
self.layer1 = self._make_layer(block, int(64 * scale), layers[0])
|
||||||
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2,
|
||||||
dilate=replace_stride_with_dilation[0])
|
dilate=replace_stride_with_dilation[0])
|
||||||
|
32
test_ori.py
32
test_ori.py
@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|||||||
# from config import config as conf
|
# from config import config as conf
|
||||||
from tools.dataset import get_transform
|
from tools.dataset import get_transform
|
||||||
from tools.image_joint import merge_imgs
|
from tools.image_joint import merge_imgs
|
||||||
|
from tools.getHeatMap import cal_cam
|
||||||
from configs import trainer_tools
|
from configs import trainer_tools
|
||||||
import yaml
|
import yaml
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -201,10 +202,11 @@ def compute_accuracy_recall(score, labels):
|
|||||||
|
|
||||||
def compute_accuracy(
|
def compute_accuracy(
|
||||||
feature_dict: Dict[str, torch.Tensor],
|
feature_dict: Dict[str, torch.Tensor],
|
||||||
pair_list: str,
|
cam: cal_cam,
|
||||||
test_root: str
|
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float]:
|
||||||
try:
|
try:
|
||||||
|
pair_list = conf['data']['test_list']
|
||||||
|
test_root = conf['data']['test_dir']
|
||||||
with open(pair_list, 'r') as f:
|
with open(pair_list, 'r') as f:
|
||||||
pairs = f.readlines()
|
pairs = f.readlines()
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
@ -220,6 +222,7 @@ def compute_accuracy(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
|
print(f"Processing pair: {pair}")
|
||||||
img1, img2, label = pair.split()
|
img1, img2, label = pair.split()
|
||||||
img1_path = osp.join(test_root, img1)
|
img1_path = osp.join(test_root, img1)
|
||||||
img2_path = osp.join(test_root, img2)
|
img2_path = osp.join(test_root, img2)
|
||||||
@ -236,9 +239,10 @@ def compute_accuracy(
|
|||||||
if conf['data']['save_image_joint']:
|
if conf['data']['save_image_joint']:
|
||||||
merge_imgs(img1_path,
|
merge_imgs(img1_path,
|
||||||
img2_path,
|
img2_path,
|
||||||
conf['data']['image_joint_pth'],
|
conf,
|
||||||
similarity,
|
similarity,
|
||||||
label)
|
label,
|
||||||
|
cam)
|
||||||
similarities.append(similarity)
|
similarities.append(similarity)
|
||||||
labels.append(int(label))
|
labels.append(int(label))
|
||||||
|
|
||||||
@ -306,7 +310,8 @@ def init_model():
|
|||||||
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
if torch.cuda.device_count() > 1 and conf['base']['distributed']:
|
||||||
model = nn.DataParallel(model).to(conf['base']['device'])
|
model = nn.DataParallel(model).to(conf['base']['device'])
|
||||||
###############正常模型加载################
|
###############正常模型加载################
|
||||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device']))
|
||||||
#######################################
|
#######################################
|
||||||
####### 对比学习模型临时运用###
|
####### 对比学习模型临时运用###
|
||||||
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
# state_dict = torch.load(conf['models']['model_path'], map_location=conf['base']['device'])
|
||||||
@ -321,7 +326,18 @@ def init_model():
|
|||||||
first_param_dtype = next(model.parameters()).dtype
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
print("模型的第一个参数的数据类型: {}".format(first_param_dtype))
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(conf['models']['model_path'], map_location=conf['base']['device']))
|
try:
|
||||||
|
model.load_state_dict(torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device']))
|
||||||
|
except:
|
||||||
|
state_dict = torch.load(conf['models']['model_path'],
|
||||||
|
map_location=conf['base']['device'])
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
new_key = k.replace("module.", "")
|
||||||
|
new_state_dict[new_key] = v
|
||||||
|
model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
|
||||||
if conf['models']['half']:
|
if conf['models']['half']:
|
||||||
model.half()
|
model.half()
|
||||||
first_param_dtype = next(model.parameters()).dtype
|
first_param_dtype = next(model.parameters()).dtype
|
||||||
@ -332,7 +348,7 @@ def init_model():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = init_model()
|
model = init_model()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
cam = cal_cam(model, conf)
|
||||||
if not conf['data']['group_test']:
|
if not conf['data']['group_test']:
|
||||||
images = unique_image(conf['data']['test_list'])
|
images = unique_image(conf['data']['test_list'])
|
||||||
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
images = [osp.join(conf['data']['test_dir'], img) for img in images]
|
||||||
@ -342,7 +358,7 @@ if __name__ == '__main__':
|
|||||||
for group in groups:
|
for group in groups:
|
||||||
d = featurize(group, test_transform, model, conf['base']['device'])
|
d = featurize(group, test_transform, model, conf['base']['device'])
|
||||||
feature_dict.update(d)
|
feature_dict.update(d)
|
||||||
accuracy, threshold = compute_accuracy(feature_dict, conf['data']['test_list'], conf['data']['test_dir'])
|
accuracy, threshold = compute_accuracy(feature_dict, cam)
|
||||||
print(
|
print(
|
||||||
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
"Test Model: {} Accuracy: {} Threshold: {}".format(conf['models']['model_path'], accuracy, threshold)
|
||||||
)
|
)
|
||||||
|
164
tools/getHeatMap.py
Normal file
164
tools/getHeatMap.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision import models
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.transforms as tfs
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
# from tools.config import cfg
|
||||||
|
# from comparative.tools.initmodel import initSimilarityModel
|
||||||
|
import yaml
|
||||||
|
from dataset import get_transform
|
||||||
|
|
||||||
|
|
||||||
|
class cal_cam(nn.Module):
|
||||||
|
def __init__(self, model, conf):
|
||||||
|
super(cal_cam, self).__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.device = self.conf['base']['device']
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
# 要求梯度的层
|
||||||
|
self.feature_layer = conf['heatmap']['feature_layer']
|
||||||
|
# 记录梯度
|
||||||
|
self.gradient = []
|
||||||
|
# 记录输出的特征图
|
||||||
|
self.output = []
|
||||||
|
_, self.transform = get_transform(self.conf)
|
||||||
|
|
||||||
|
def get_conf(self, yaml_pth):
|
||||||
|
with open(yaml_pth, 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def save_grad(self, grad):
|
||||||
|
self.gradient.append(grad)
|
||||||
|
|
||||||
|
def get_grad(self):
|
||||||
|
return self.gradient[-1].cpu().data
|
||||||
|
|
||||||
|
def get_feature(self):
|
||||||
|
return self.output[-1][0]
|
||||||
|
|
||||||
|
def process_img(self, input):
|
||||||
|
input = self.transform(input)
|
||||||
|
input = input.unsqueeze(0)
|
||||||
|
return input
|
||||||
|
|
||||||
|
# 计算最后一个卷积层的梯度,输出梯度和最后一个卷积层的特征图
|
||||||
|
def getGrad(self, input_):
|
||||||
|
self.gradient = [] # 清除之前的梯度
|
||||||
|
self.output = [] # 清除之前的特征图
|
||||||
|
# print(f"cuda.memory_allocated 1 {torch.cuda.memory_allocated()/ (1024 ** 3)}G")
|
||||||
|
input_ = input_.to(self.device).requires_grad_(True)
|
||||||
|
num = 1
|
||||||
|
for name, module in self.model._modules.items():
|
||||||
|
# print(f'module_name: {name}')
|
||||||
|
# print(f'module: {module}')
|
||||||
|
if (num == 1):
|
||||||
|
input = module(input_)
|
||||||
|
num = num + 1
|
||||||
|
continue
|
||||||
|
# 是待提取特征图的层
|
||||||
|
if (name == self.feature_layer):
|
||||||
|
input = module(input)
|
||||||
|
input.register_hook(self.save_grad)
|
||||||
|
self.output.append([input])
|
||||||
|
# 马上要到全连接层了
|
||||||
|
elif (name == "avgpool"):
|
||||||
|
input = module(input)
|
||||||
|
input = input.reshape(input.shape[0], -1)
|
||||||
|
# 普通的层
|
||||||
|
else:
|
||||||
|
input = module(input)
|
||||||
|
|
||||||
|
# print(f"cuda.memory_allocated 2 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||||
|
# 到这里input就是最后全连接层的输出了
|
||||||
|
index = torch.max(input, dim=-1)[1]
|
||||||
|
one_hot = torch.zeros((1, input.shape[-1]), dtype=torch.float32)
|
||||||
|
one_hot[0][index] = 1
|
||||||
|
confidenct = one_hot * input.cpu()
|
||||||
|
confidenct = torch.sum(confidenct, dim=-1).requires_grad_(True)
|
||||||
|
|
||||||
|
# print(f"cuda.memory_allocated 3 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||||
|
# 清除之前的所有梯度
|
||||||
|
self.model.zero_grad()
|
||||||
|
# 反向传播获取梯度
|
||||||
|
grad_output = torch.ones_like(confidenct)
|
||||||
|
confidenct.backward(grad_output)
|
||||||
|
# 获取特征图的梯度
|
||||||
|
grad_val = self.get_grad()
|
||||||
|
feature = self.get_feature()
|
||||||
|
|
||||||
|
# print(f"cuda.memory_allocated 4 {torch.cuda.memory_allocated() / (1024 ** 3)}G")
|
||||||
|
return grad_val, feature, input_.grad
|
||||||
|
|
||||||
|
# 计算CAM
|
||||||
|
def getCam(self, grad_val, feature):
|
||||||
|
# 对特征图的每个通道进行全局池化
|
||||||
|
alpha = torch.mean(grad_val, dim=(2, 3)).cpu()
|
||||||
|
feature = feature.cpu()
|
||||||
|
# 将池化后的结果和相应通道特征图相乘
|
||||||
|
cam = torch.zeros((feature.shape[2], feature.shape[3]), dtype=torch.float32)
|
||||||
|
for idx in range(alpha.shape[1]):
|
||||||
|
cam = cam + alpha[0][idx] * feature[0][idx]
|
||||||
|
# 进行ReLU操作
|
||||||
|
cam = np.maximum(cam.detach().numpy(), 0)
|
||||||
|
|
||||||
|
# plt.imshow(cam)
|
||||||
|
# plt.colorbar()
|
||||||
|
# plt.savefig("cam.jpg")
|
||||||
|
|
||||||
|
# 将cam区域放大到输入图片大小
|
||||||
|
cam_ = cv2.resize(cam, (224, 224))
|
||||||
|
cam_ = cam_ - np.min(cam_)
|
||||||
|
cam_ = cam_ / np.max(cam_)
|
||||||
|
# plt.imshow(cam_)
|
||||||
|
# plt.savefig("cam_.jpg")
|
||||||
|
cam = torch.from_numpy(cam)
|
||||||
|
|
||||||
|
return cam, cam_
|
||||||
|
|
||||||
|
def show_img(self, cam_, img, heatmap_save_pth, imgname):
|
||||||
|
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||||
|
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||||
|
# cv2.imwrite("img.jpg", cam_img)
|
||||||
|
cv2.imwrite(os.sep.join([heatmap_save_pth, imgname]), cam_img)
|
||||||
|
|
||||||
|
def get_hot_map(self, img_pth):
|
||||||
|
img = Image.open(img_pth)
|
||||||
|
img = img.resize((224, 224))
|
||||||
|
input = self.process_img(img)
|
||||||
|
grad_val, feature, input_grad = self.getGrad(input)
|
||||||
|
cam, cam_ = self.getCam(grad_val, feature)
|
||||||
|
heatmap = cv2.applyColorMap(np.uint8(255 * cam_), cv2.COLORMAP_JET)
|
||||||
|
cam_img = 0.3 * heatmap + 0.7 * np.float32(img)
|
||||||
|
cam_img = Image.fromarray(np.uint8(cam_img))
|
||||||
|
return cam_img
|
||||||
|
|
||||||
|
# def __call__(self, img_root, heatmap_save_pth):
|
||||||
|
# for imgname in os.listdir(img_root):
|
||||||
|
# img = Image.open(os.sep.join([img_root, imgname]))
|
||||||
|
# img = img.resize((224, 224))
|
||||||
|
# # plt.imshow(img)
|
||||||
|
# # plt.savefig("airplane.jpg")
|
||||||
|
# input = self.process_img(img)
|
||||||
|
# grad_val, feature, input_grad = self.getGrad(input)
|
||||||
|
# cam, cam_ = self.getCam(grad_val, feature)
|
||||||
|
# self.show_img(cam_, img, heatmap_save_pth, imgname)
|
||||||
|
# return cam
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cam = cal_cam()
|
||||||
|
img_root = "test_img/"
|
||||||
|
heatmap_save_pth = "heatmap_result"
|
||||||
|
cam(img_root, heatmap_save_pth)
|
@ -188,7 +188,7 @@ class PairGenerator:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
original_path = '/home/lc/data_center/scatter/val_extar'
|
original_path = '/home/lc/data_center/scatter/v4/val'
|
||||||
parent_dir = str(Path(original_path).parent)
|
parent_dir = str(Path(original_path).parent)
|
||||||
generator = PairGenerator()
|
generator = PairGenerator()
|
||||||
|
|
||||||
|
@ -1,33 +1,50 @@
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from tools.getHeatMap import cal_cam
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def merge_imgs(img1_path, img2_path, save_path, similar=None, label=None):
|
def merge_imgs(img1_path, img2_path, conf, similar=None, label=None, cam=None):
|
||||||
position = (50, 50) # 文字的左上角坐标
|
save = True
|
||||||
color = (255, 0, 0) # 红色文字,格式为 RGB
|
position = (50, 50) # 文字的左上角坐标
|
||||||
if not os.path.exists(os.sep.join([save_path, str(label)])):
|
color = (255, 0, 0) # 红色文字,格式为 RGB
|
||||||
os.makedirs(os.sep.join([save_path, str(label)]))
|
# if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||||
save_path = os.sep.join([save_path, str(label)])
|
# os.makedirs(os.sep.join([save_path, str(label)]))
|
||||||
img_name = os.path.basename(img1_path).split('.')[0]+'_'+os.path.basename(img2_path).split('.')[0]+'.png'
|
# save_path = os.sep.join([save_path, str(label)])
|
||||||
|
# img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||||
|
if not conf['heatmap']['show_heatmap']:
|
||||||
img1 = Image.open(img1_path)
|
img1 = Image.open(img1_path)
|
||||||
img2 = Image.open(img2_path)
|
img2 = Image.open(img2_path)
|
||||||
img1 = img1.resize((224,224))
|
img1 = img1.resize((224, 224))
|
||||||
img2 = img2.resize((224,224))
|
img2 = img2.resize((224, 224))
|
||||||
print('img1_path', img1)
|
save_path = conf['data']['image_joint_pth']
|
||||||
print('img2_path', img2)
|
else:
|
||||||
assert img1.height == img2.height
|
assert cam is not None, 'cam is None'
|
||||||
|
img1 = cam.get_hot_map(img1_path)
|
||||||
|
img2 = cam.get_hot_map(img2_path)
|
||||||
|
save_path = conf['heatmap']['image_joint_pth']
|
||||||
|
# print('img1_path', img1)
|
||||||
|
# print('img2_path', img2)
|
||||||
|
if not os.path.exists(os.sep.join([save_path, str(label)])):
|
||||||
|
os.makedirs(os.sep.join([save_path, str(label)]))
|
||||||
|
save_path = os.sep.join([save_path, str(label)])
|
||||||
|
img_name = os.path.basename(img1_path).split('.')[0] + '_' + os.path.basename(img2_path).split('.')[0] + '.png'
|
||||||
|
assert img1.height == img2.height
|
||||||
|
|
||||||
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
new_img = Image.new('RGB', (img1.width + img2.width + 10, img1.height))
|
||||||
|
|
||||||
# print('new_img', new_img)
|
# print('new_img', new_img)
|
||||||
new_img.paste(img1, (0, 0))
|
new_img.paste(img1, (0, 0))
|
||||||
new_img.paste(img2, (img1.width + 10, 0))
|
new_img.paste(img2, (img1.width + 10, 0))
|
||||||
|
|
||||||
if similar is not None:
|
if similar is not None:
|
||||||
similar = str(similar)+'_'+str(label)
|
if label == '1' and similar > 0.5:
|
||||||
draw = ImageDraw.Draw(new_img)
|
save = False
|
||||||
draw.text(position, str(similar), color, font_size=36)
|
elif label == '0' and similar < 0.5:
|
||||||
os.makedirs(save_path, exist_ok=True)
|
save = False
|
||||||
img_save = os.path.join(save_path, img_name)
|
similar = str(similar) + '_' + str(label)
|
||||||
|
draw = ImageDraw.Draw(new_img)
|
||||||
|
draw.text(position, str(similar), color, font_size=36)
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
img_save = os.path.join(save_path, img_name)
|
||||||
|
if save:
|
||||||
new_img.save(img_save)
|
new_img.save(img_save)
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ if __name__ == '__main__':
|
|||||||
rknn.config(
|
rknn.config(
|
||||||
mean_values=[[127.5, 127.5, 127.5]],
|
mean_values=[[127.5, 127.5, 127.5]],
|
||||||
std_values=[[127.5, 127.5, 127.5]],
|
std_values=[[127.5, 127.5, 127.5]],
|
||||||
target_platform='rk3588',
|
target_platform='rk3566',
|
||||||
model_pruning=False,
|
model_pruning=False,
|
||||||
compress_weight=False,
|
compress_weight=False,
|
||||||
single_core_mode=True,
|
single_core_mode=True,
|
||||||
|
Reference in New Issue
Block a user