Files
2025-08-06 17:03:28 +08:00

193 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import shutil
from PIL import Image, ImageEnhance
class ImageExtendProcessor:
def __init__(self, conf):
self.conf = conf
def is_image_file(self, filename):
"""检查文件是否为图像文件"""
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff')
return filename.lower().endswith(image_extensions)
def random_cute_image(self, image_path, output_path, ratio=0.8):
"""
对图像进行随机裁剪
:param image_path: 输入图像路径
:param output_path: 输出图像路径
:param ratio: 裁剪比例决定裁剪区域的大小默认为0.8
"""
try:
with Image.open(image_path) as img:
# 获取图像尺寸
width, height = img.size
# 计算裁剪后的尺寸
new_width = int(width * ratio)
new_height = int(height * ratio)
# 随机生成裁剪起始点
left = random.randint(0, width - new_width)
top = random.randint(0, height - new_height)
right = left + new_width
bottom = top + new_height
# 执行裁剪
cropped_img = img.crop((left, top, right, bottom))
# 保存裁剪后的图像
cropped_img.save(output_path)
return True
except Exception as e:
print(f"处理图像 {image_path} 时出错: {e}")
return False
def random_brightness(self, image_path, output_path, brightness_factor=None):
"""
对图像进行随机亮度调整
:param image_path: 输入图像路径
:param output_path: 输出图像路径
:param brightness_factor: 亮度调整因子,默认为随机值
"""
try:
with Image.open(image_path) as img:
# 创建一个ImageEnhance.Brightness对象
enhancer = ImageEnhance.Brightness(img)
# 如果没有指定亮度因子,则随机生成
if brightness_factor is None:
brightness_factor = random.uniform(0.5, 1.5)
# 应用亮度调整
brightened_img = enhancer.enhance(brightness_factor)
# 保存调整后的图像
brightened_img.save(output_path)
return True
except Exception as e:
print(f"处理图像 {image_path} 时出错: {e}")
return False
def rotate_image(self, image_path, output_path, degrees):
"""旋转图像并保存到指定路径"""
try:
with Image.open(image_path) as img:
# 旋转图像并自动调整画布大小
rotated = img.rotate(degrees, expand=True)
rotated.save(output_path)
return True
except Exception as e:
print(f"处理图像 {image_path} 时出错: {e}")
return False
def process_extra_directory(self, src_dir, dst_dir, same_directory, dir_name):
"""
处理单个目录中的图像文件
:param src_dir: 源目录路径
:param dst_dir: 目标目录路径
"""
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
# 获取目录中所有图像文件
image_files = [f for f in os.listdir(src_dir)
if self.is_image_file(f) and os.path.isfile(os.path.join(src_dir, f))]
# 处理每个图像文件
for img_file in image_files:
src_path = os.path.join(src_dir, img_file)
base_name, ext = os.path.splitext(img_file)
if not same_directory:
# 复制原始文件 (另存文件夹时启用)
shutil.copy2(src_path, os.path.join(dst_dir, img_file))
if dir_name == 'extra':
# 生成并保存旋转后的图像
for angle in [90, 180, 270]:
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
self.rotate_image(src_path, dst_path, angle)
for ratio in [0.8, 0.85, 0.9]:
dst_path = os.path.join(dst_dir, f"{base_name}_cute_{ratio}{ext}")
self.random_cute_image(src_path, dst_path, ratio)
for brightness_factor in [0.8, 0.9, 1.0]:
dst_path = os.path.join(dst_dir, f"{base_name}_brightness_{brightness_factor}{ext}")
self.random_brightness(src_path, dst_path, brightness_factor)
elif dir_name == 'train':
# 生成并保存旋转后的图像
for angle in [90, 180, 270]:
dst_path = os.path.join(dst_dir, f"{base_name}_rotated_{angle}{ext}")
self.rotate_image(src_path, dst_path, angle)
def image_extend(self, src_dir, dst_dir, same_directory=False, dir_name=None):
if same_directory:
n_dst_dir = src_dir
print(f"处理目录 {src_dir} 中的图像文件 保存至同一目录下")
else:
n_dst_dir = dst_dir
print(f"处理目录 {src_dir} 中的图像文件 保存至不同目录下")
for src_subdir in os.listdir(src_dir):
src_subdir_path = os.path.join(src_dir, src_subdir)
dst_subdir_path = os.path.join(n_dst_dir, src_subdir)
if dir_name == 'extra':
self.process_extra_directory(src_subdir_path,
dst_subdir_path,
same_directory,
dir_name)
if dir_name == 'train':
if len(os.listdir(src_subdir_path)) < 50:
self.process_extra_directory(src_subdir_path,
dst_subdir_path,
same_directory,
dir_name)
def random_remove_image(self, subdir_path, max_count=1000):
"""
随机删除子目录中的图像文件直到数量不超过max_count
:param subdir_path: 子目录路径
:param max_count: 最大允许的图像数量
"""
# 统计图像文件数量
image_files = [f for f in os.listdir(subdir_path)
if self.is_image_file(f) and os.path.isfile(os.path.join(subdir_path, f))]
current_count = len(image_files)
# 如果图像数量不超过max_count则无需删除
if current_count <= max_count:
print(f"无需处理 {subdir_path} (包含 {current_count} 个图像)")
return
# 计算需要删除的文件数
remove_count = current_count - max_count
# 随机选择要删除的文件
files_to_remove = random.sample(image_files, remove_count)
# 删除选中的文件
for file in files_to_remove:
file_path = os.path.join(subdir_path, file)
os.remove(file_path)
print(f"已删除 {file_path}")
def control_number(self):
if self.conf['extend']['extend_extra']:
self.image_extend(self.conf['extend']['extend_extra_dir'],
'',
same_directory=self.conf['extend']['extend_same_dir'],
dir_name='extra')
if self.conf['extend']['extend_train']:
self.image_extend(self.conf['extend']['extend_train_dir'],
'',
same_directory=self.conf['extend']['extend_same_dir'],
dir_name='train')
if self.conf['limit']['count_limit']:
self.random_remove_image(self.conf['limit']['limit_dir'],
max_count=self.conf['limit']['limit_count'])
if __name__ == "__main__":
src_dir = "./scatter_mini"
dst_dir = "./scatter_add"
image_ex = ImageExtendProcessor()
image_ex.image_extend(src_dir, dst_dir, same_directory=False)