训练数据前置处理与提升训练效率

This commit is contained in:
lee
2025-07-10 14:24:05 +08:00
parent 0701538a73
commit 09f41f6289
15 changed files with 430 additions and 116 deletions

View File

@ -8,7 +8,9 @@ def is_image_file(filename):
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
return filename.lower().endswith(image_extensions)
def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
def split_directory(conf):
"""
分割目录中的图像文件到train和val目录
:param src_dir: 源目录路径
@ -17,22 +19,22 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
:param split_ratio: 训练集比例(默认0.9)
"""
# 创建目标目录
Path(train_dir).mkdir(parents=True, exist_ok=True)
train_dir = conf['data']['train_dir']
val_dir = conf['data']['val_dir']
split_ratio = conf['data']['split_ratio']
Path(val_dir).mkdir(parents=True, exist_ok=True)
# 遍历源目录
for root, dirs, files in os.walk(src_dir):
# 获取相对路径(相对于src_dir)
rel_path = os.path.relpath(root, src_dir)
for root, dirs, files in os.walk(train_dir):
# 获取相对路径(train_dir)
rel_path = os.path.relpath(root, train_dir)
# 跳过当前目录(.)
if rel_path == '.':
continue
# 创建对应的目标子目录
train_subdir = os.path.join(train_dir, rel_path)
val_subdir = os.path.join(val_dir, rel_path)
os.makedirs(train_subdir, exist_ok=True)
os.makedirs(val_subdir, exist_ok=True)
# 筛选图像文件
@ -46,13 +48,6 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
# 计算分割点
split_point = int(len(image_files) * split_ratio)
# 复制文件到训练集
for file in image_files[:split_point]:
src = os.path.join(root, file)
dst = os.path.join(train_subdir, file)
# shutil.copy2(src, dst)
shutil.move(src, dst)
# 复制文件到验证集
for file in image_files[split_point:]:
src = os.path.join(root, file)
@ -62,12 +57,14 @@ def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
def control_train_number():
pass
if __name__ == "__main__":
# 设置目录路径
SOURCE_DIR = "scatter_add"
TRAIN_DIR = "scatter_data/train"
VAL_DIR = "scatter_data/val"
print("开始分割数据集...")
split_directory(SOURCE_DIR, TRAIN_DIR, VAL_DIR)
split_directory(TRAIN_DIR, VAL_DIR)
print("数据集分割完成")