作者:筷子 | 来源:互联网 | 2023-09-01 23:12
首先:我是TensorFlow(版本2)的初学者。通过阅读,我学到很多东西。但是,我似乎找不到以下问题的答案。
我正在尝试建立一个模型,将图像分为三个标签。
正如您在下面的图表中看到的那样,我的培训专家还不错,但是验证准确性太低了。
据我了解,这可能是一个“过拟合”问题。
也许我先解释一下我要做什么:
我想使用图像作为输入。作为输出,我想接收零个或多个属于这些图像的标签(分类器)。
我期望这将是一件容易的事,因为输入图像很简单。 (只有两种颜色,并且只有0、1、2或3个可能的“标签”。
这是图像的一些例子。它们表示字段(由蓝色多边形界定)上行走的轨道(绿色):
可能的标签是:
- 十字架:(前两张图片):您可以清楚地看到绿线正在形成一个或多个“十字架”
- zig-zag :(第三张图片):不能完全确定这是否是英语中的正确术语,但我想您已经明白了;-)
- 行:绿线主要是平行线(无锯齿形或交叉形)
- 以上都不是(不知道这是否需要作为标签)
我正在使用以下模型:
batch_size = 128
epochs = 30
IMG_HEIGHT = 150
IMG_WIDTH = 150
model = Sequential([
Conv2D(16,3,padding='same',activation='relu',input_shape=(IMG_HEIGHT,IMG_WIDTH,3)),MaxPooling2D(),Dropout(0.2),Conv2D(32,activation='relu'),Conv2D(64,flatten(),Dense(512,Dense(1,activation='sigmoid')
])
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_15 (Conv2D) (None,150,16) 448
_________________________________________________________________
max_pooling2d_15 (MaxPooling (None,75,16) 0
_________________________________________________________________
dropout_10 (Dropout) (None,16) 0
_________________________________________________________________
conv2d_16 (Conv2D) (None,32) 4640
_________________________________________________________________
max_pooling2d_16 (MaxPooling (None,37,32) 0
_________________________________________________________________
conv2d_17 (Conv2D) (None,64) 18496
_________________________________________________________________
max_pooling2d_17 (MaxPooling (None,18,64) 0
_________________________________________________________________
dropout_11 (Dropout) (None,64) 0
_________________________________________________________________
flatten_5 (flatten) (None,20736) 0
_________________________________________________________________
dense_10 (Dense) (None,512) 10617344
_________________________________________________________________
dense_11 (Dense) (None,3) 1539
=================================================================
Total params: 10,642,467
Trainable params: 10,467
Non-trainable params: 0
我将3360张图像用作训练数据集,将496张用作验证数据集。
这些已经“增强”,因此这些集合包含其他现有图像的已经旋转和镜像的版本。
也许值得一提的是数据集是不平衡的:80%的图像的确包含标签“ cross”,而其他20%的图像则包含“ zig-zag”和“ rows”。
任何人都可以指导我正确的方向,如何改善我的模型?
您希望网络输出3个可能的标签,以便模型中的最后一层可以执行此操作。实际上,您可以将其更改为Dense(3,activation='sigmoid')
。
我不知道为什么在培训期间它不会给您带来任何错误,但是您还应该检查将输入和标签馈送到网络的方式。