import os import shutil import random from pathlib import Path def is_image_file(filename): """检查文件是否为图像文件""" image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff') return filename.lower().endswith(image_extensions) def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9): """ 分割目录中的图像文件到train和val目录 :param src_dir: 源目录路径 :param train_dir: 训练集目录路径 :param val_dir: 验证集目录路径 :param split_ratio: 训练集比例(默认0.9) """ # 创建目标目录 Path(train_dir).mkdir(parents=True, exist_ok=True) Path(val_dir).mkdir(parents=True, exist_ok=True) # 遍历源目录 for root, dirs, files in os.walk(src_dir): # 获取相对路径(相对于src_dir) rel_path = os.path.relpath(root, src_dir) # 跳过当前目录(.) if rel_path == '.': continue # 创建对应的目标子目录 train_subdir = os.path.join(train_dir, rel_path) val_subdir = os.path.join(val_dir, rel_path) os.makedirs(train_subdir, exist_ok=True) os.makedirs(val_subdir, exist_ok=True) # 筛选图像文件 image_files = [f for f in files if is_image_file(f)] if not image_files: continue # 随机打乱文件列表 random.shuffle(image_files) # 计算分割点 split_point = int(len(image_files) * split_ratio) # 复制文件到训练集 for file in image_files[:split_point]: src = os.path.join(root, file) dst = os.path.join(train_subdir, file) # shutil.copy2(src, dst) shutil.move(src, dst) # 复制文件到验证集 for file in image_files[split_point:]: src = os.path.join(root, file) dst = os.path.join(val_subdir, file) # shutil.copy2(src, dst) shutil.move(src, dst) print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})") if __name__ == "__main__": # 设置目录路径 SOURCE_DIR = "scatter_add" TRAIN_DIR = "scatter_data/train" VAL_DIR = "scatter_data/val" print("开始分割数据集...") split_directory(SOURCE_DIR, TRAIN_DIR, VAL_DIR) print("数据集分割完成")