目标检测、分割的数据增强策略
下面是模型训练时使用到的数据增强策略,包括几何变换,可以模拟低光照、恶劣天气场景,可以将每个数据增强方法的p都设置为1,逐步观察该方法的增强效果。照一张想要数据增强的图片,放到dataaug文件夹下即可。
import warningswarnings.filterwarnings('ignore') import os, shutil, cv2, tqdm import numpy as npnp.random.seed(0) import albumentations as A from PIL import ImageIMAGE_PATH = 'dataaug' AUG_IMAGE_PATH = 'results' SHOW_SAVE_PATH = 'results'ENHANCEMENT_LOOP = 1 #数据增强策略 ENHANCEMENT_STRATEGY = A.Compose([#几何翻转A.Compose([#仿射变换包括缩放、平移、旋转、剪切等,同时可以保持图像中的平行线和平行关系A.Affine(scale=[0.5, 1.5], translate_percent=[0.0, 0.3], rotate=[-360, 360], shear=[-45, 45], keep_ratio=True,cval_mask=0, p=1),#-scale在x,y轴上缩放#-translate_percent平移的百分比(相对于图像的宽度和高度)#-rotate旋转角度范围,单位是度#-shear剪切变换会使图像在水平或垂直方向上发生倾斜(类似斜切效果)#keep_ratio如果设置为`True`,则在缩放时保持图像的宽高比不变(即x和y轴使用相同的缩放因子)#cval_mask=0当进行变换时,图像边界外区域的填充值(对于掩模maskA.BBoxSafeRandomCrop(erosion_rate=0.2, p=0.1), # 随机裁剪图像的同时确保边界框(Bounding Boxes)不会被过度裁剪或丢失A.D4(p=0.1),#在数学中,**D4群**(二面体群)是指正方形对称的变换群,包含以下8种操作:# 1. 恒等变换(0°旋转)2. 90°旋转3. 180°旋转4. 270°旋转5. 水平翻转6. 垂直翻7. 主对角线翻转8. 反对角线翻转#`A.D4`操作会从这8种变换中**随机选择一种**应用于图像。A.ElasticTransform(p=1), #随机地扭曲图像像素,使其在图像上产生弹性变形的效果A.HorizontalFlip(p=0.05), # 水平翻转A.VerticalFlip(p=0.05), # 垂直翻转A.GridDistortion(p=1),#以网格状的方式扭曲图像A.Perspective(p=1),#该变换会使图像产生透视失真,就像从不同的角度观察图像一样,直线可能会弯曲], p=1.0),A.Compose([A.GaussNoise(p=1), # 添加高斯噪声后,图像会出现颗粒状的随机点,类似于老式相机在暗光条件下拍摄的照片或电视静态噪声。# 高斯噪声应用场景:# - 低光照条件下的图像处理(如监控、自动驾驶夜视系统)。# - 医学影像(如X光、MRI)中模拟噪声以提高模型鲁棒性。# - 任何需要模型对噪声不敏感的任务。A.ISONoise(p=1), # 模拟相机高感光度(ISO)设置产生的图像噪声的数据增强操作#-ISO100-400光线充足,800-1600轻度 彩色噪点室内环境,3200-6400黄昏弱光,明显颗粒噪点;12800+极暗环境,严重噪点 +色偏A.ImageCompression(quality_lower=0, quality_upper=50, p=1),#用于模拟JPEG图像压缩效果,降低图片质量A.RandomBrightnessContrast(p=1), #用来随机调整图像的亮度和对比度的A.RandomFog(p=1), # 专门用于在图像上模拟雾天效果A.RandomRain(p=1), # 模拟雨A.RandomSnow(p=1), # 模拟雪A.RandomShadow(p=1), # 模拟自然光线产生的阴影A.RandomSunFlare(p=1), # 图像中模拟太阳光晕(镜头光晕)效果,通常出现在逆光拍摄时太阳直射镜头的情况。A.ToGray(p=1), # 使用加权平均法将RGB图像转换为灰度图像 gray = 0.299 * R + 0.587 * G + 0.114 * B, 这个权重基于人眼对不同颜色的敏感度(绿色最高,红色次之,蓝色最低)], p=1.0) ], is_check_shapes=False)def data_aug_single(images_name):file_heads, postfix = os.path.splitext(images_name)images_path = os.path.join(IMAGE_PATH, images_name)if os.path.exists(images_path):# 读取原始图片并转换为OpenCV格式original_pil = Image.open(images_path)original_cv = cv2.cvtColor(np.array(original_pil), cv2.COLOR_RGB2BGR)for i in range(ENHANCEMENT_LOOP):# 增强后图片保存路径aug_image_path = os.path.join(AUG_IMAGE_PATH, f'{file_heads}_{i + 1:03d}{postfix}')# 拼接后图片保存路径combined_path = os.path.join(AUG_IMAGE_PATH, f'{file_heads}_combined_{i + 1:03d}{postfix}')try:# 应用数据增强transformed = ENHANCEMENT_STRATEGY(image=np.array(original_pil))transformed_image = transformed['image']# 将增强图片转换为OpenCV格式aug_cv = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR)# 保存增强图片cv2.imwrite(aug_image_path, aug_cv)# 创建拼接图片(水平拼接)combined_image = np.hstack((original_cv, aug_cv))# 添加分隔线separator = np.zeros((original_cv.shape[0], 5, 3), dtype=np.uint8)separator.fill(255) # 白色分隔线combined_image = np.hstack((original_cv, separator, aug_cv))# 添加文字标签font = cv2.FONT_HERSHEY_SIMPLEXcv2.putText(combined_image, "Original", (10, 30), font, 1, (0, 255, 0), 2)cv2.putText(combined_image, "Augmented",(original_cv.shape[1] + 15, 30), font, 1, (0, 255, 0), 2)# 保存拼接图片cv2.imwrite(combined_path, combined_image)print(f'拼接图片保存成功: {combined_path}')except Exception as e:print(f"处理图片 {images_name} 时出错: {str(e)}")continue def data_aug():if os.path.exists(AUG_IMAGE_PATH):shutil.rmtree(AUG_IMAGE_PATH)os.makedirs(AUG_IMAGE_PATH, exist_ok=True)for images_name in tqdm.tqdm(os.listdir(IMAGE_PATH)):data_aug_single(images_name)if __name__ == '__main__':#show_labels(IMAGE_PATH)#show_labels(AUG_IMAGE_PATH, AUG_LABEL_PATH)data_aug()