热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

如何改进模型以防止过度拟合,从而实现非常简单的图像分类

首先:我是TensorFlow(版本2)的初学者。通过阅读,我学到很多东西。但是,我似

首先:我是TensorFlow(版本2)的初学者。通过阅读,我学到很多东西。但是,我似乎找不到以下问题的答案。

我正在尝试建立一个模型,将图像分为三个标签。
正如您在下面的图表中看到的那样,我的培训专家还不错,但是验证准确性太低了。

如何改进模型以防止过度拟合,从而实现非常简单的图像分类

据我了解,这可能是一个“过拟合”问题。

也许我先解释一下我要做什么:

我想使用图像作为输入。作为输出,我想接收零个或多个属于这些图像的标签(分类器)。
我期望这将是一件容易的事,因为输入图像很简单。 (只有两种颜色,并且只有0、1、2或3个可能的“标签”。
这是图像的一些例子。它们表示字段(由蓝色多边形界定)上行走的轨道(绿色):

如何改进模型以防止过度拟合,从而实现非常简单的图像分类

可能的标签是:


  1. 十字架:(前两张图片):您可以清楚地看到绿线正在形成一个或多个“十字架”

  2. zig-zag :(第三张图片):不能完全确定这是否是英语中的正确术语,但我想您已经明白了;-)

  3. 行:绿线主要是平行线(无锯齿形或交叉形)

  4. 以上都不是(不知道这是否需要作为标签)

我正在使用以下模型:

batch_size = 128
epochs = 30
IMG_HEIGHT = 150
IMG_WIDTH = 150
model = Sequential([
Conv2D(16,3,padding='same',activation='relu',input_shape=(IMG_HEIGHT,IMG_WIDTH,3)),MaxPooling2D(),Dropout(0.2),Conv2D(32,activation='relu'),Conv2D(64,flatten(),Dense(512,Dense(1,activation='sigmoid')
])
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_15 (Conv2D) (None,150,16) 448
_________________________________________________________________
max_pooling2d_15 (MaxPooling (None,75,16) 0
_________________________________________________________________
dropout_10 (Dropout) (None,16) 0
_________________________________________________________________
conv2d_16 (Conv2D) (None,32) 4640
_________________________________________________________________
max_pooling2d_16 (MaxPooling (None,37,32) 0
_________________________________________________________________
conv2d_17 (Conv2D) (None,64) 18496
_________________________________________________________________
max_pooling2d_17 (MaxPooling (None,18,64) 0
_________________________________________________________________
dropout_11 (Dropout) (None,64) 0
_________________________________________________________________
flatten_5 (flatten) (None,20736) 0
_________________________________________________________________
dense_10 (Dense) (None,512) 10617344
_________________________________________________________________
dense_11 (Dense) (None,3) 1539
=================================================================
Total params: 10,642,467
Trainable params: 10,467
Non-trainable params: 0

我将3360张图像用作训练数据集,将496张用作验证数据集。
这些已经“增强”,因此这些集合包含其他现有图像的已经旋转和镜像的版本。

也许值得一提的是数据集是不平衡的:80%的图像的确包含标签“ cross”,而其他20%的图像则包含“ zig-zag”和“ rows”。

任何人都可以指导我正确的方向,如何改善我的模型?


您希望网络输出3个可能的标签,以便模型中的最后一层可以执行此操作。实际上,您可以将其更改为Dense(3,activation='sigmoid')

我不知道为什么在培训期间它不会给您带来任何错误,但是您还应该检查将输入和标签馈送到网络的方式。


推荐阅读
author-avatar
筷子
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有