训练数据前置处理与提升训练效率

This commit is contained in:
lee
2025-07-10 14:24:05 +08:00
parent 0701538a73
commit 09f41f6289
15 changed files with 430 additions and 116 deletions

View File

@ -1,7 +1,7 @@
import os
import shutil
from pathlib import Path
import yaml
def count_files(directory):
"""统计目录中的文件数量"""
try:
@ -26,6 +26,20 @@ def clear_empty_dirs(path):
except Exception as e:
print(f"Error: {e.strerror}")
def get_max_files(conf):
max_files_ratio = conf['data']['max_files_ratio']
files_number = []
for root, dirs, files in os.walk(conf['data']['source_dir']):
if len(dirs) == 0:
if len(files) == 0:
print(root, dirs,files)
files_number.append(len(files))
files_number = sorted(files_number, reverse=False)
max_files = files_number[int(max_files_ratio * len(files_number))]
print(f"max_files: {max_files}")
if max_files < conf['data']['min_files']:
max_files = conf['data']['min_files']
return max_files
def megre_subdirs(pth):
for roots, dir_names, files in os.walk(pth):
print(f"image {dir_names}")
@ -41,19 +55,24 @@ def megre_subdirs(pth):
clear_empty_dirs(pth)
def split_subdirs(source_dir, target_dir, max_files=10):
# def split_subdirs(source_dir, target_dir, max_files=10):
def split_subdirs(conf):
"""
复制文件数≤max_files的子目录到目标目录
:param source_dir: 源目录路径
:param target_dir: 目标目录路径
:param max_files: 最大文件数阈值
"""
source_dir = conf['data']['source_dir']
target_extra_dir = conf['data']['data_extra_dir']
train_dir = conf['data']['train_dir']
max_files = get_max_files(conf)
megre_subdirs(source_dir) # 合并子目录,删除上级目录
# 创建目标目录
Path(target_dir).mkdir(parents=True, exist_ok=True)
Path(target_extra_dir).mkdir(parents=True, exist_ok=True)
print(f"开始处理目录: {source_dir}")
print(f"目标目录: {target_dir}")
print(f"目标目录: {target_extra_dir}")
print(f"筛选条件: 文件数 ≤ {max_files}\n")
# 遍历源目录
@ -65,18 +84,18 @@ def split_subdirs(source_dir, target_dir, max_files=10):
try:
file_count = count_files(subdir_path)
print(f"复制 {subdir} (包含 {file_count} 个文件)")
if file_count <= max_files:
print(f"复制 {subdir} (包含 {file_count} 个文件)")
dest_path = os.path.join(target_dir, subdir)
# 如果目标目录已存在则跳过
if os.path.exists(dest_path):
print(f"目录已存在,跳过: {dest_path}")
continue
# shutil.copytree(subdir_path, dest_path)
shutil.move(subdir_path, dest_path)
dest_path = os.path.join(target_extra_dir, subdir)
else:
dest_path = os.path.join(train_dir, subdir)
# 如果目标目录已存在则跳过
if os.path.exists(dest_path):
print(f"目录已存在,跳过: {dest_path}")
continue
print(f"复制 {subdir} (包含 {file_count} 个文件) 至 {dest_path}")
shutil.copytree(subdir_path, dest_path)
# shutil.move(subdir_path, dest_path)
except Exception as e:
print(f"处理目录 {subdir} 时出错: {e}")
@ -85,8 +104,8 @@ def split_subdirs(source_dir, target_dir, max_files=10):
if __name__ == "__main__":
# 配置路径
SOURCE_DIR = r"C:\Users\123\Desktop\test1\scatter_sub_class"
TARGET_DIR = "scatter_mini"
with open('../configs/sub_data.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# 执行复制操作
split_subdirs(SOURCE_DIR, TARGET_DIR)
split_subdirs(conf)

View File

@ -8,7 +8,9 @@ def is_image_file(filename):
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
return filename.lower().endswith(image_extensions)
def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
def split_directory(conf):
"""
分割目录中的图像文件到train和val目录
:param src_dir: 源目录路径
@ -17,22 +19,22 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
:param split_ratio: 训练集比例(默认0.9)
"""
# 创建目标目录
Path(train_dir).mkdir(parents=True, exist_ok=True)
train_dir = conf['data']['train_dir']
val_dir = conf['data']['val_dir']
split_ratio = conf['data']['split_ratio']
Path(val_dir).mkdir(parents=True, exist_ok=True)
# 遍历源目录
for root, dirs, files in os.walk(src_dir):
# 获取相对路径(相对于src_dir)
rel_path = os.path.relpath(root, src_dir)
for root, dirs, files in os.walk(train_dir):
# 获取相对路径(train_dir)
rel_path = os.path.relpath(root, train_dir)
# 跳过当前目录(.)
if rel_path == '.':
continue
# 创建对应的目标子目录
train_subdir = os.path.join(train_dir, rel_path)
val_subdir = os.path.join(val_dir, rel_path)
os.makedirs(train_subdir, exist_ok=True)
os.makedirs(val_subdir, exist_ok=True)
# 筛选图像文件
@ -46,13 +48,6 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
# 计算分割点
split_point = int(len(image_files) * split_ratio)
# 复制文件到训练集
for file in image_files[:split_point]:
src = os.path.join(root, file)
dst = os.path.join(train_subdir, file)
# shutil.copy2(src, dst)
shutil.move(src, dst)
# 复制文件到验证集
for file in image_files[split_point:]:
src = os.path.join(root, file)
@ -62,12 +57,14 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
def control_train_number():
pass
if __name__ == "__main__":
# 设置目录路径
SOURCE_DIR = "scatter_add"
TRAIN_DIR = "scatter_data/train"
VAL_DIR = "scatter_data/val"
print("开始分割数据集...")
split_directory(SOURCE_DIR, TRAIN_DIR, VAL_DIR)
split_directory(TRAIN_DIR, VAL_DIR)
print("数据集分割完成")

View File

@ -5,6 +5,9 @@ from PIL import Image, ImageEnhance
class ImageExtendProcessor:
def __init__(self, conf):
self.conf = conf
def is_image_file(self, filename):
"""检查文件是否为图像文件"""
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
@ -80,7 +83,7 @@ class ImageExtendProcessor:
print(f"处理图像 {image_path} 时出错: {e}")
return False
def process_image_directory(self, src_dir, dst_dir, same_directory, **kwargs):
def process_extra_directory(self, src_dir, dst_dir, same_directory, dir_name):
"""
处理单个目录中的图像文件
:param src_dir: 源目录路径
@ -99,19 +102,24 @@ class ImageExtendProcessor:
if not same_directory:
# 复制原始文件 (另存文件夹时启用)
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
if dir_name == 'extra':
# 生成并保存旋转后的图像
for angle in [90, 180, 270]:
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
self.rotate_image(src_path, dst_path, angle)
for ratio in [0.8, 0.85, 0.9]:
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
self.random_cute_image(src_path, dst_path, ratio)
for brightness_factor in [0.8, 0.9, 1.0]:
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
self.random_brightness(src_path, dst_path, brightness_factor)
elif dir_name == 'train':
# 生成并保存旋转后的图像
for angle in [90, 180, 270]:
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
self.rotate_image(src_path, dst_path, angle)
# 生成并保存旋转后的图像
for angle in [90, 180, 270]:
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
self.rotate_image(src_path, dst_path, angle)
for ratio in [0.8, 0.85, 0.9]:
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
self.random_cute_image(src_path, dst_path, ratio)
for brightness_factor in [0.8, 0.9, 1.0]:
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
self.random_brightness(src_path, dst_path, brightness_factor)
def image_extend(self, src_dir, dst_dir, same_directory=False, **kwargs):
def image_extend(self, src_dir, dst_dir, same_directory=False, dir_name=None):
if same_directory:
n_dst_dir = src_dir
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
@ -121,7 +129,60 @@ class ImageExtendProcessor:
for src_subdir in os.listdir(src_dir):
src_subdir_path = os.path.join(src_dir, src_subdir)
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
self.process_image_directory(src_subdir_path, dst_subdir_path, same_directory)
if dir_name == 'extra':
self.process_extra_directory(src_subdir_path,
dst_subdir_path,
same_directory,
dir_name)
if dir_name == 'train':
if len(os.listdir(src_subdir_path)) < 50:
self.process_extra_directory(src_subdir_path,
dst_subdir_path,
same_directory,
dir_name)
def random_remove_image(self, subdir_path, max_count=200):
"""
随机删除子目录中的图像文件直到数量不超过max_count
:param subdir_path: 子目录路径
:param max_count: 最大允许的图像数量
"""
# 统计图像文件数量
image_files = [f for f in os.listdir(subdir_path)
if self.is_image_file(f) and os.path.isfile(os.path.join(subdir_path, f))]
current_count = len(image_files)
# 如果图像数量不超过max_count则无需删除
if current_count <= max_count:
print(f"无需处理 {subdir_path} (包含 {current_count} 个图像)")
return
# 计算需要删除的文件数
remove_count = current_count - max_count
# 随机选择要删除的文件
files_to_remove = random.sample(image_files, remove_count)
# 删除选中的文件
for file in files_to_remove:
file_path = os.path.join(subdir_path, file)
os.remove(file_path)
print(f"已删除 {file_path}")
def control_number(self):
if self.conf['extend']['extend_extra']:
self.image_extend(self.conf['extend']['extend_extra_dir'],
'',
same_directory=self.conf['extend']['extend_same_dir'],
dir_name='extra')
if self.conf['extend']['extend_train']:
self.image_extend(self.conf['extend']['extend_train_dir'],
'',
same_directory=self.conf['extend']['extend_same_dir'],
dir_name='train')
if self.conf['limit']['count_limit']:
self.random_remove_image(self.conf['limit']['limit_dir'],
max_count=self.conf['limit']['limit_count'])
if __name__ == "__main__":

View File

@ -1,20 +0,0 @@
from create_extra import split_subdirs
from data_split import split_directory
from extend import ImageExtendProcessor
import yaml
def data_preprocessing(conf):
split_subdirs(conf['data']['source_dir'], conf['data']['data_extra_dir'], conf['data']['max_files'])
split_directory(conf['data']['source_dir'], conf['data']['train_dir'],
conf['data']['val_dir'], conf['data']['split_ratio'])
image_ex = ImageExtendProcessor()
image_ex.image_extend(conf['data']['extra_dir'],
'',
same_directory=True)
if __name__ == '__main__':
with open('../configs/scatter_data.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
data_preprocessing(conf)

View File

@ -0,0 +1,17 @@
from create_extra import split_subdirs
from data_split import split_directory
from extend import ImageExtendProcessor
import yaml
def data_preprocessing(conf):
# split_subdirs(conf)
# image_ex = ImageExtendProcessor(conf)
# image_ex.control_number()
split_directory(conf)
if __name__ == '__main__':
with open('../configs/sub_data.yml', 'r') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
data_preprocessing(conf)