From 0701538a73f5b95382724eddf796379cf3a61cc8 Mon Sep 17 00:00:00 2001 From: lee <770918727@qq.com> Date: Mon, 7 Jul 2025 15:19:22 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=A3=E7=A7=B0=E8=AE=AD=E7=BB=83=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=89=8D=E7=BD=AE=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/scatter.yml | 14 +- configs/scatter_data.yml | 19 +++ data_preprocessing/__init__.py | 0 data_preprocessing/create_extra.py | 92 ++++++++++++ data_preprocessing/data_split.py | 73 ++++++++++ data_preprocessing/extend.py | 131 ++++++++++++++++++ .../scatter_data_preprocessing.py | 20 +++ 7 files changed, 342 insertions(+), 7 deletions(-) create mode 100644 configs/scatter_data.yml create mode 100644 data_preprocessing/__init__.py create mode 100644 data_preprocessing/create_extra.py create mode 100644 data_preprocessing/data_split.py create mode 100644 data_preprocessing/extend.py create mode 100644 data_preprocessing/scatter_data_preprocessing.py diff --git a/configs/scatter.yml b/configs/scatter.yml index 848c622..40d2ed3 100644 --- a/configs/scatter.yml +++ b/configs/scatter.yml @@ -18,9 +18,9 @@ models: # 训练参数 training: - epochs: 600 # 总训练轮次 + epochs: 800 # 总训练轮次 batch_size: 64 # 批次大小 - lr: 0.0004 # 初始学习率 + lr: 0.01 # 初始学习率 optimizer: "sgd" # 优化器类型 metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax) loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax) @@ -29,8 +29,8 @@ training: weight_decay: 0.0005 # 权重衰减 scheduler: "step" # 学习率调度器(可选:cosine_annealing/step/none) num_workers: 32 # 数据加载线程数 - checkpoints: "./checkpoints/resnet18_scatter_6.26/" # 模型保存目录 - restore: True + checkpoints: "./checkpoints/resnet18_scatter_7.4/" # 模型保存目录 + restore: false restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径 @@ -46,8 +46,8 @@ data: train_batch_size: 128 # 训练批次大小 val_batch_size: 100 # 验证批次大小 num_workers: 32 # 数据加载线程数 - data_train_dir: "../data_center/scatter/v2/train" # 训练数据集根目录 - data_val_dir: "../data_center/scatter/v2/val" # 验证数据集根目录 + data_train_dir: "../data_center/scatter/v4/train" # 训练数据集根目录 + data_val_dir: "../data_center/scatter/v4/val" # 验证数据集根目录 transform: img_size: 224 # 图像尺寸 @@ -59,7 +59,7 @@ transform: # 日志与监控 logging: - logging_dir: "./log/2025.6.25-scatter.txt" # 日志保存目录 + logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录 tensorboard: true # 是否启用TensorBoard checkpoint_interval: 30 # 检查点保存间隔(epoch) diff --git a/configs/scatter_data.yml b/configs/scatter_data.yml new file mode 100644 index 0000000..2fe7e14 --- /dev/null +++ b/configs/scatter_data.yml @@ -0,0 +1,19 @@ +# configs/scatter_data.yml +# 专为散称前处理的配置文件 + +# 数据配置 +data: + dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集) + source_dir: "../../data_center/scatter/v5/source" # 原始数据 + train_dir: "../../data_center/scatter/v5/train" # 训练数据集根目录 + val_dir: "../../data_center/scatter/v5/val" # 验证数据集根目录 + extra_dir: "../../data_center/scatter/v5/extra" # 验证数据集根目录 + split_ratio: 0.9 + max_files: 10 # 数据集小于该阈值则归纳至extra + + +# 日志与监控 +logging: + logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录 + log_level: "info" # 日志级别(debug/info/warning/error) + diff --git a/data_preprocessing/__init__.py b/data_preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_preprocessing/create_extra.py b/data_preprocessing/create_extra.py new file mode 100644 index 0000000..55686a2 --- /dev/null +++ b/data_preprocessing/create_extra.py @@ -0,0 +1,92 @@ +import os +import shutil +from pathlib import Path + +def count_files(directory): + """统计目录中的文件数量""" + try: + return len([f for f in os.listdir(directory) + if os.path.isfile(os.path.join(directory, f))]) + except Exception as e: + print(f"无法统计目录 {directory}: {e}") + return 0 + +def clear_empty_dirs(path): + """ + 删除空目录 + :param path: 目录路径 + """ + for root, dirs, files in os.walk(path, topdown=False): + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + try: + if not os.listdir(dir_path): + os.rmdir(dir_path) + print(f"Deleted empty directory: {dir_path}") + except Exception as e: + print(f"Error: {e.strerror}") + +def megre_subdirs(pth): + for roots, dir_names, files in os.walk(pth): + print(f"image {dir_names}") + for image in dir_names: + inner_dir_path = os.path.join(pth, image) + for inner_roots, inner_dirs, inner_files in os.walk(inner_dir_path): + for inner_dir in inner_dirs: + src_dir = os.path.join(inner_roots, inner_dir) + dest_dir = os.path.join(pth, inner_dir) + # shutil.copytree(src_dir, dest_dir) + shutil.move(src_dir, dest_dir) + print(f"Copied {inner_dir} to {pth}") + clear_empty_dirs(pth) + + +def split_subdirs(source_dir, target_dir, max_files=10): + """ + 复制文件数≤max_files的子目录到目标目录 + :param source_dir: 源目录路径 + :param target_dir: 目标目录路径 + :param max_files: 最大文件数阈值 + """ + megre_subdirs(source_dir) # 合并子目录,删除上级目录 + # 创建目标目录 + Path(target_dir).mkdir(parents=True, exist_ok=True) + + print(f"开始处理目录: {source_dir}") + print(f"目标目录: {target_dir}") + print(f"筛选条件: 文件数 ≤ {max_files}\n") + + # 遍历源目录 + for subdir in os.listdir(source_dir): + subdir_path = os.path.join(source_dir, subdir) + + if not os.path.isdir(subdir_path): + continue + + try: + file_count = count_files(subdir_path) + + if file_count <= max_files: + print(f"复制 {subdir} (包含 {file_count} 个文件)") + dest_path = os.path.join(target_dir, subdir) + + # 如果目标目录已存在则跳过 + if os.path.exists(dest_path): + print(f"目录已存在,跳过: {dest_path}") + continue + + # shutil.copytree(subdir_path, dest_path) + shutil.move(subdir_path, dest_path) + + except Exception as e: + print(f"处理目录 {subdir} 时出错: {e}") + + print("\n处理完成") + +if __name__ == "__main__": + # 配置路径 + SOURCE_DIR = r"C:\Users\123\Desktop\test1\scatter_sub_class" + TARGET_DIR = "scatter_mini" + + # 执行复制操作 + split_subdirs(SOURCE_DIR, TARGET_DIR) diff --git a/data_preprocessing/data_split.py b/data_preprocessing/data_split.py new file mode 100644 index 0000000..12d27a0 --- /dev/null +++ b/data_preprocessing/data_split.py @@ -0,0 +1,73 @@ +import os +import shutil +import random +from pathlib import Path + +def is_image_file(filename): + """检查文件是否为图像文件""" + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff') + return filename.lower().endswith(image_extensions) + +def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9): + """ + 分割目录中的图像文件到train和val目录 + :param src_dir: 源目录路径 + :param train_dir: 训练集目录路径 + :param val_dir: 验证集目录路径 + :param split_ratio: 训练集比例(默认0.9) + """ + # 创建目标目录 + Path(train_dir).mkdir(parents=True, exist_ok=True) + Path(val_dir).mkdir(parents=True, exist_ok=True) + + # 遍历源目录 + for root, dirs, files in os.walk(src_dir): + # 获取相对路径(相对于src_dir) + rel_path = os.path.relpath(root, src_dir) + + # 跳过当前目录(.) + if rel_path == '.': + continue + + # 创建对应的目标子目录 + train_subdir = os.path.join(train_dir, rel_path) + val_subdir = os.path.join(val_dir, rel_path) + os.makedirs(train_subdir, exist_ok=True) + os.makedirs(val_subdir, exist_ok=True) + + # 筛选图像文件 + image_files = [f for f in files if is_image_file(f)] + if not image_files: + continue + + # 随机打乱文件列表 + random.shuffle(image_files) + + # 计算分割点 + split_point = int(len(image_files) * split_ratio) + + # 复制文件到训练集 + for file in image_files[:split_point]: + src = os.path.join(root, file) + dst = os.path.join(train_subdir, file) + # shutil.copy2(src, dst) + shutil.move(src, dst) + + # 复制文件到验证集 + for file in image_files[split_point:]: + src = os.path.join(root, file) + dst = os.path.join(val_subdir, file) + # shutil.copy2(src, dst) + shutil.move(src, dst) + + print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})") + +if __name__ == "__main__": + # 设置目录路径 + SOURCE_DIR = "scatter_add" + TRAIN_DIR = "scatter_data/train" + VAL_DIR = "scatter_data/val" + + print("开始分割数据集...") + split_directory(SOURCE_DIR, TRAIN_DIR, VAL_DIR) + print("数据集分割完成") diff --git a/data_preprocessing/extend.py b/data_preprocessing/extend.py new file mode 100644 index 0000000..bdb66f6 --- /dev/null +++ b/data_preprocessing/extend.py @@ -0,0 +1,131 @@ +import os +import random +import shutil +from PIL import Image, ImageEnhance + + +class ImageExtendProcessor: + def is_image_file(self, filename): + """检查文件是否为图像文件""" + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff') + return filename.lower().endswith(image_extensions) + + def random_cute_image(self, image_path, output_path, ratio=0.8): + """ + 对图像进行随机裁剪 + :param image_path: 输入图像路径 + :param output_path: 输出图像路径 + :param ratio: 裁剪比例,决定裁剪区域的大小,默认为0.8 + """ + try: + with Image.open(image_path) as img: + # 获取图像尺寸 + width, height = img.size + + # 计算裁剪后的尺寸 + new_width = int(width * ratio) + new_height = int(height * ratio) + + # 随机生成裁剪起始点 + left = random.randint(0, width - new_width) + top = random.randint(0, height - new_height) + right = left + new_width + bottom = top + new_height + + # 执行裁剪 + cropped_img = img.crop((left, top, right, bottom)) + + # 保存裁剪后的图像 + cropped_img.save(output_path) + return True + except Exception as e: + print(f"处理图像 {image_path} 时出错: {e}") + return False + + def random_brightness(self, image_path, output_path, brightness_factor=None): + """ + 对图像进行随机亮度调整 + :param image_path: 输入图像路径 + :param output_path: 输出图像路径 + :param brightness_factor: 亮度调整因子,默认为随机值 + """ + try: + with Image.open(image_path) as img: + # 创建一个ImageEnhance.Brightness对象 + enhancer = ImageEnhance.Brightness(img) + + # 如果没有指定亮度因子,则随机生成 + if brightness_factor is None: + brightness_factor = random.uniform(0.5, 1.5) + + # 应用亮度调整 + brightened_img = enhancer.enhance(brightness_factor) + + # 保存调整后的图像 + brightened_img.save(output_path) + return True + except Exception as e: + print(f"处理图像 {image_path} 时出错: {e}") + return False + + def rotate_image(self, image_path, output_path, degrees): + """旋转图像并保存到指定路径""" + try: + with Image.open(image_path) as img: + # 旋转图像并自动调整画布大小 + rotated = img.rotate(degrees, expand=True) + rotated.save(output_path) + return True + except Exception as e: + print(f"处理图像 {image_path} 时出错: {e}") + return False + + def process_image_directory(self, src_dir, dst_dir, same_directory, **kwargs): + """ + 处理单个目录中的图像文件 + :param src_dir: 源目录路径 + :param dst_dir: 目标目录路径 + """ + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + # 获取目录中所有图像文件 + image_files = [f for f in os.listdir(src_dir) + if self.is_image_file(f) and os.path.isfile(os.path.join(src_dir, f))] + + # 处理每个图像文件 + for img_file in image_files: + src_path = os.path.join(src_dir, img_file) + base_name, ext = os.path.splitext(img_file) + if not same_directory: + # 复制原始文件 (另存文件夹时启用) + shutil.copy2(src_path, os.path.join(dst_dir, img_file)) + + # 生成并保存旋转后的图像 + for angle in [90, 180, 270]: + dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}") + self.rotate_image(src_path, dst_path, angle) + for ratio in [0.8, 0.85, 0.9]: + dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}") + self.random_cute_image(src_path, dst_path, ratio) + for brightness_factor in [0.8, 0.9, 1.0]: + dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}") + self.random_brightness(src_path, dst_path, brightness_factor) + + def image_extend(self, src_dir, dst_dir, same_directory=False, **kwargs): + if same_directory: + n_dst_dir = src_dir + print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下") + else: + n_dst_dir = dst_dir + print(f"处理目录 {src_dir} 中的图像文件 保存至不同目录下") + for src_subdir in os.listdir(src_dir): + src_subdir_path = os.path.join(src_dir, src_subdir) + dst_subdir_path = os.path.join(n_dst_dir, src_subdir) + self.process_image_directory(src_subdir_path, dst_subdir_path, same_directory) + + +if __name__ == "__main__": + src_dir = "./scatter_mini" + dst_dir = "./scatter_add" + image_ex = ImageExtendProcessor() + image_ex.image_extend(src_dir, dst_dir, same_directory=False) diff --git a/data_preprocessing/scatter_data_preprocessing.py b/data_preprocessing/scatter_data_preprocessing.py new file mode 100644 index 0000000..5300189 --- /dev/null +++ b/data_preprocessing/scatter_data_preprocessing.py @@ -0,0 +1,20 @@ +from create_extra import split_subdirs +from data_split import split_directory +from extend import ImageExtendProcessor +import yaml + + +def data_preprocessing(conf): + split_subdirs(conf['data']['source_dir'], conf['data']['data_extra_dir'], conf['data']['max_files']) + split_directory(conf['data']['source_dir'], conf['data']['train_dir'], + conf['data']['val_dir'], conf['data']['split_ratio']) + image_ex = ImageExtendProcessor() + image_ex.image_extend(conf['data']['extra_dir'], + '', + same_directory=True) + + +if __name__ == '__main__': + with open('../configs/scatter_data.yml', 'r') as f: + conf = yaml.load(f, Loader=yaml.FullLoader) + data_preprocessing(conf)