热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

深度学习基础系列(十一)|Keras中图像增强技术详解

在深度学习中,数据短缺是我们经常面临的一个问题,虽然现在有不少公开数据集,但跟大公司掌握的海量数据集相比,数量上仍然偏少,而某些特定领域的数据采集更是非常困难。根据之前的学习可知,

  在深度学习中,数据短缺是我们经常面临的一个问题,虽然现在有不少公开数据集,但跟大公司掌握的海量数据集相比,数量上仍然偏少,而某些特定领域的数据采集更是非常困难。根据之前的学习可知,数据量少带来的最直接影响就是过拟合。那有没有办法在现有少量数据基础上,降低或解决过拟合问题呢?

      答案是有的,就是数据增强技术。我们可以对现有的数据,如图片数据进行平移、翻转、旋转、缩放、亮度增强等操作,以生成新的图片来参与训练或测试。这种操作可以将图片数量提升数倍,由此大大降低了过拟合的可能。本文将详解图像增强技术在Keras中的原理和应用。

 

一、Keras中的ImageDataGenerator类

  图像增强的官网地址是:https://keras.io/preprocessing/image/ ,API使用相对简单,功能也很强大。

  先介绍的是ImageDataGenerator类,这个类定义了图片该如何进行增强操作,其API及参数定义如下:

keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False, #输入值按照均值为0进行处理     samplewise_center=False, #每个样本的均值按0处理     featurewise_std_normalization=False, #输入值按照标准正态化处理
    samplewise_std_normalization=False, #每个样本按照标准正态化处理     zca_whitening=False, # 是否开启增白     zca_epsilon=1e-06,     rotation_range=0, #图像随机旋转一定角度,最大旋转角度为设定值     width_shift_range=0.0, #图像随机水平平移,最大平移值为设定值。若值为小于1的float值,则可认为是按比例平移,若大于1,则平移的是像素;若值为整型,平移的也是像素;假设像素为2.0,则移动范围为[-1,1]之间     height_shift_range=0.0, #图像随机垂直平移,同上     brightness_range=None, # 图像随机亮度增强,给定一个含两个float值的list,亮度值取自上下限值间     shear_range=0.0, # 图像随机修剪     zoom_range=0.0, # 图像随机变焦      channel_shift_range=0.0,     fill_mode=\'nearest\', #填充模式,默认为最近原则,比如一张图片向右平移,那么最左侧部分会被临近的图案覆盖     cval=0.0,     horizontal_flip=False, #图像随机水平翻转     vertical_flip=False, #图像随机垂直翻转     rescale=None, #缩放尺寸     preprocessing_function=None,     data_format=None,     validation_split=0.0,     dtype=None)

   下文将以mnist和花类的数据集进行图片操作,其中花类(17种花,共1360张图片)数据集可见我的百度网盘: https://pan.baidu.com/s/1YDA_VOBlJSQEijcCoGC60w 。让我们以直观地方式看看各参数能带来什么样的图片变化。

   随机旋转

  我们可用mnist数据集对图片进行随机旋转,旋转的最大角度由参数定义。

from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K

K.set_image_dim_ordering(\'th\')

(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype(\'float32\')

# 创建图像生成器,指定对图像操作的内容
datagen = ImageDataGenerator(rotation_range=90)
# 图像生成器要训练的数据
datagen.fit(train_data)

# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
    for i in range(0, 9):
        # 创建一个 3*3的九宫格,以显示图片
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap(\'gray\'))
    pyplot.show()
    break

  生成结果为:

  随机平移

  我们可用花类数据集对图片进行随机平移,可以在垂直和水平方向上平移,平移最大值由参数定义。

from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_img

IMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = \'/home/yourname/Documents/tensorflow/images/17flowerclasses/train\'
TEST_PATH = \'/home/yourname/Documents/tensorflow/images/17flowerclasses/test\'
FLOWER_CLASSES = [\'Bluebell\', \'ButterCup\', \'ColtsFoot\', \'Cowslip\', \'Crocus\', \'Daffodil\', \'Daisy\',
                  \'Dandelion\', \'Fritillary\', \'Iris\', \'LilyValley\', \'Pansy\', \'Snowdrop\', \'Sunflower\',
                  \'Tigerlily\', \'tulip\', \'WindFlower\']

# 创建图像生成器,指定对图像操作的内容,平移的最大比例为50%
train_datagen = ImageDataGenerator(width_shift_range=0.5, height_shift_range=0.5)

# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):
    for i in range(0, 9):
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(array_to_img(X_batch[i]))
    pyplot.show()
    break

  生成结果为:

  可以观察到,图片除了实现平移外,其原来的位置都被最近的图案给填充,因为默认给的填充方式是nearest。

  随机亮度调整

  我们可用花类数据集对图片进行随机亮度调整,亮度范围由参数定义。

from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_img

IMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = \'/home/yourname/Documents/tensorflow/images/17flowerclasses/train\'
TEST_PATH = \'/home/yourname/Documents/tensorflow/images/17flowerclasses/test\'
FLOWER_CLASSES = [\'Bluebell\', \'ButterCup\', \'ColtsFoot\', \'Cowslip\', \'Crocus\', \'Daffodil\', \'Daisy\',
                  \'Dandelion\', \'Fritillary\', \'Iris\', \'LilyValley\', \'Pansy\', \'Snowdrop\', \'Sunflower\',
                  \'Tigerlily\', \'tulip\', \'WindFlower\']

# 创建图像生成器,指定对图像操作的内容,亮度范围在0.1~10之间随机选择
train_datagen = ImageDataGenerator(brightness_range=[0.1, 10])

# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):
    for i in range(0, 9):
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(array_to_img(X_batch[i]))
    pyplot.show()
    break

  生成结果为:

  随机焦距调整

  我们可用mnist数据集对图片进行随机焦距调整,焦距调整值由参数定义。

from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K

K.set_image_dim_ordering(\'th\')

(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype(\'float32\')

# 创建图像生成器,指定对图像操作的内容,焦距值在0.1~1之间
datagen = ImageDataGenerator(zoom_range=[0.1, 1])
# 图像生成器要训练的数据
datagen.fit(train_data)

# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
    for i in range(0, 9):
        # 创建一个 3*3的九宫格,以显示图片
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap(\'gray\'))
    pyplot.show()
    break

  生成结果为:

  可以看出这跟相机调焦一样,可以放大或缩小焦距。

  随机翻转

  我们可用花类数据集对图片进行随机翻转。

from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_img

IMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = \'/home/hutao/Documents/tensorflow/images/17flowerclasses/train\'
TEST_PATH = \'/home/hutao/Documents/tensorflow/images/17flowerclasses/test\'
FLOWER_CLASSES = [\'Bluebell\', \'ButterCup\', \'ColtsFoot\', \'Cowslip\', \'Crocus\', \'Daffodil\', \'Daisy\',
                  \'Dandelion\', \'Fritillary\', \'Iris\', \'LilyValley\', \'Pansy\', \'Snowdrop\', \'Sunflower\',
                  \'Tigerlily\', \'tulip\', \'WindFlower\']

# 创建图像生成器,指定对图像操作的内容,图片随机翻转
train_datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True)

# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):
    for i in range(0, 9):
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(array_to_img(X_batch[i]))
    pyplot.show()
    break

  生成结果为:

  从上图可看出,有些图片水平翻转了,有些是垂直翻转了。

  ZCA图像增白

  说实在我不太清楚该技术有何用,用花类图片实验结果显示zca不支持,可以用mnist数据集来看看效果。

from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K

K.set_image_dim_ordering(\'th\')

(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype(\'float32\')

# 创建图像生成器,指定对图像操作的内容,增白图片
datagen = ImageDataGenerator(zca_whitening=True)
# 图像生成器要训练的数据
datagen.fit(train_data)

# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
    for i in range(0, 9):
        # 创建一个 3*3的九宫格,以显示图片
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap(\'gray\'))
    pyplot.show()
    break

  生成结果为:

  特征标准化

  特征标准化的含义是使图片的像素均值为0,标准差为1,不过我试了多次,直观效果不明显。

from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as K

K.set_image_dim_ordering(\'th\')

(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype(\'float32\')

# 创建图像生成器,指定对图像操作的内容,允许图片标准化处理
datagen = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True)
# 图像生成器要训练的数据
datagen.fit(train_data)

# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):
    for i in range(0, 9):
        # 创建一个 3*3的九宫格,以显示图片
        pyplot.subplot(330 + 1 + i)
        pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap(\'gray\'))
    pyplot.show()
    break

  生成结果为:

  就个人而言,我倾向于在图像增强中使用旋转、亮度调整、翻转和平移操作。

 

二、Keras如何进行图像增强数据训练

  在之前的文章中我已经展现过数据增强的使用。在Keras中,增强图片有三种来源:

  • 图片来源于已知数据集,如mnist、cifar,数据格式为numpy格式;
  • 图片来源于我们自己搜集的图片,如本文引入的花类数据集,其图片为jpg、png等格式;
  • 图片来源于panda数据集;

  其中数据来源已知数据集,其操作方法如下:

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

#生成器绑定训练集
datagen.fit(x_train)

# 模型绑定生成器,并不停地迭代产生数据,可指定迭代次数,假设图片总数为1000张,batch默认为32,则每次迭代需要产生1000/32=32个步骤
history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                    steps_per_epoch=len(x_train) / 32, epochs=epochs)

  数据来源图片集,其操作方法如下:

batch_size = 32
# 迭代50次
epochs = 50
# 依照模型规定,图片大小被设定为224
IMAGE_SIZE = 224
TRAIN_PATH = \'/home/yourname/Documents/tensorflow/images/17flowerclasses/train\'
TEST_PATH = \'/home/yourname/Documents/tensorflow/images/17flowerclasses/test\'
FLOWER_CLASSES = [\'Bluebell\', \'ButterCup\', \'ColtsFoot\', \'Cowslip\', \'Crocus\', \'Daffodil\', \'Daisy\',\'Dandelion\', \'Fritillary\', \'Iris\', \'LilyValley\', \'Pansy\', \'Snowdrop\', \'Sunflower\',\'Tigerlily\', \'tulip\', \'WindFlower\']

# 使用数据增强
train_datagen = ImageDataGenerator(rotation_range=90)
# 可指定输出图片大小,因为深度学习要求训练图片大小保持一致 train_generator
= train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE), classes=FLOWER_CLASSES) test_datagen = ImageDataGenerator() test_generator = test_datagen.flow_from_directory(directory=TEST_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE), classes=FLOWER_CLASSES) # 运行模型 history = model.fit_generator(train_generator, epochs=epochs, validation_data=test_generator)

   需要说明的是,这些增强图片都是在内存中实时批量迭代生成的,不是一次性被读入内存,这样可以极大地节约内存空间,加快处理速度。若想保留中间过程生成的增强图片,可以在上述方法中添加保存路径等参数,此处不再赘述。

 

三、结论

  本文介绍了如何在Keras中使用图像增强技术,对图片可以进行各种操作,以生成数倍于原图片的增强图片集。这些数据集可帮助我们有效地对抗过拟合问题,更好地生成理想的模型。


推荐阅读
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 本文介绍了贝叶斯垃圾邮件分类的机器学习代码,代码来源于https://www.cnblogs.com/huangyc/p/10327209.html,并对代码进行了简介。朴素贝叶斯分类器训练函数包括求p(Ci)和基于词汇表的p(w|Ci)。 ... [详细]
author-avatar
余夫子曰
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有