热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow2.0中的Keras数据归一化实践

数据预处理是机器学习任务中的关键步骤,特别是在深度学习领域。通过将数据归一化至特定范围,可以在梯度下降过程中实现更快的收敛速度和更高的模型性能。本文探讨了如何使用TensorFlow2.0和Keras进行有效的数据归一化。

在机器学习项目中,数据预处理是确保模型能够有效学习的重要环节。特别是对于深度学习模型而言,数据的归一化处理尤为重要,因为它能帮助模型更快地收敛,并提高最终的预测准确性。本文将介绍如何使用 TensorFlow 2.0 和 Keras 来实现这一过程。



在深度学习中,数据归一化是指将不同量级的数据调整到同一尺度,从而避免某些特征因为数值较大而主导模型的学习过程。这种做法有助于保持所有特征在训练过程中的相对重要性,进而提升模型的整体表现。



数据归一化的数学表达式通常为:
数据归一化公式



实战演练:使用 TensorFlow 2.0 和 Keras 进行数据归一化



为了更好地理解数据归一化的过程,我们将通过一个具体的例子来演示如何在 TensorFlow 2.0 中使用 Keras 对 Fashion MNIST 数据集进行归一化处理。以下是详细的代码示例:



首先,我们需要导入必要的库:



import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(sys.version_info)

for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)



接下来,加载并分割数据集:



fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()

x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)



检查原始数据的最大值和最小值:



print(np.max(x_train), np.min(x_train))



输出结果显示,训练集中的像素值范围为 0 到 255。为了使这些值更适合神经网络的输入,我们对其进行归一化处理:



from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(
x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_valid_scaled = scaler.transform(
x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_test_scaled = scaler.transform(
x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)



在这个过程中,StandardScaler 类用于计算训练数据的均值和标准差,并据此对数据进行缩放。值得注意的是,训练数据使用 fit_transform 方法,而验证和测试数据则仅使用 transform 方法,以确保模型不会受到未见数据的影响。



归一化后,再次检查数据的最大值和最小值:



print(np.max(x_train_scaled), np.min(x_train_scaled))



接下来,构建并训练一个简单的多层感知器模型:



model = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(300, activation='relu'),
keras.layers.Dense(100, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])

history = model.fit(x_train_scaled, y_train, epochs=10,
validation_data=(x_valid_scaled, y_valid))



训练过程中的日志显示,随着训练的进行,模型的准确率逐渐提高,验证集上的表现也有所改善。这表明数据归一化确实有助于提升模型性能。



最后,可以通过绘制学习曲线来直观地观察模型的训练情况:



def plot_learning_curves(history):
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

plot_learning_curves(history)



此外,还可以对测试集进行评估,以了解模型在未见过的数据上的表现:



model.evaluate(x_test_scaled, y_test)



评估结果显示,模型在测试集上的准确率为 0.8825,进一步验证了数据归一化对模型性能的积极影响。


推荐阅读
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • XNA 3.0 游戏编程:从 XML 文件加载数据
    本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ... [详细]
  • MongoDB集群配置:副本集与分片详解
    本文详细介绍了如何在MongoDB中配置副本集(Replica Sets)和分片(Sharding),并提供了具体的步骤和命令,帮助读者理解并实现高可用性和水平扩展的MongoDB集群。 ... [详细]
  • 优化局域网SSH连接延迟问题的解决方案
    本文介绍了解决局域网内SSH连接到服务器时出现长时间等待问题的方法。通过调整配置和优化网络设置,可以显著缩短SSH连接的时间。 ... [详细]
  • Keras 实战:自编码器入门指南
    本文介绍了使用 Keras 框架实现自编码器的基本方法。自编码器是一种用于无监督学习的神经网络模型,主要功能包括数据降维、特征提取等。通过实际案例,我们将展示如何使用全连接层和卷积层来构建自编码器,并讨论不同维度对重建效果的影响。 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • 资源推荐 | TensorFlow官方中文教程助力英语非母语者学习
    来源:机器之心。本文详细介绍了TensorFlow官方提供的中文版教程和指南,帮助开发者更好地理解和应用这一强大的开源机器学习平台。 ... [详细]
  • DNN Community 和 Professional 版本的主要差异
    本文详细解析了 DotNetNuke (DNN) 的两种主要版本:Community 和 Professional。通过对比两者的功能和附加组件,帮助用户选择最适合其需求的版本。 ... [详细]
  • 本文深入探讨了Linux系统中网卡绑定(bonding)的七种工作模式。网卡绑定技术通过将多个物理网卡组合成一个逻辑网卡,实现网络冗余、带宽聚合和负载均衡,在生产环境中广泛应用。文章详细介绍了每种模式的特点、适用场景及配置方法。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 使用Python在SAE上开发新浪微博应用的初步探索
    最近重新审视了新浪云平台(SAE)提供的服务,发现其已支持Python开发。本文将详细介绍如何利用Django框架构建一个简单的新浪微博应用,并分享开发过程中的关键步骤。 ... [详细]
  • 实体映射最强工具类:MapStruct真香 ... [详细]
  • 本文介绍如何在Linux Mint系统上搭建Rust开发环境,包括安装IntelliJ IDEA、Rust工具链及必要的插件。通过详细步骤,帮助开发者快速上手。 ... [详细]
author-avatar
dmcm0001
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有