散称训练数据前置处理
This commit is contained in:
@ -18,9 +18,9 @@ models:
|
|||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
training:
|
training:
|
||||||
epochs: 600 # 总训练轮次
|
epochs: 800 # 总训练轮次
|
||||||
batch_size: 64 # 批次大小
|
batch_size: 64 # 批次大小
|
||||||
lr: 0.0004 # 初始学习率
|
lr: 0.01 # 初始学习率
|
||||||
optimizer: "sgd" # 优化器类型
|
optimizer: "sgd" # 优化器类型
|
||||||
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
metric: 'arcface' # 损失函数类型(可选:arcface/cosface/sphereface/softmax)
|
||||||
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
loss: "cross_entropy" # 损失函数类型(可选:cross_entropy/cross_entropy_smooth/center_loss/center_loss_smooth/arcface/cosface/sphereface/softmax)
|
||||||
@ -29,8 +29,8 @@ training:
|
|||||||
weight_decay: 0.0005 # 权重衰减
|
weight_decay: 0.0005 # 权重衰减
|
||||||
scheduler: "step" # 学习率调度器(可选:cosine_annealing/step/none)
|
scheduler: "step" # 学习率调度器(可选:cosine_annealing/step/none)
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
checkpoints: "./checkpoints/resnet18_scatter_6.26/" # 模型保存目录
|
checkpoints: "./checkpoints/resnet18_scatter_7.4/" # 模型保存目录
|
||||||
restore: True
|
restore: false
|
||||||
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
|
restore_model: "checkpoints/resnet18_scatter_6.25/best.pth" # 模型恢复路径
|
||||||
|
|
||||||
|
|
||||||
@ -46,8 +46,8 @@ data:
|
|||||||
train_batch_size: 128 # 训练批次大小
|
train_batch_size: 128 # 训练批次大小
|
||||||
val_batch_size: 100 # 验证批次大小
|
val_batch_size: 100 # 验证批次大小
|
||||||
num_workers: 32 # 数据加载线程数
|
num_workers: 32 # 数据加载线程数
|
||||||
data_train_dir: "../data_center/scatter/v2/train" # 训练数据集根目录
|
data_train_dir: "../data_center/scatter/v4/train" # 训练数据集根目录
|
||||||
data_val_dir: "../data_center/scatter/v2/val" # 验证数据集根目录
|
data_val_dir: "../data_center/scatter/v4/val" # 验证数据集根目录
|
||||||
|
|
||||||
transform:
|
transform:
|
||||||
img_size: 224 # 图像尺寸
|
img_size: 224 # 图像尺寸
|
||||||
@ -59,7 +59,7 @@ transform:
|
|||||||
|
|
||||||
# 日志与监控
|
# 日志与监控
|
||||||
logging:
|
logging:
|
||||||
logging_dir: "./log/2025.6.25-scatter.txt" # 日志保存目录
|
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
|
||||||
tensorboard: true # 是否启用TensorBoard
|
tensorboard: true # 是否启用TensorBoard
|
||||||
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
checkpoint_interval: 30 # 检查点保存间隔(epoch)
|
||||||
|
|
||||||
|
19
configs/scatter_data.yml
Normal file
19
configs/scatter_data.yml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# configs/scatter_data.yml
|
||||||
|
# 专为散称前处理的配置文件
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
data:
|
||||||
|
dataset: "imagenet" # 数据集名称(示例用,可替换为实际数据集)
|
||||||
|
source_dir: "../../data_center/scatter/v5/source" # 原始数据
|
||||||
|
train_dir: "../../data_center/scatter/v5/train" # 训练数据集根目录
|
||||||
|
val_dir: "../../data_center/scatter/v5/val" # 验证数据集根目录
|
||||||
|
extra_dir: "../../data_center/scatter/v5/extra" # 验证数据集根目录
|
||||||
|
split_ratio: 0.9
|
||||||
|
max_files: 10 # 数据集小于该阈值则归纳至extra
|
||||||
|
|
||||||
|
|
||||||
|
# 日志与监控
|
||||||
|
logging:
|
||||||
|
logging_dir: "./log/2025.7.4-scatter.txt" # 日志保存目录
|
||||||
|
log_level: "info" # 日志级别(debug/info/warning/error)
|
||||||
|
|
0
data_preprocessing/__init__.py
Normal file
0
data_preprocessing/__init__.py
Normal file
92
data_preprocessing/create_extra.py
Normal file
92
data_preprocessing/create_extra.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def count_files(directory):
|
||||||
|
"""统计目录中的文件数量"""
|
||||||
|
try:
|
||||||
|
return len([f for f in os.listdir(directory)
|
||||||
|
if os.path.isfile(os.path.join(directory, f))])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"无法统计目录 {directory}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def clear_empty_dirs(path):
|
||||||
|
"""
|
||||||
|
删除空目录
|
||||||
|
:param path: 目录路径
|
||||||
|
"""
|
||||||
|
for root, dirs, files in os.walk(path, topdown=False):
|
||||||
|
for dir_name in dirs:
|
||||||
|
dir_path = os.path.join(root, dir_name)
|
||||||
|
try:
|
||||||
|
if not os.listdir(dir_path):
|
||||||
|
os.rmdir(dir_path)
|
||||||
|
print(f"Deleted empty directory: {dir_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e.strerror}")
|
||||||
|
|
||||||
|
def megre_subdirs(pth):
|
||||||
|
for roots, dir_names, files in os.walk(pth):
|
||||||
|
print(f"image {dir_names}")
|
||||||
|
for image in dir_names:
|
||||||
|
inner_dir_path = os.path.join(pth, image)
|
||||||
|
for inner_roots, inner_dirs, inner_files in os.walk(inner_dir_path):
|
||||||
|
for inner_dir in inner_dirs:
|
||||||
|
src_dir = os.path.join(inner_roots, inner_dir)
|
||||||
|
dest_dir = os.path.join(pth, inner_dir)
|
||||||
|
# shutil.copytree(src_dir, dest_dir)
|
||||||
|
shutil.move(src_dir, dest_dir)
|
||||||
|
print(f"Copied {inner_dir} to {pth}")
|
||||||
|
clear_empty_dirs(pth)
|
||||||
|
|
||||||
|
|
||||||
|
def split_subdirs(source_dir, target_dir, max_files=10):
|
||||||
|
"""
|
||||||
|
复制文件数≤max_files的子目录到目标目录
|
||||||
|
:param source_dir: 源目录路径
|
||||||
|
:param target_dir: 目标目录路径
|
||||||
|
:param max_files: 最大文件数阈值
|
||||||
|
"""
|
||||||
|
megre_subdirs(source_dir) # 合并子目录,删除上级目录
|
||||||
|
# 创建目标目录
|
||||||
|
Path(target_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"开始处理目录: {source_dir}")
|
||||||
|
print(f"目标目录: {target_dir}")
|
||||||
|
print(f"筛选条件: 文件数 ≤ {max_files}\n")
|
||||||
|
|
||||||
|
# 遍历源目录
|
||||||
|
for subdir in os.listdir(source_dir):
|
||||||
|
subdir_path = os.path.join(source_dir, subdir)
|
||||||
|
|
||||||
|
if not os.path.isdir(subdir_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_count = count_files(subdir_path)
|
||||||
|
|
||||||
|
if file_count <= max_files:
|
||||||
|
print(f"复制 {subdir} (包含 {file_count} 个文件)")
|
||||||
|
dest_path = os.path.join(target_dir, subdir)
|
||||||
|
|
||||||
|
# 如果目标目录已存在则跳过
|
||||||
|
if os.path.exists(dest_path):
|
||||||
|
print(f"目录已存在,跳过: {dest_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# shutil.copytree(subdir_path, dest_path)
|
||||||
|
shutil.move(subdir_path, dest_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理目录 {subdir} 时出错: {e}")
|
||||||
|
|
||||||
|
print("\n处理完成")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 配置路径
|
||||||
|
SOURCE_DIR = r"C:\Users\123\Desktop\test1\scatter_sub_class"
|
||||||
|
TARGET_DIR = "scatter_mini"
|
||||||
|
|
||||||
|
# 执行复制操作
|
||||||
|
split_subdirs(SOURCE_DIR, TARGET_DIR)
|
73
data_preprocessing/data_split.py
Normal file
73
data_preprocessing/data_split.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def is_image_file(filename):
|
||||||
|
"""检查文件是否为图像文件"""
|
||||||
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||||
|
return filename.lower().endswith(image_extensions)
|
||||||
|
|
||||||
|
def split_directory(src_dir, train_dir, val_dir, split_ratio=0.9):
|
||||||
|
"""
|
||||||
|
分割目录中的图像文件到train和val目录
|
||||||
|
:param src_dir: 源目录路径
|
||||||
|
:param train_dir: 训练集目录路径
|
||||||
|
:param val_dir: 验证集目录路径
|
||||||
|
:param split_ratio: 训练集比例(默认0.9)
|
||||||
|
"""
|
||||||
|
# 创建目标目录
|
||||||
|
Path(train_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(val_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 遍历源目录
|
||||||
|
for root, dirs, files in os.walk(src_dir):
|
||||||
|
# 获取相对路径(相对于src_dir)
|
||||||
|
rel_path = os.path.relpath(root, src_dir)
|
||||||
|
|
||||||
|
# 跳过当前目录(.)
|
||||||
|
if rel_path == '.':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建对应的目标子目录
|
||||||
|
train_subdir = os.path.join(train_dir, rel_path)
|
||||||
|
val_subdir = os.path.join(val_dir, rel_path)
|
||||||
|
os.makedirs(train_subdir, exist_ok=True)
|
||||||
|
os.makedirs(val_subdir, exist_ok=True)
|
||||||
|
|
||||||
|
# 筛选图像文件
|
||||||
|
image_files = [f for f in files if is_image_file(f)]
|
||||||
|
if not image_files:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 随机打乱文件列表
|
||||||
|
random.shuffle(image_files)
|
||||||
|
|
||||||
|
# 计算分割点
|
||||||
|
split_point = int(len(image_files) * split_ratio)
|
||||||
|
|
||||||
|
# 复制文件到训练集
|
||||||
|
for file in image_files[:split_point]:
|
||||||
|
src = os.path.join(root, file)
|
||||||
|
dst = os.path.join(train_subdir, file)
|
||||||
|
# shutil.copy2(src, dst)
|
||||||
|
shutil.move(src, dst)
|
||||||
|
|
||||||
|
# 复制文件到验证集
|
||||||
|
for file in image_files[split_point:]:
|
||||||
|
src = os.path.join(root, file)
|
||||||
|
dst = os.path.join(val_subdir, file)
|
||||||
|
# shutil.copy2(src, dst)
|
||||||
|
shutil.move(src, dst)
|
||||||
|
|
||||||
|
print(f"处理完成: {rel_path} (共 {len(image_files)} 个图像, 训练集: {split_point}, 验证集: {len(image_files)-split_point})")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 设置目录路径
|
||||||
|
SOURCE_DIR = "scatter_add"
|
||||||
|
TRAIN_DIR = "scatter_data/train"
|
||||||
|
VAL_DIR = "scatter_data/val"
|
||||||
|
|
||||||
|
print("开始分割数据集...")
|
||||||
|
split_directory(SOURCE_DIR, TRAIN_DIR, VAL_DIR)
|
||||||
|
print("数据集分割完成")
|
131
data_preprocessing/extend.py
Normal file
131
data_preprocessing/extend.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
from PIL import Image, ImageEnhance
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExtendProcessor:
|
||||||
|
def is_image_file(self, filename):
|
||||||
|
"""检查文件是否为图像文件"""
|
||||||
|
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
|
||||||
|
return filename.lower().endswith(image_extensions)
|
||||||
|
|
||||||
|
def random_cute_image(self, image_path, output_path, ratio=0.8):
|
||||||
|
"""
|
||||||
|
对图像进行随机裁剪
|
||||||
|
:param image_path: 输入图像路径
|
||||||
|
:param output_path: 输出图像路径
|
||||||
|
:param ratio: 裁剪比例,决定裁剪区域的大小,默认为0.8
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# 获取图像尺寸
|
||||||
|
width, height = img.size
|
||||||
|
|
||||||
|
# 计算裁剪后的尺寸
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
|
||||||
|
# 随机生成裁剪起始点
|
||||||
|
left = random.randint(0, width - new_width)
|
||||||
|
top = random.randint(0, height - new_height)
|
||||||
|
right = left + new_width
|
||||||
|
bottom = top + new_height
|
||||||
|
|
||||||
|
# 执行裁剪
|
||||||
|
cropped_img = img.crop((left, top, right, bottom))
|
||||||
|
|
||||||
|
# 保存裁剪后的图像
|
||||||
|
cropped_img.save(output_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {image_path} 时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def random_brightness(self, image_path, output_path, brightness_factor=None):
|
||||||
|
"""
|
||||||
|
对图像进行随机亮度调整
|
||||||
|
:param image_path: 输入图像路径
|
||||||
|
:param output_path: 输出图像路径
|
||||||
|
:param brightness_factor: 亮度调整因子,默认为随机值
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# 创建一个ImageEnhance.Brightness对象
|
||||||
|
enhancer = ImageEnhance.Brightness(img)
|
||||||
|
|
||||||
|
# 如果没有指定亮度因子,则随机生成
|
||||||
|
if brightness_factor is None:
|
||||||
|
brightness_factor = random.uniform(0.5, 1.5)
|
||||||
|
|
||||||
|
# 应用亮度调整
|
||||||
|
brightened_img = enhancer.enhance(brightness_factor)
|
||||||
|
|
||||||
|
# 保存调整后的图像
|
||||||
|
brightened_img.save(output_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {image_path} 时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def rotate_image(self, image_path, output_path, degrees):
|
||||||
|
"""旋转图像并保存到指定路径"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# 旋转图像并自动调整画布大小
|
||||||
|
rotated = img.rotate(degrees, expand=True)
|
||||||
|
rotated.save(output_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理图像 {image_path} 时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_image_directory(self, src_dir, dst_dir, same_directory, **kwargs):
|
||||||
|
"""
|
||||||
|
处理单个目录中的图像文件
|
||||||
|
:param src_dir: 源目录路径
|
||||||
|
:param dst_dir: 目标目录路径
|
||||||
|
"""
|
||||||
|
if not os.path.exists(dst_dir):
|
||||||
|
os.makedirs(dst_dir)
|
||||||
|
# 获取目录中所有图像文件
|
||||||
|
image_files = [f for f in os.listdir(src_dir)
|
||||||
|
if self.is_image_file(f) and os.path.isfile(os.path.join(src_dir, f))]
|
||||||
|
|
||||||
|
# 处理每个图像文件
|
||||||
|
for img_file in image_files:
|
||||||
|
src_path = os.path.join(src_dir, img_file)
|
||||||
|
base_name, ext = os.path.splitext(img_file)
|
||||||
|
if not same_directory:
|
||||||
|
# 复制原始文件 (另存文件夹时启用)
|
||||||
|
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
|
||||||
|
|
||||||
|
# 生成并保存旋转后的图像
|
||||||
|
for angle in [90, 180, 270]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
|
||||||
|
self.rotate_image(src_path, dst_path, angle)
|
||||||
|
for ratio in [0.8, 0.85, 0.9]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
|
||||||
|
self.random_cute_image(src_path, dst_path, ratio)
|
||||||
|
for brightness_factor in [0.8, 0.9, 1.0]:
|
||||||
|
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
|
||||||
|
self.random_brightness(src_path, dst_path, brightness_factor)
|
||||||
|
|
||||||
|
def image_extend(self, src_dir, dst_dir, same_directory=False, **kwargs):
|
||||||
|
if same_directory:
|
||||||
|
n_dst_dir = src_dir
|
||||||
|
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
|
||||||
|
else:
|
||||||
|
n_dst_dir = dst_dir
|
||||||
|
print(f"处理目录 {src_dir} 中的图像文件 保存至不同目录下")
|
||||||
|
for src_subdir in os.listdir(src_dir):
|
||||||
|
src_subdir_path = os.path.join(src_dir, src_subdir)
|
||||||
|
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
|
||||||
|
self.process_image_directory(src_subdir_path, dst_subdir_path, same_directory)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
src_dir = "./scatter_mini"
|
||||||
|
dst_dir = "./scatter_add"
|
||||||
|
image_ex = ImageExtendProcessor()
|
||||||
|
image_ex.image_extend(src_dir, dst_dir, same_directory=False)
|
20
data_preprocessing/scatter_data_preprocessing.py
Normal file
20
data_preprocessing/scatter_data_preprocessing.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from create_extra import split_subdirs
|
||||||
|
from data_split import split_directory
|
||||||
|
from extend import ImageExtendProcessor
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def data_preprocessing(conf):
|
||||||
|
split_subdirs(conf['data']['source_dir'], conf['data']['data_extra_dir'], conf['data']['max_files'])
|
||||||
|
split_directory(conf['data']['source_dir'], conf['data']['train_dir'],
|
||||||
|
conf['data']['val_dir'], conf['data']['split_ratio'])
|
||||||
|
image_ex = ImageExtendProcessor()
|
||||||
|
image_ex.image_extend(conf['data']['extra_dir'],
|
||||||
|
'',
|
||||||
|
same_directory=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open('../configs/scatter_data.yml', 'r') as f:
|
||||||
|
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
data_preprocessing(conf)
|
Reference in New Issue
Block a user