Albumentations: Fast and Flexible Image Augmentations.

Albumentations是一个为图像的数据增强设计的python库,安装如下:

pip install albumentations

1. Albumentations中的数据增强方法

Albumentations中的数据增强方法可以分为像素级的变换(pixel-level transforms)空间级的变换(spatial-level transforms)两类。

⚪ pixel-level transforms

像素级的变换只改变图像的整体像素值,不影响图像的标签(如mask,检测框,关键点)。适用于图像分类等任务。

格式转换

引入全局噪声

平滑滤波:模糊图像

锐化滤波:增强轮廓

对比度变换

颜色变换

下游任务

⚪ spatial-level transforms

空间级的变换同时改变图像及其标注(如mask,检测框,关键点),适用于图像分割、目标检测、姿态估计等任务。

Transform Image Masks BBoxes Keypoints
Affine
CenterCrop
CoarseDropout    
Crop
CropAndPad
CropNonEmptyMaskIfExists
ElasticTransform    
Flip
GridDistortion    
GridDropout    
HorizontalFlip
Lambda
LongestMaxSize
MaskDropout    
NoOp
OpticalDistortion    
PadIfNeeded
Perspective
PiecewiseAffine
RandomCrop
RandomCropNearBBox
RandomGridShuffle    
RandomResizedCrop
RandomRotate90
RandomScale
RandomSizedBBoxSafeCrop  
RandomSizedCrop
Resize
Rotate
SafeRotate
ShiftScaleRotate
SmallestMaxSize
Transpose
VerticalFlip

2. Albumentations的使用

Albumentations的简单使用如下:

import albumentations as A
import cv2

# Declare an augmentation pipeline
transform = A.Compose([
    A.RandomCrop(width=256, height=256), # 随机裁剪
    A.HorizontalFlip(p=0.5), # 随机水平翻转
    A.RandomBrightnessContrast(p=0.2), # 随机明亮对比度
    A.OneOf([
        A.Blur(blur_limit=3, p=0.1), # 使用随机大小的内核模糊图像
        A.MedianBlur(blur_limit=3, p=0.1), # 中值滤波
    ], p=0.2),
])

# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Augment an image
transformed = transform(image=image)
transformed_image = transformed["image"]

Albumentations的两个主要方法:

  1. A.Compose:顺序执行内部的变换
  2. A.OneOf:随机选择一种变换执行

Albumentations已经集成在mmdetection框架下。使用时直接修改config文件内的train_pipeline即可:

albu_train_transforms = [
    dict(type='HorizontalFlip', p=0.5),
    dict(type='OneOf', transforms=[
            dict(type='Blur', blur_limit=3, p=0.5),
            dict(type='MedianBlur', blur_limit=3, p=0.5),
        ],
        p=0.1),
]

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Albu', transforms=albu_train_transforms),  # 数据增强
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]