热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于Keras实现的卷积神经网络(CNN)示例

本文介绍了一个使用Keras框架构建的卷积神经网络(CNN)实例,主要利用了Keras提供的MNIST数据集以及相关的层,如Dense、Dropout、Activation等,构建了一个具有两层卷积和两层全连接层的CNN模型。
### 基于Keras的CNN实现

本文展示了一个使用Keras框架构建的卷积神经网络(CNN)实例。该模型主要用于处理图像分类任务,特别是手写数字识别。通过使用Keras内置的MNIST数据集,我们构建了一个包含两层卷积层和两层全连接层的简单CNN模型。经过12轮训练后,模型的准确率达到了99.25%,展示了CNN在图像识别领域的强大能力。

```python
import numpy as np
np.random.seed(1337)
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12
img_rows, img_cols = 28, 28
filters = 32
pool_size = (2, 2)
kernel_size = (3, 3)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(filters, kernel_size, padding='same', input_shape=input_shape))
model.add(Activation('relu'))
model.add(Conv2D(filters, kernel_size))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=pool_size))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
```

### Keras关键组件详解

#### 1. Keras的Backend支持
Keras支持多种后端,包括Theano和TensorFlow。不同的后端在数据格式上有所差异,例如Theano模式下,图像数据的形状为(样本数, 通道数, 高度, 宽度),而TensorFlow模式下则为(样本数, 高度, 宽度, 通道数)。这些差异在代码中通过`K.image_data_format()`来自动适应。

#### 2. 二维卷积层(Conv2D)
`Conv2D`层用于对二维输入进行卷积操作。主要参数包括:
- `filters`: 卷积核的数量。
- `kernel_size`: 卷积核的大小。
- `strides`: 卷积核的步长。
- `padding`: 边界填充方式,常见的有`'valid'`和`'same'`。

#### 3. 二维池化层(MaxPooling2D)
`MaxPooling2D`层用于对输入进行最大池化操作,减少数据的维度。主要参数包括:
- `pool_size`: 池化窗口的大小。
- `strides`: 池化窗口的步长。
- `padding`: 边界填充方式。

#### 4. 激活层(Activation)
`Activation`层用于对前一层的输出应用激活函数,常见的激活函数有`relu`、`sigmoid`、`tanh`等。

#### 5. Dropout层
`Dropout`层用于在训练过程中随机丢弃一部分神经元,以防止过拟合。主要参数包括:
- `rate`: 丢弃神经元的比例。

#### 6. Flatten层
`Flatten`层用于将多维输入展平为一维,常用于从卷积层到全连接层的过渡。

#### 7. 全连接层(Dense)
`Dense`层是一个全连接层,用于处理特征的线性组合。主要参数包括:
- `units`: 输出单元的数量。
- `activation`: 激活函数。

#### 8. 编译模型(compile)
`compile`方法用于配置模型的学习过程,主要参数包括:
- `optimizer`: 优化器。
- `loss`: 损失函数。
- `metrics`: 评估指标。

#### 9. 训练模型(fit)
`fit`方法用于训练模型,主要参数包括:
- `x`: 输入数据。
- `y`: 标签数据。
- `batch_size`: 批次大小。
- `epochs`: 训练轮数。
- `validation_data`: 验证数据。

#### 10. 评估模型(evaluate)
`evaluate`方法用于评估模型在测试集上的表现,主要参数与`fit`方法类似。

更多详细信息请参考Keras官方文档:
- 中文版:http://keras-cn.readthedocs.io/en/latest/
- 英文版:https://keras.io/
推荐阅读
  • 本文详细介绍了Java中org.w3c.dom.Text类的splitText()方法,通过多个代码示例展示了其实际应用。该方法用于将文本节点在指定位置拆分为两个节点,并保持在文档树中。 ... [详细]
  • 本文介绍了Java并发库中的阻塞队列(BlockingQueue)及其典型应用场景。通过具体实例,展示了如何利用LinkedBlockingQueue实现线程间高效、安全的数据传递,并结合线程池和原子类优化性能。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • PHP 过滤器详解
    本文深入探讨了 PHP 中的过滤器机制,包括常见的 $_SERVER 变量、filter_has_var() 函数、filter_id() 函数、filter_input() 函数及其数组形式、filter_list() 函数以及 filter_var() 和其数组形式。同时,详细介绍了各种过滤器的用途和用法。 ... [详细]
  • 本教程详细介绍了如何使用 TensorFlow 2.0 构建和训练多层感知机(MLP)网络,涵盖回归和分类任务。通过具体示例和代码实现,帮助初学者快速掌握 TensorFlow 的核心概念和操作。 ... [详细]
  • Coursera ML 机器学习
    2019独角兽企业重金招聘Python工程师标准线性回归算法计算过程CostFunction梯度下降算法多变量回归![选择特征](https:static.oschina.n ... [详细]
  • 深入解析Spring启动过程
    本文详细介绍了Spring框架的启动流程,帮助开发者理解其内部机制。通过具体示例和代码片段,解释了Bean定义、工厂类、读取器以及条件评估等关键概念,使读者能够更全面地掌握Spring的初始化过程。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 深入理解Java中的volatile、内存屏障与CPU指令
    本文详细探讨了Java中volatile关键字的作用机制,以及其与内存屏障和CPU指令之间的关系。通过具体示例和专业解析,帮助读者更好地理解多线程编程中的同步问题。 ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文介绍如何使用阿里云的fastjson库解析包含时间戳、IP地址和参数等信息的JSON格式文本,并进行数据处理和保存。 ... [详细]
  • Scala 实现 UTF-8 编码属性文件读取与克隆
    本文介绍如何使用 Scala 以 UTF-8 编码方式读取属性文件,并实现属性文件的克隆功能。通过这种方式,可以确保配置文件在多线程环境下的一致性和高效性。 ... [详细]
  • 深入解析 Spring Security 用户认证机制
    本文将详细介绍 Spring Security 中用户登录认证的核心流程,重点分析 AbstractAuthenticationProcessingFilter 和 AuthenticationManager 的工作原理。通过理解这些组件的实现,读者可以更好地掌握 Spring Security 的认证机制。 ... [详细]
author-avatar
微笑5885
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有