Python深度学习albumentations数据增强库

时间:2022-12-31 15:43:27

数据增强的必要性

深度学习在最近十年得以风靡得益于计算机算力的提高以及数据资源获取的难度下降。一个好的深度模型往往需要大量具有label的数据,使得模型能够很好的学习这种数据的分布。而给数据打标签往往是一件耗时耗力的工作。
拿cv里的经典任务为例,classification需要人准确识别物品类别或者生物种类,object detection需要人工画出bounding box, 确定其坐标,semantic segmentation甚至需要在像素级别进行标签标注。对于一些专业领域的图像标注,依赖于专业人士的知识素养(例如医疗,遥感等),这无疑对有标签数据的收集带来了麻烦。

那么有没有什么方法能够在数据集规模很小的情况,尽可能提高模型的表现力呢?

1.transfer learning或者说是domain adaptation,这种方法期望降低源域与目标域之间的数据分布差异,使得具有大量标注数据的源域帮助提升模型的训练效果。

2.对现有数据进行数据增强深度学习能够学习到的空间不变性,像素级别的不变性特征都有限。所以对图片进行平移,缩放,旋转,改变色调值等方法,可以使得模型见过各种类型的数据,提高模型在测试数据上的判别力。

 

albumentations

上面我只是笼统的谈了下数据增强的必要性,对于其更加深刻的理解往往需要在实验中不断体会或者总结。

albumentations的安装

这步没什么好说,利用包管理工具直接安装。

pip install albumentations

albumentations的流水线工作方式

导入所需要的库

import albumentations as A
from PIL import Image
import numpy as np

读入数据这步需要其它库进行配合,可以利用CV2,PIL等,这里出于习惯我选择使用PIL

image_path = './your/image/path'
image = np.array(Image.open(image_path))  # 获得了一个[H, W, C]的三维数组

创建流水线

transform = A.Compose([
	A.Resize(width=256, height=256),
	A.HorizontalFlip(p=0.5),
	A.RandomBrightnessContrast(p=0.2)
])

A.Compose中需要传入一个list, list包含了一系列数据增强操作的对象。这里可以理解为A.Compose返回一条工业流水线, 第一步进行A.Resize操作,将图片缩放成256 * 256;第二步在上一步的基础上以0.5的概率对图片进行镜像翻转(p这个参数代表进行这个操作的概率);第三步同理,对第一步第二步处理完的图像以0.2的概率进行亮度和对比度的改变。

transform就是我们将要对图片进行的操作流程,下一步就需要将图片数据传入进去。

获得数据增强完的图片数据

transformed = transform(image=image)
tranformed_image = transformed['image']

将图片数据传递给transform(很明显这是个可调用的对象)的image参数,它会返回一个处理完的对象,对象的key值image对应的value就是处理完的图像数据。

图像处理结果展示

Python深度学习albumentations数据增强库

object detection的数据增强

上述对albumentations流水线工作过程的简要说明其实就是classification任务的大致流程。
当然,albumentations如果仅仅只能做到上述的功能,那么torchvision中transform API可以把它完全替代,并且它也满足不了大多数cv任务的数据增强需求。

拿object detection为例,一张图片数据往往对应了若干个bounding box,如果你对图片数据进行的操作具有空间变换性,那么原有的bounding box数据画出的目标框必然已经对应不了图片中的对象了。
所以对图片数据进行变换的同时也必须对bounding box数据进行变换,保持二者的一致性。

绘制目标框

在介绍object detection的数据增强之前,先介绍一个绘制目标框的函数。在albumentation中展示的代码是用cv2实现,个人觉得画出的bounding box不太美观,下面使用的是matplotlib实现的代码。

import matplotlib.pyplot as plt
import matplotlib.patches as patches
def visualize_bbox(img, bbox, class_name, color, ax):
	"""
	img:图片数据 (H, W, C)数据格式
	bbox:array或者tensor, 假定数据格式是 [x_mid, y_mid, width, height]
	classname:str 目标框对应的种类
	color:str
	thickness:目标框的宽度
	"""
	x_mid, y_mid, width, height = bbox
	x_min = int(x_mid - width / 2)
	y_min = int(y_mid - height / 2)
	# 画目标检测框
	rect = patches.Rectangle((x_min, y_min), 
								width, 
								height, 
								linewidth=3,
								edgecolor=color,
								facecolor="none"
								)
	ax.imshow(img)
	ax.add_patch(rect)
	ax.text(x_min + 1, y_min - 3, class_name, fontSize=10, bbox={'facecolor':color, 'pad': 3, 'edgecolor':color})
def visualize(img, bboxes, category_ids, category_id_to_name, category_id_to_color):
	fig, ax = plt.subplots(1, figsize=(8, 8))
	ax.axis('off')
	for box, category in zip(bboxes, category_ids):
		class_name = category_id_to_name[category]
		color = category_id_to_color[category]
		visualize_bbox(img, box, class_name, color, ax)
	plt.show()

Python深度学习albumentations数据增强库

对bounding box进行空间变换

导入所需要的库

import albumentations as A
from PIL import Image
import numpy as np
image_path = './your/image/path'
image = np.array(Image.open(image_path))

构造流水线

transform = A.Compose([
	A.Resize(width=256, height=256),
	A.HorizontalFlip(p=0.5),
	A.RandomBrightnessContrast(p=0.2)
], bbox_params = A.BboxParams(format='yolo'))

相较于最简单的流水线(for classification),oject detection需要传入一个叫做bbox_params的参数,它接收的是用于配置bounding box参数的对象。
format表示的是bounding box数据的格式,albumentations提供了4种格式。

Python深度学习albumentations数据增强库

1.pascal_voc [x_min, y_min, x_max, y_max] 数值并没有归一化

直接使用像素值[98, 345, 420, 462]

2.albumentations [x_min, y_min, x_max, y_max] 与上一种格式不一样的是

这里值都是normalized 做了归一化处理[0.153125, 0.71875, 0.65625, 0.9625]

3.coco [x_min, y_min, width, height] 没有归一化

4.yolo [x_center, y_center, width, height] 归一化了

传入image数据和bounding box数据进行变换

label = np.array([
      [0.339, 0.6693333333333333, 0.402, 0.42133333333333334],
      [0.379, 0.5666666666666667, 0.158, 0.3813333333333333],
      [0.612, 0.7093333333333333, 0.084, 0.3466666666666667],
      [0.555, 0.7026666666666667, 0.078, 0.34933333333333333]
])  # normalized (x_center, y_center, width, height) 对应format yolo
category_ids = [12, 14, 14, 14]
category_id_to_name = {
  12: 'horse',
  14: 'people'
}
category_id_to_color = {
  12: 'yellow',
  14: 'red'
}
transformed = transform(image=image,bboxes=label)
transformed_image = transformed['image']
transformed_bboxes = transformed['bboxes']
height, width, _ = transformed_image.shape
transformed_bboxes[:, [0, 2]] = transformed_bboxes[:, [0, 2]] * width
transformed_bboxes[:, [1, 3]] = transformed_bboxes[:, [1, 3]] * height
visualize(transformed_image, transformed_bboxes, category_ids, category_id_to_name, category_id_to_color)

Python深度学习albumentations数据增强库

BboxParams中不止format这一个参数。当我们做随机裁剪操作的时候,bounding box完全可能只保留了一部分,当保留比例小于某一个阈值的时候,我们可以将其drop掉,具体的操作细节可以查看albumentations的相关教程。

semantic segmentation的数据增强

object detection和semantic segmentation在像素级别的data agumentation和classification没什么区别,而在空间变换上segmentation没有bounding box变换,与之对应的是mask变换。
mask是像素级别的label,与原图中的像素一一对应。
albumentations上的教程使用的是kaggle上的数据集,这里为了方便展示我们使用同样的数据集。

数据集网址

Python深度学习albumentations数据增强库

下载完数据并解压缩完成后可以得到如上的目录结构,通过train.csv文件可以得到所用的image和mask名称。

image = np.array(Image.open(image_path))  # 这里使用的是/train/images/0fea4b5049.png
mask = np.array(Image.open(mask_path))  # /train/masks/0fea4b5049.png

下面介绍一下展示结果的函数

from matplotlib import pyplot as plt
def visualize(image, mask, original_image=None, original_mask=None):
	fontsize=8
	if original_image == None and original_mask == None:
		fg, ax = plt.subplots(2, 1, figsize=(8, 8))
		ax[0].axis('off')
		ax[0].imshow(image)
		ax[0].set_title('image', fontsize=fontsize)
		ax[1].axis('off')
		ax[1].imshow(mask)
		ax[1].set_title('mask', fontsize=fontsize)
	else:
		fg, ax = plt.subplots(2, 2, figsize=(8, 8))
		ax[0, 0].axis('off')
		ax[0, 0].imshow(original_image)
		ax[0, 0].set_title('Original Image', fontsize=fontsize)
		ax[0, 1].axis('off')
		ax[0, 1].imshow(original_mask)
		ax[0, 1].set_title('Original Mask', fontsize=fontsize)
		ax[1, 0].axis('off')
		ax[1, 0].imshow(image)
		ax[1, 0].set_title('Transformed Image', fontsize=fontsize)
		ax[1, 1].axis('off')
		ax[1, 1].imshow(mask)
		ax[1, 1].set_title('Transformed Mask', fontsize=fontsize)	

data agumentation的流水线操作

aug = A.PadIfNeeded(min_height=128, min_width=128, p=1)
augmented = aug(image=image, mask=mask)
augmented_img = augmented['image']
augmented_mask = augmented['mask']
visualize(augmented_img, augmented_mask, original_image=image, original_mask=mask)

这里相较于classification就是多了个mask函数,将mask数据直接传进入即可。

Python深度学习albumentations数据增强库

padding的填充方式默认是reflection, 可以看到变换以后的mask右侧多了些黄色区域。
对于一些分割任务而言,我们不想增加或者删除额外的信息,所以往往采用 Non destructive transformations(非破坏性变换)如HorizontalFlip(水平翻转), VerticalFlip(垂直翻转), RandomRotate90(Randomly rotates by 0, 90, 180, 270 degrees)

aug = A.RandomRotate(p=1)
augmented = aug(image=image, mask=mask)
augmented_image = augmented['image']
augmented_mask = augmented['mask']
visualize(augmented_image, augmented_mask, original_image=image, original_mask=mask)

Python深度学习albumentations数据增强库

下面介绍下多个transform综合起来的流水线操作

original_height, original_width = image.shape[:2]
aug = A.Compose([
  A.OneOf([
      A.RandomSizedCrop(min_max_height=(50, 101), height=original_height, width=original_width, p=0.5),
      A.PadIfNeeded(min_height=original_height, min_width=original_width, p=0.5)
  ]),
  A.VerticalFlip(p=0.5),
  A.RandomRotate90(p=0.5),
  A.OneOf([
      A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
      A.GridDistortion(p=0.5),
      A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=1)
  ], p=0.8)
])
augmented = aug(image=image, mask=mask)
image_medium = augmented['image']
mask_medium = augmented['mask']
visualize(image_medium, mask_medium, original_image=image, original_mask=mask)

这里一个较新的知识点是A.OneOf,它接收的transform对象的list, 从中按照权重随机选择一个进行变换,它本身也有概率。

Python深度学习albumentations数据增强库

可以看到OneOf将list中的transform的概率进行归一化再重新分配。所以这里transform的p不再理解为概率,而是权重,取到1,甚至比1大都没有关系。

以上就是Python深度学习albumentations数据增强库的详细内容,更多关于Python数据增强库albumentations的资料请关注服务器之家其它相关文章!

原文链接:https://blog.csdn.net/qq_43152622/article/details/120541332