训练数据前置处理与提升训练效率

This commit is contained in:
lee
2025-07-10 14:24:05 +08:00
parent 0701538a73
commit 09f41f6289
15 changed files with 430 additions and 116 deletions

View File

@ -1,7 +1,7 @@
import os
import shutil
from pathlib import Path
import yaml
def count_files(directory):
"""统计目录中的文件数量"""
try:
@ -26,6 +26,20 @@ def clear_empty_dirs(path):
except Exception as e:
print(f"Error: {e.strerror}")
def get_max_files(conf):
max_files_ratio = conf['data']['max_files_ratio']
files_number = []
for root, dirs, files in os.walk(conf['data']['source_dir']):
if len(dirs) == 0:
if len(files) == 0:
print(root, dirs,files)
files_number.append(len(files))
files_number = sorted(files_number, reverse=False)
max_files = files_number[int(max_files_ratio * len(files_number))]
print(f"max_files: {max_files}")
if max_files < conf['data']['min_files']:
max_files = conf['data']['min_files']
return max_files
def megre_subdirs(pth):
for roots, dir_names, files in os.walk(pth):
print(f"image {dir_names}")
@ -41,19 +55,24 @@ def megre_subdirs(pth):
clear_empty_dirs(pth)
def split_subdirs(source_dir, target_dir, max_files=10):
# def split_subdirs(source_dir, target_dir, max_files=10):
def split_subdirs(conf):
"""
复制文件数≤max_files的子目录到目标目录
:param source_dir: 源目录路径
:param target_dir: 目标目录路径
:param max_files: 最大文件数阈值
"""
source_dir = conf['data']['source_dir']
target_extra_dir = conf['data']['data_extra_dir']
train_dir = conf['data']['train_dir']
max_files = get_max_files(conf)
megre_subdirs(source_dir) # 合并子目录,删除上级目录
# 创建目标目录
Path(target_dir).mkdir(parents=True, exist_ok=True)
Path(target_extra_dir).mkdir(parents=True, exist_ok=True)
print(f"开始处理目录: {source_dir}")
print(f"目标目录: {target_dir}")
print(f"目标目录: {target_extra_dir}")
print(f"筛选条件: 文件数 ≤ {max_files}\n")
# 遍历源目录
@ -65,18 +84,18 @@ def split_subdirs(source_dir, target_dir, max_files=10):
try:
file_count = count_files(subdir_path)
print(f"复制 {subdir} (包含 {file_count} 个文件)")
if file_count <= max_files:
print(f"复制 {subdir} (包含 {file_count} 个文件)")
dest_path = os.path.join(target_dir, subdir)
# 如果目标目录已存在则跳过
if os.path.exists(dest_path):
print(f"目录已存在,跳过: {dest_path}")
continue
# shutil.copytree(subdir_path, dest_path)
shutil.move(subdir_path, dest_path)
dest_path = os.path.join(target_extra_dir, subdir)
else:
dest_path = os.path.join(train_dir, subdir)
# 如果目标目录已存在则跳过
if os.path.exists(dest_path):
print(f"目录已存在,跳过: {dest_path}")
continue
print(f"复制 {subdir} (包含 {file_count} 个文件) 至 {dest_path}")
shutil.copytree(subdir_path, dest_path)
# shutil.move(subdir_path, dest_path)
except Exception as e:
print(f"处理目录 {subdir} 时出错: {e}")
@ -85,8 +104,8 @@ def split_subdirs(source_dir, target_dir, max_files=10):
if __name__ == "__main__":
# 配置路径
SOURCE_DIR = r"C:\Users\123\Desktop\test1\scatter_sub_class"
TARGET_DIR = "scatter_mini"
with open('../configs/sub_data.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# 执行复制操作
split_subdirs(SOURCE_DIR, TARGET_DIR)
split_subdirs(conf)