热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于Tensorflow框架的BP神经网络回归小案例预测跳高

(案例):我们将14组国内男子跳高运动员各项素质指标作为输入,即(30m行进跑,立定三级跳远,助跑摸高,助跑4-6步跳高,负重深蹲杠铃,杠铃半蹲系数,100m,抓举),将对应的跳高


(案例):我们将14组国内男子跳高运动员各项素质指标作为输入,即(30m行进跑,立定三级跳远,助跑摸高,助跑4-6步跳高,负重深蹲杠铃,杠铃半蹲系数,100m,抓举),将对应的跳高成绩作为输出,通过对14位选手的数据训练建立模型,预测第15位选手的跳高成绩。

待预测样本a=[[3.0,9.3,3.3,2.05,100,2.8,11.2,50]]

import tensorflow as tf

import pandas as pd

import numpy as np

from sklearn.preprocessing import MinMaxScaler

构造数据:14个样本,8个特征1个标签

x=[[3.2,3.2,3,3.2,3.2,3.4,3.2,3,3.2,3.2,3.2,3.9,3.1,3.2],

[9.6,10.3,9,10.3,10.1,10,9.6,9,9.6,9.2,9.5,9,9.5,9.7],

[3.45,3.75,3.5,3.65,3.5,3.4,3.55,3.5,3.55,3.5,3.4,3.1,3.6,3.45],

[2.15,2.2,2.2,2.2,2,2.15,2.14,2.1,2.1,2.1,2.15,2,2.1,2.15],

[140,120,140,150,80,130,130,100,130,140,115,80,90,130],

[2.8,3.4,3.5,2.8,1.5,3.2,3.5,1.8,3.5,2.5,2.8,2.2,2.7,4.6],

[11,10.9,11.4,10.8,11.3,11.5,11.8,11.3,11.8,11,11.9,13,11.1,10.85],

[50,70,50,80,50,60,65,40,65,50,50,50,70,70]]

y=[[2.24],[2.33],[2.24],[2.32],[2.2],[2.27],[2.2],[2.26],[2.2],[2.24],[2.24],[2.2],

[2.2],[2.35]]

获取数据集

x_t=np.array(x,dtype=‘float32’).T #[148]

y_true=np.array(y,dtype=‘float32’) #[14
1]

将特征数据最值归一,范围在(0,1)之间

mm=MinMaxScaler() #实例化

std=mm.fit(x_t) #训练模型

x_true=std.transform(x_t) #转化

print(x_true)

print(y_true)

通过占位符,预定义输入X,输出Y

即输入层8*1个神经元,输出层1个神经元


X=tf.placeholder(tf.float32,[None,8])

Y=tf.placeholder(tf.float32,[None,1])

随机数列生成,创建隐含层的神经网络,隐含层4个神经元

truncated_normal:选取位于正态分布方差在0.1附近的随机数据


w1=tf.Variable(tf.truncated_normal([8,4],stddev=0.1))

b1=tf.Variable(tf.zeros([4]))

w2=tf.Variable(tf.zeros([4,1]))

b2=tf.Variable(tf.zeros([1]))

relu,为激活函数,增加非线性关系,隐藏层和输出层的计算

L1=tf.nn.relu(tf.matmul(X,w1)+b1)

y_pre=tf.matmul(L1,w2)+b2

计算损失函数:均方误差

loss=tf.reduce_mean(tf.cast(tf.square(Y-y_pre),tf.float32))

梯度下降优化损失函数,学习率过大容易导致权重非常大,会出现nan值

train_op=tf.train.GradientDescentOptimizer(0.01).minimize(loss)

初始化变量

init_op=tf.global_variables_initializer()

创建一个saver,用来保存训练模型

saver=tf.train.Saver()

开启回话

with tf.Session() as sess:

sess.run(init_op)

训练模型15次

for i in range(1,300): #控制训练批次

for j in range (len(y_true)):#控制每批次训练的样本数

sess.run(train_op,feed_dict={X:[x_true[j,:]],Y:[y_true[j,:]]})#[[]]是为了匹配占位的类型

输出每次训练的损失

print(‘第%s批次第%s个样本训练的损失为:%s,真实值为:%s,预测值为:%s’% (i,j+1,

sess.run(loss, feed_dict={X:[x_true[j,:]],Y:[y_true[j,:]]}),

y_true[j,:],

sess.run(y_pre,feed_dict={X:[x_true[j,:]],Y:[y_true[j,:]]})))

保存模型:需要在会话里完成(注意缩进代码)

saver.save(sess,’./BP_demo/BP_model’)

加载模型,预测15号选手的跳高成绩

saver.restore(sess,’./BP_demo/BP_model’)

样本原始数据

a = [[3.0,9.3,3.3,2.05,100,2.8,11.2,50]]

获取测试样本

x_test=np.array(a,dtype=‘float32’)

将数据最值归一

x_test=std.transform(x_test)

print(‘15号选手的跳高成绩预测值为:’, sess.run(y_pre,feed_dict={X:x_test})

结果:



第299批次第11个样本训练的损失为:0.00016767633,真实值为:[2.24],预测值为:[[2.227051]]

第299批次第12个样本训练的损失为:1.5376372e-06,真实值为:[2.2],预测值为:[[2.19876]]

第299批次第13个样本训练的损失为:0.00062711653,真实值为:[2.2],预测值为:[[2.2250423]]

第299批次第14个样本训练的损失为:0.008744326,真实值为:[2.35],预测值为:[[2.2564888]]

15号选手的跳高成绩预测值为: [[2.1450984]]

问题:

只是简单实现数据预测,误差还是较大,有更好的优化方法,欢迎大家一起来分享哦!



推荐阅读
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文介绍了利用ARMA模型对平稳非白噪声序列进行建模的步骤及代码实现。首先对观察值序列进行样本自相关系数和样本偏自相关系数的计算,然后根据这些系数的性质选择适当的ARMA模型进行拟合,并估计模型中的位置参数。接着进行模型的有效性检验,如果不通过则重新选择模型再拟合,如果通过则进行模型优化。最后利用拟合模型预测序列的未来走势。文章还介绍了绘制时序图、平稳性检验、白噪声检验、确定ARMA阶数和预测未来走势的代码实现。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • MATLAB函数重名问题解决方法及数据导入导出操作详解
    本文介绍了解决MATLAB函数重名的方法,并详细讲解了数据导入和导出的操作。包括使用菜单导入数据、在工作区直接新建变量、粘贴数据到.m文件或.txt文件并用load命令调用、使用save命令导出数据等方法。同时还介绍了使用dlmread函数调用数据的方法。通过本文的内容,读者可以更好地处理MATLAB中的函数重名问题,并掌握数据导入导出的各种操作。 ... [详细]
  • RouterOS 5.16软路由安装图解教程
    本文介绍了如何安装RouterOS 5.16软路由系统,包括系统要求、安装步骤和登录方式。同时提供了详细的图解教程,方便读者进行操作。 ... [详细]
  • MPLS VP恩 后门链路shamlink实验及配置步骤
    本文介绍了MPLS VP恩 后门链路shamlink的实验步骤及配置过程,包括拓扑、CE1、PE1、P1、P2、PE2和CE2的配置。详细讲解了shamlink实验的目的和操作步骤,帮助读者理解和实践该技术。 ... [详细]
  • 本文介绍了如何在Mac上使用Pillow库加载不同于默认字体和大小的字体,并提供了一个简单的示例代码。通过该示例,读者可以了解如何在Python中使用Pillow库来写入不同字体的文本。同时,本文也解决了在Mac上使用Pillow库加载字体时可能遇到的问题。读者可以根据本文提供的示例代码,轻松实现在Mac上使用Pillow库加载不同字体的功能。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • 本文讨论了如何使用GStreamer来删除H264格式视频文件中的中间部分,而不需要进行重编码。作者提出了使用gst_element_seek(...)函数来实现这个目标的思路,并提到遇到了一个解决不了的BUG。文章还列举了8个解决方案,希望能够得到更好的思路。 ... [详细]
  • 本文介绍了使用readlink命令获取文件的完整路径的简单方法,并提供了一个示例命令来打印文件的完整路径。共有28种解决方案可供选择。 ... [详细]
  • 1Lock与ReadWriteLock1.1LockpublicinterfaceLock{voidlock();voidlockInterruptibl ... [详细]
author-avatar
我的明天谁2502931447
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有