热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow2.0中的Keras数据归一化实践

数据预处理是机器学习任务中的关键步骤,特别是在深度学习领域。通过将数据归一化至特定范围,可以在梯度下降过程中实现更快的收敛速度和更高的模型性能。本文探讨了如何使用TensorFlow2.0和Keras进行有效的数据归一化。

在机器学习项目中,数据预处理是确保模型能够有效学习的重要环节。特别是对于深度学习模型而言,数据的归一化处理尤为重要,因为它能帮助模型更快地收敛,并提高最终的预测准确性。本文将介绍如何使用 TensorFlow 2.0 和 Keras 来实现这一过程。



在深度学习中,数据归一化是指将不同量级的数据调整到同一尺度,从而避免某些特征因为数值较大而主导模型的学习过程。这种做法有助于保持所有特征在训练过程中的相对重要性,进而提升模型的整体表现。



数据归一化的数学表达式通常为:
数据归一化公式



实战演练:使用 TensorFlow 2.0 和 Keras 进行数据归一化



为了更好地理解数据归一化的过程,我们将通过一个具体的例子来演示如何在 TensorFlow 2.0 中使用 Keras 对 Fashion MNIST 数据集进行归一化处理。以下是详细的代码示例:



首先,我们需要导入必要的库:



import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(sys.version_info)

for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)



接下来,加载并分割数据集:



fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()

x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)



检查原始数据的最大值和最小值:



print(np.max(x_train), np.min(x_train))



输出结果显示,训练集中的像素值范围为 0 到 255。为了使这些值更适合神经网络的输入,我们对其进行归一化处理:



from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(
x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_valid_scaled = scaler.transform(
x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_test_scaled = scaler.transform(
x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)



在这个过程中,StandardScaler 类用于计算训练数据的均值和标准差,并据此对数据进行缩放。值得注意的是,训练数据使用 fit_transform 方法,而验证和测试数据则仅使用 transform 方法,以确保模型不会受到未见数据的影响。



归一化后,再次检查数据的最大值和最小值:



print(np.max(x_train_scaled), np.min(x_train_scaled))



接下来,构建并训练一个简单的多层感知器模型:



model = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(300, activation='relu'),
keras.layers.Dense(100, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])

history = model.fit(x_train_scaled, y_train, epochs=10,
validation_data=(x_valid_scaled, y_valid))



训练过程中的日志显示,随着训练的进行,模型的准确率逐渐提高,验证集上的表现也有所改善。这表明数据归一化确实有助于提升模型性能。



最后,可以通过绘制学习曲线来直观地观察模型的训练情况:



def plot_learning_curves(history):
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

plot_learning_curves(history)



此外,还可以对测试集进行评估,以了解模型在未见过的数据上的表现:



model.evaluate(x_test_scaled, y_test)



评估结果显示,模型在测试集上的准确率为 0.8825,进一步验证了数据归一化对模型性能的积极影响。


推荐阅读
  • 本文探讨了如何在TensorFlow中使用张量来处理和分析数字图像,特别是通过具体的代码示例展示了张量在图像处理中的作用。 ... [详细]
  • 本文针对公司项目中普遍存在的IE浏览器兼容性问题,特别是IE9及以下版本,提出了具体的解决方案,确保用户在这些旧版浏览器中也能顺利实现图片上传预览功能。 ... [详细]
  • 本文探讨了使用匈牙利算法解决二分图中的最大权匹配问题,并通过HDU1533题目实例进行详细解析。代码实现中包括了必要的数据结构定义、输入处理以及求解过程。 ... [详细]
  • 转载网址:http:www.open-open.comlibviewopen1326597582452.html参考资料:http:www.cocos2d-ip ... [详细]
  • 本文将详细介绍如何使用ViewPager实现多页面滑动切换,并探讨如何去掉其默认的左右切换动画效果。ViewPager是Android开发中常用的组件之一,用于实现屏幕间的内容切换。 ... [详细]
  • 图神经网络模型综述
    本文综述了图神经网络(Graph Neural Networks, GNN)的发展,从传统的数据存储模型转向图和动态模型,探讨了模型中的显性和隐性结构,并详细介绍了GNN的关键组件及其应用。 ... [详细]
  • 本文详细介绍了在Hive中创建表的基本语法,包括临时表、外部表的创建方法,以及如何设置表的各种属性和约束条件。 ... [详细]
  • 框图|中将_DA14531 学习笔记经验总结
    框图|中将_DA14531 学习笔记经验总结 ... [详细]
  • 本文探讨如何通过贪心算法有效地安排一系列活动,确保使用最少数量的会场来完成所有活动的调度。 ... [详细]
  • 一、数据更新操作DML语法中主要包括两个内容:查询与更新,更新主要包括:增加数据、修改数据、删除数据。其中这些操作是离不开查询的。1、增加数据语法:INSERTINTO表名称[(字 ... [详细]
  • HTML中用于创建表单的标签是什么
    本文将详细介绍HTML中用于创建表单的标签及其基本用法,包括表单的主要特性和常用的属性设置。如果您正在学习HTML或需要了解如何在网页中添加表单,这将是一个很好的起点。 ... [详细]
  • 统计报表模板及其实现方法
    本文介绍两个实用的统计报表模板,并提供如何将这些静态模板转换为动态JSP页面的方法。同时,文中附上了详细的代码示例。 ... [详细]
  • 本文详细介绍了HTML5中的文件操作API,包括FileList、Blob、File和FileReader等重要JavaScript对象的接口定义及其功能特性。 ... [详细]
  • TensorFlow核心函数解析与应用
    本文详细介绍了TensorFlow中几个常用的基础函数及其应用场景,包括常量创建、张量扩展以及二维卷积操作等,旨在帮助开发者更好地理解和使用这些功能。 ... [详细]
  • Python图像处理库概览
    本文详细介绍了Python中常用的图像处理库,包括scikit-image、Numpy、Scipy、Pillow、OpenCV-Python、SimpleCV、Mahotas、SimpleITK、pgmagick和Pycairo,旨在帮助开发者和研究人员选择合适的工具进行图像处理任务。 ... [详细]
author-avatar
dmcm0001
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有