热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow2.0中的Keras数据归一化实践

数据预处理是机器学习任务中的关键步骤,特别是在深度学习领域。通过将数据归一化至特定范围,可以在梯度下降过程中实现更快的收敛速度和更高的模型性能。本文探讨了如何使用TensorFlow2.0和Keras进行有效的数据归一化。

在机器学习项目中,数据预处理是确保模型能够有效学习的重要环节。特别是对于深度学习模型而言,数据的归一化处理尤为重要,因为它能帮助模型更快地收敛,并提高最终的预测准确性。本文将介绍如何使用 TensorFlow 2.0 和 Keras 来实现这一过程。



在深度学习中,数据归一化是指将不同量级的数据调整到同一尺度,从而避免某些特征因为数值较大而主导模型的学习过程。这种做法有助于保持所有特征在训练过程中的相对重要性,进而提升模型的整体表现。



数据归一化的数学表达式通常为:
数据归一化公式



实战演练:使用 TensorFlow 2.0 和 Keras 进行数据归一化



为了更好地理解数据归一化的过程,我们将通过一个具体的例子来演示如何在 TensorFlow 2.0 中使用 Keras 对 Fashion MNIST 数据集进行归一化处理。以下是详细的代码示例:



首先,我们需要导入必要的库:



import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(sys.version_info)

for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)



接下来,加载并分割数据集:



fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()

x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)



检查原始数据的最大值和最小值:



print(np.max(x_train), np.min(x_train))



输出结果显示,训练集中的像素值范围为 0 到 255。为了使这些值更适合神经网络的输入,我们对其进行归一化处理:



from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(
x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_valid_scaled = scaler.transform(
x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_test_scaled = scaler.transform(
x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)



在这个过程中,StandardScaler 类用于计算训练数据的均值和标准差,并据此对数据进行缩放。值得注意的是,训练数据使用 fit_transform 方法,而验证和测试数据则仅使用 transform 方法,以确保模型不会受到未见数据的影响。



归一化后,再次检查数据的最大值和最小值:



print(np.max(x_train_scaled), np.min(x_train_scaled))



接下来,构建并训练一个简单的多层感知器模型:



model = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(300, activation='relu'),
keras.layers.Dense(100, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])

history = model.fit(x_train_scaled, y_train, epochs=10,
validation_data=(x_valid_scaled, y_valid))



训练过程中的日志显示,随着训练的进行,模型的准确率逐渐提高,验证集上的表现也有所改善。这表明数据归一化确实有助于提升模型性能。



最后,可以通过绘制学习曲线来直观地观察模型的训练情况:



def plot_learning_curves(history):
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

plot_learning_curves(history)



此外,还可以对测试集进行评估,以了解模型在未见过的数据上的表现:



model.evaluate(x_test_scaled, y_test)



评估结果显示,模型在测试集上的准确率为 0.8825,进一步验证了数据归一化对模型性能的积极影响。


推荐阅读
  • 本文探讨了图像标签的多种分类场景及其在以图搜图技术中的应用,涵盖了从基础理论到实际项目实施的全面解析。 ... [详细]
  • 本文详细介绍了 TensorFlow 的入门实践,特别是使用 MNIST 数据集进行数字识别的项目。文章首先解析了项目文件结构,并解释了各部分的作用,随后逐步讲解了如何通过 TensorFlow 实现基本的神经网络模型。 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • 本文介绍了一种根据目标检测结果,从原始XML文件中提取并分析特定类别的方法。通过解析XML文件,筛选出特定类别的图像和标注信息,并保存到新的文件夹中,以便进一步分析和处理。 ... [详细]
  • Keras 实战:自编码器入门指南
    本文介绍了使用 Keras 框架实现自编码器的基本方法。自编码器是一种用于无监督学习的神经网络模型,主要功能包括数据降维、特征提取等。通过实际案例,我们将展示如何使用全连接层和卷积层来构建自编码器,并讨论不同维度对重建效果的影响。 ... [详细]
  • 本文详细介绍了使用NumPy和TensorFlow实现的逻辑回归算法。通过具体代码示例,解释了数据加载、模型训练及分类预测的过程。 ... [详细]
  • 在Ubuntu 16.04中使用Anaconda安装TensorFlow
    本文详细介绍了如何在Ubuntu 16.04系统上通过Anaconda环境管理工具安装TensorFlow。首先,需要下载并安装Anaconda,然后配置环境变量以确保系统能够识别Anaconda命令。接着,创建一个特定的Python环境用于安装TensorFlow,并通过指定的镜像源加速安装过程。最后,通过一个简单的线性回归示例验证TensorFlow的安装是否成功。 ... [详细]
  • 本文详细介绍如何通过Anaconda 3.5.01快速安装TensorFlow,包括环境配置和具体步骤。 ... [详细]
  • 深入解析JVM垃圾收集器
    本文基于《深入理解Java虚拟机:JVM高级特性与最佳实践》第二版,详细探讨了JVM中不同类型的垃圾收集器及其工作原理。通过介绍各种垃圾收集器的特性和应用场景,帮助读者更好地理解和优化JVM内存管理。 ... [详细]
  • 本文将介绍如何使用 Go 语言编写和运行一个简单的“Hello, World!”程序。内容涵盖开发环境配置、代码结构解析及执行步骤。 ... [详细]
  • 本文探讨了 Objective-C 中的一些重要语法特性,包括 goto 语句、块(block)的使用、访问修饰符以及属性管理等。通过实例代码和详细解释,帮助开发者更好地理解和应用这些特性。 ... [详细]
  • 2017年人工智能领域的十大里程碑事件回顾
    随着2018年的临近,我们一同回顾过去一年中人工智能领域的重要进展。这一年,无论是政策层面的支持,还是技术上的突破,都显示了人工智能发展的迅猛势头。以下是精选的2017年人工智能领域最具影响力的事件。 ... [详细]
author-avatar
dmcm0001
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有