基于Keras实现的卷积神经网络(CNN)示例
作者:微笑5885 | 来源:互联网 | 2024-12-03 19:35
本文介绍了一个使用Keras框架构建的卷积神经网络(CNN)实例,主要利用了Keras提供的MNIST数据集以及相关的层,如Dense、Dropout、Activation等,构建了一个具有两层卷积和两层全连接层的CNN模型。
### 基于Keras的CNN实现 本文展示了一个使用Keras框架构建的卷积神经网络(CNN)实例。该模型主要用于处理图像分类任务,特别是手写数字识别。通过使用Keras内置的MNIST数据集,我们构建了一个包含两层卷积层和两层全连接层的简单CNN模型。经过12轮训练后,模型的准确率达到了99.25%,展示了CNN在图像识别领域的强大能力。 ```python import numpy as np np.random.seed(1337) from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, Dropout, Activation, Flatten from keras.layers import Conv2D, MaxPooling2D from keras.utils import np_utils from keras import backend as K batch_size = 128 num_classes = 10 epochs = 12 img_rows, img_cols = 28, 28 filters = 32 pool_size = (2, 2) kernel_size = (3, 3) (x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = np_utils.to_categorical(y_train, num_classes) y_test = np_utils.to_categorical(y_test, num_classes) model = Sequential() model.add(Conv2D(filters, kernel_size, padding='same', input_shape=input_shape)) model.add(Activation('relu')) model.add(Conv2D(filters, kernel_size)) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=pool_size)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128)) model.add(Activation('relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes)) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy']) model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test)) score = model.evaluate(x_test, y_test, verbose=0) print('Test score:', score[0]) print('Test accuracy:', score[1]) ``` ### Keras关键组件详解 #### 1. Keras的Backend支持 Keras支持多种后端,包括Theano和TensorFlow。不同的后端在数据格式上有所差异,例如Theano模式下,图像数据的形状为(样本数, 通道数, 高度, 宽度),而TensorFlow模式下则为(样本数, 高度, 宽度, 通道数)。这些差异在代码中通过`K.image_data_format()`来自动适应。 #### 2. 二维卷积层(Conv2D) `Conv2D`层用于对二维输入进行卷积操作。主要参数包括: - `filters`: 卷积核的数量。 - `kernel_size`: 卷积核的大小。 - `strides`: 卷积核的步长。 - `padding`: 边界填充方式,常见的有`'valid'`和`'same'`。 #### 3. 二维池化层(MaxPooling2D) `MaxPooling2D`层用于对输入进行最大池化操作,减少数据的维度。主要参数包括: - `pool_size`: 池化窗口的大小。 - `strides`: 池化窗口的步长。 - `padding`: 边界填充方式。 #### 4. 激活层(Activation) `Activation`层用于对前一层的输出应用激活函数,常见的激活函数有`relu`、`sigmoid`、`tanh`等。 #### 5. Dropout层 `Dropout`层用于在训练过程中随机丢弃一部分神经元,以防止过拟合。主要参数包括: - `rate`: 丢弃神经元的比例。 #### 6. Flatten层 `Flatten`层用于将多维输入展平为一维,常用于从卷积层到全连接层的过渡。 #### 7. 全连接层(Dense) `Dense`层是一个全连接层,用于处理特征的线性组合。主要参数包括: - `units`: 输出单元的数量。 - `activation`: 激活函数。 #### 8. 编译模型(compile) `compile`方法用于配置模型的学习过程,主要参数包括: - `optimizer`: 优化器。 - `loss`: 损失函数。 - `metrics`: 评估指标。 #### 9. 训练模型(fit) `fit`方法用于训练模型,主要参数包括: - `x`: 输入数据。 - `y`: 标签数据。 - `batch_size`: 批次大小。 - `epochs`: 训练轮数。 - `validation_data`: 验证数据。 #### 10. 评估模型(evaluate) `evaluate`方法用于评估模型在测试集上的表现,主要参数与`fit`方法类似。 更多详细信息请参考Keras官方文档: - 中文版:http://keras-cn.readthedocs.io/en/latest/ - 英文版:https://keras.io/
推荐阅读
前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ...
[详细]
蜡笔小新 2024-12-27 15:19:01
本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ...
[详细]
蜡笔小新 2024-12-28 09:46:23
1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ...
[详细]
蜡笔小新 2024-12-27 18:36:54
本文详细介绍了如何使用 Yii2 的 GridView 组件在列表页面实现数据的直接编辑功能。通过具体的代码示例和步骤,帮助开发者快速掌握这一实用技巧。 ...
[详细]
蜡笔小新 2024-12-27 16:27:52
本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ...
[详细]
蜡笔小新 2024-12-27 15:06:12
本章将深入探讨移动 UI 设计的核心原则,帮助开发者构建简洁、高效且用户友好的界面。通过学习设计规则和用户体验优化技巧,您将能够创建出既美观又实用的移动应用。 ...
[详细]
蜡笔小新 2024-12-27 08:43:40
学习链接:http:blog.csdn.netlwt36articledetails48908031学习扫描线主要学习的是一种扫描的思想,后期可以求解很 ...
[详细]
蜡笔小新 2024-12-26 20:04:36
本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ...
[详细]
蜡笔小新 2024-12-25 17:38:50
本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ...
[详细]
蜡笔小新 2024-12-28 04:11:47
本文详细介绍了如何在Linux系统上安装和配置Smokeping,以实现对网络链路质量的实时监控。通过详细的步骤和必要的依赖包安装,确保用户能够顺利完成部署并优化其网络性能监控。 ...
[详细]
蜡笔小新 2024-12-27 19:31:05
本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ...
[详细]
蜡笔小新 2024-12-27 13:10:20
本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ...
[详细]
蜡笔小新 2024-12-27 11:39:44
本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ...
[详细]
蜡笔小新 2024-12-26 22:04:19
梯度方向指示函数值增加的方向,由各轴方向的偏导数综合而成,其模长表示函数值变化的速率。本文详细探讨了导数、偏导数、梯度等概念,并结合Softmax函数、卷积神经网络(CNN)中的卷积计算、权值共享及池化操作进行了深入分析。 ...
[详细]
蜡笔小新 2024-12-26 18:23:11
本文介绍如何使用 Scala 以 UTF-8 编码方式读取属性文件,并实现属性文件的克隆功能。通过这种方式,可以确保配置文件在多线程环境下的一致性和高效性。 ...
[详细]
蜡笔小新 2024-12-26 08:25:19