73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
import os
|
|
import shutil
|
|
import random
|
|
from pathlib import Path
|
|
import yaml
|
|
def is_image_file(filename):
|
|
"""检查文件是否为图像文件"""
|
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
|
return filename.lower().endswith(image_extensions)
|
|
|
|
|
|
|
|
def split_directory(conf):
|
|
"""
|
|
分割目录中的图像文件到train和val目录
|
|
:param src_dir: 源目录路径
|
|
:param train_dir: 训练集目录路径
|
|
:param val_dir: 验证集目录路径
|
|
:param split_ratio: 训练集比例(默认0.9)
|
|
"""
|
|
# 创建目标目录
|
|
train_dir = conf['data']['train_dir']
|
|
val_dir = conf['data']['val_dir']
|
|
split_ratio = conf['data']['split_ratio']
|
|
Path(val_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
# 遍历源目录
|
|
for root, dirs, files in os.walk(train_dir):
|
|
# 获取相对路径(train_dir)
|
|
rel_path = os.path.relpath(root, train_dir)
|
|
|
|
# 跳过当前目录(.)
|
|
if rel_path == '.':
|
|
continue
|
|
|
|
# 创建对应的目标子目录
|
|
val_subdir = os.path.join(val_dir, rel_path)
|
|
os.makedirs(val_subdir, exist_ok=True)
|
|
|
|
# 筛选图像文件
|
|
image_files = [f for f in files if is_image_file(f)]
|
|
if not image_files:
|
|
continue
|
|
|
|
# 随机打乱文件列表
|
|
random.shuffle(image_files)
|
|
|
|
# 计算分割点
|
|
split_point = int(len(image_files) * split_ratio)
|
|
|
|
# 复制文件到验证集
|
|
for file in image_files[split_point:]:
|
|
src = os.path.join(root, file)
|
|
dst = os.path.join(val_subdir, file)
|
|
# shutil.copy2(src, dst)
|
|
shutil.move(src, dst)
|
|
|
|
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
|
|
|
|
def control_train_number():
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
# # 设置目录路径
|
|
# TRAIN_DIR = "/home/lc/data_center/electornic/v1/train"
|
|
# VAL_DIR = "/home/lc/data_center/electornic/v1/val"
|
|
|
|
with open('../configs/scatter_data.yml', 'r') as f:
|
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
|
print("开始分割数据集...")
|
|
split_directory(conf)
|
|
print("数据集分割完成")
|