热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow学习笔记(2):多元线性回归

前言本文使用tensorflow训练多元线性回归模型,并将其与scikit-learn做比较。数据集来自AndrewNg的网上公开课程DeepLearning代码#!usrbine

前言

本文使用tensorflow训练多元线性回归模型,并将其与scikit-learn做比较。数据集来自Andrew Ng的网上公开课程Deep Learning

代码

#!/usr/bin/env python
# -*- coding=utf-8 -*-
# @author: 陈水平
# @date: 2016-12-30
# @description: compare multi linear regression of tensor flow to scikit-learn based on data from deep learning cource of Andrew Ng
# @ref: http://openclassroom.stanford.edu/MainFolder/DocumentPage.php?course=DeepLearning&doc=exercises/ex3/ex3.html
#
import numpy as np
import tensorflow as tf
from sklearn import linear_model
from sklearn import preprocessing
# Read x and y
x_data = np.loadtxt("ex3x.dat").astype(np.float32)
y_data = np.loadtxt("ex3y.dat").astype(np.float32)
# We evaluate the x and y by sklearn to get a sense of the coefficients.
reg = linear_model.LinearRegression()
reg.fit(x_data, y_data)
print "Coefficients of sklearn: K=%s, b=%f" % (reg.coef_, reg.intercept_)
# Now we use tensorflow to get similar results.
# Before we put the x_data into tensorflow, we need to standardize it
# in order to achieve better performance in gradient descent;
# If not standardized, the convergency speed could not be tolearated.
# Reason: If a feature has a variance that is orders of magnitude larger than others,
# it might dominate the objective function
# and make the estimator unable to learn from other features correctly as expected.
scaler = preprocessing.StandardScaler().fit(x_data)
print scaler.mean_, scaler.scale_
x_data_standard = scaler.transform(x_data)
W = tf.Variable(tf.zeros([2, 1]))
b = tf.Variable(tf.zeros([1, 1]))
y = tf.matmul(x_data_standard, W) + b
loss = tf.reduce_mean(tf.square(y - y_data.reshape(-1, 1)))/2
optimizer = tf.train.GradientDescentOptimizer(0.3)
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for step in range(100):
sess.run(train)
if step % 10 == 0:
print step, sess.run(W).flatten(), sess.run(b).flatten()
print "Coefficients of tensorflow (input should be standardized): K=%s, b=%s" % (sess.run(W).flatten(), sess.run(b).flatten())
print "Coefficients of tensorflow (raw input): K=%s, b=%s" % (sess.run(W).flatten() / scaler.scale_, sess.run(b).flatten() - np.dot(scaler.mean_ / scaler.scale_, sess.run(W)))

输出如下:

Coefficients of sklearn: K=[ 139.21066284 -8738.02148438], b=89597.927966
[ 2000.6809082 3.17021275] [ 7.86202576e+02 7.52842903e-01]
0 [ 31729.23632812 16412.6484375 ] [ 102123.7890625]
10 [ 97174.78125 5595.25585938] [ 333681.59375]
20 [ 106480.5703125 -3611.31201172] [ 340222.53125]
30 [ 108727.5390625 -5858.10302734] [ 340407.28125]
40 [ 109272.953125 -6403.52148438] [ 340412.5]
50 [ 109405.3515625 -6535.91503906] [ 340412.625]
60 [ 109437.4921875 -6568.05371094] [ 340412.625]
70 [ 109445.296875 -6575.85644531] [ 340412.625]
80 [ 109447.1875 -6577.75097656] [ 340412.625]
90 [ 109447.640625 -6578.20654297] [ 340412.625]
Coefficients of tensorflow (input should be standardized): K=[ 109447.7421875 -6578.31152344], b=[ 340412.625]
Coefficients of tensorflow (raw input): K=[ 139.21061707 -8737.9609375 ], b=[ 89597.78125]

思考

对于梯度下降算法,变量是否标准化很重要。在这个例子中,变量一个是面积,一个是房间数,量级相差很大,如果不归一化,面积在目标函数和梯度中就会占据主导地位,导致收敛极慢。


推荐阅读
  • XNA 3.0 游戏编程:从 XML 文件加载数据
    本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ... [详细]
  • 1:有如下一段程序:packagea.b.c;publicclassTest{privatestaticinti0;publicintgetNext(){return ... [详细]
  • 本文介绍如何使用阿里云的fastjson库解析包含时间戳、IP地址和参数等信息的JSON格式文本,并进行数据处理和保存。 ... [详细]
  • Java 中的 BigDecimal pow()方法,示例 ... [详细]
  • 本文介绍如何利用动态规划算法解决经典的0-1背包问题。通过具体实例和代码实现,详细解释了在给定容量的背包中选择若干物品以最大化总价值的过程。 ... [详细]
  • 本文介绍了Java并发库中的阻塞队列(BlockingQueue)及其典型应用场景。通过具体实例,展示了如何利用LinkedBlockingQueue实现线程间高效、安全的数据传递,并结合线程池和原子类优化性能。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • 本文详细介绍了Java编程语言中的核心概念和常见面试问题,包括集合类、数据结构、线程处理、Java虚拟机(JVM)、HTTP协议以及Git操作等方面的内容。通过深入分析每个主题,帮助读者更好地理解Java的关键特性和最佳实践。 ... [详细]
  • DNN Community 和 Professional 版本的主要差异
    本文详细解析了 DotNetNuke (DNN) 的两种主要版本:Community 和 Professional。通过对比两者的功能和附加组件,帮助用户选择最适合其需求的版本。 ... [详细]
  • JavaScript中属性节点的类型及应用
    本文深入探讨了JavaScript中属性节点的不同类型及其在实际开发中的应用,帮助开发者更好地理解和处理HTML元素的属性。通过具体的案例和代码示例,我们将详细解析如何操作这些属性节点。 ... [详细]
  • 2023年京东Android面试真题解析与经验分享
    本文由一位拥有6年Android开发经验的工程师撰写,详细解析了京东面试中常见的技术问题。涵盖引用传递、Handler机制、ListView优化、多线程控制及ANR处理等核心知识点。 ... [详细]
  • 从 .NET 转 Java 的自学之路:IO 流基础篇
    本文详细介绍了 Java 中的 IO 流,包括字节流和字符流的基本概念及其操作方式。探讨了如何处理不同类型的文件数据,并结合编码机制确保字符数据的正确读写。同时,文中还涵盖了装饰设计模式的应用,以及多种常见的 IO 操作实例。 ... [详细]
  • 本文探讨了 C++ 中普通数组和标准库类型 vector 的初始化方法。普通数组具有固定长度,而 vector 是一种可扩展的容器,允许动态调整大小。文章详细介绍了不同初始化方式及其应用场景,并提供了代码示例以加深理解。 ... [详细]
  • 本文介绍如何使用Python进行文本处理,包括分词和生成词云图。通过整合多个文本文件、去除停用词并生成词云图,展示文本数据的可视化分析方法。 ... [详细]
author-avatar
爱做梦的蓝梦
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有