热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【Python神经网络预测】

Python-神经网络预测目的:预测钢铁成锭率导入模块读取文件标准化转换输出均值,方差测试数据输出权重矩阵,系数矩阵模型评价画图目的&#x


Python-神经网络预测

  • 目的:预测钢铁成锭率
  • 导入模块
  • 读取文件
  • 标准化转换
  • 输出均值,方差
  • 测试数据
  • 输出权重矩阵,系数矩阵
  • 模型评价
  • 画图


目的:预测钢铁成锭率


导入模块

from matplotlib import pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd


读取文件

在这里插入图片描述
#原数据差不多是下面那样
原数据

file_path = ‘data/成锭率.csv’
df = pd.read_csv(file_path)
Y = df.iloc[:, 14]
X = df.values[:, 0:14]


标准化转换

scaler=StandardScaler()
X = scaler.fit_transform(X)


输出均值,方差

print(“拟合后的均值为:”, scaler.mean_)
print(“拟合后的方差:”, scaler.var_)

clf = MLPRegressor(solver=‘lbfgs’, activation=‘relu’, learning_rate_init=0.001,alpha=0.001,max_iter=1000000, hidden_layer_sizes=(40,40))
clf.fit(X,Y)


测试数据

pred = clf.predict(scaler.transform([[3.66,4.55,6.27,0.24,5.68,8.75,2.35,2.86,0.123,0.536,5.145,2.48,0.55,0.49],[4.68,4.12,7.32,4.55,6.02,8.01,2.35,2.86,0.225,1.251,5.145,2.48,0.55,0.49],[6.38,7.32,8.61,4.25,6.86,3.61,1.55,1.65,0.144,1.652,1.035,1.55,0.27,0.54],
[3.26,4.55,6.23,4.56,6.55,3.29,0.62,0.86,0.429,2.409,1.035,1.75,0.46,0.53],
[6.22,6.54,8.66,4.65,4.58,3.25,1.45,1.66,0.555,0.456,0.756,1.57,0.46,0.52],
[8.05,9.12,0.26,0.12,6.08,8.87,1.26,2.31,0.552,2.548,0.185,2.15,0.48,0.51],
[8.23,7.85,4.87,6.54,6.08,9.12,1.26,3.56,0.463,0.255,1.013,2.15,0.48,0.51],
[7.56,7.56,7.47,6.25,6.55,8.88,1.26,2.31,0.552,2.456,0.185,2.15,0.48,0.51],
[6.45,8.01,8.56,6.54,6.08,8.48,1.26,3.56,0.463,0.574,1.013,2.15,0.48,0.51]]))
print(‘回归预测结果:’, pred)
ypred = clf.predict(X)
print(ypred)


输出权重矩阵,系数矩阵

index=0
for w in clf.coefs_:
index += 1
print(‘第{}层网络层:’.format(index))
print(‘权重矩阵:’, w.shape)
print(‘系数矩阵:’, w)


模型评价

score = clf.score(X, Y)# 相关系数
print(np.abs(df.iloc[:,14]-ypred).mean() )


画图

plt.figure()
plt.plot(np.arange(len(Y)), Y, “bo-”, label=“真实值”) # 训练数据和训练标签
plt.plot(np.arange(len(ypred )), ypred , “ro-”, label=“预测值”) # 训练数据和模型预测的标签
plt.rcParams[‘font.sans-serif’] = [‘SimHei’] # 显示中文
plt.title(f’sklearn神经网络—拟合度:{score}\n’)
plt.legend(loc=“best”)
plt.show()
在这里插入图片描述

完整代码

from matplotlib import pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd
file_path = 'data/成锭率.csv'
df = pd.read_csv(file_path)
Y = df.iloc[:, 14]
X = df.values[:, 0:14]
# 标准化转换
scaler=StandardScaler()
X = scaler.fit_transform(X)
print("拟合后的均值为:", scaler.mean_)
print("拟合后的方差:", scaler.var_)
clf = MLPRegressor(solver='lbfgs', activation='relu', learning_rate_init=0.001,alpha=0.001,max_iter=1000000, hidden_layer_sizes=(40,40))
clf.fit(X,Y)
# 测试数据
pred = clf.predict(scaler.transform([[3.66,4.55,6.27,0.24,5.68,8.75,2.35,2.86,0.123,0.536,5.145,2.48,0.55,0.49],[4.68,4.12,7.32,4.55,6.02,8.01,2.35,2.86,0.225,1.251,5.145,2.48,0.55,0.49],[6.38,7.32,8.61,4.25,6.86,3.61,1.55,1.65,0.144,1.652,1.035,1.55,0.27,0.54],
[3.26,4.55,6.23,4.56,6.55,3.29,0.62,0.86,0.429,2.409,1.035,1.75,0.46,0.53],
[6.22,6.54,8.66,4.65,4.58,3.25,1.45,1.66,0.555,0.456,0.756,1.57,0.46,0.52],
[8.05,9.12,0.26,0.12,6.08,8.87,1.26,2.31,0.552,2.548,0.185,2.15,0.48,0.51],
[8.23,7.85,4.87,6.54,6.08,9.12,1.26,3.56,0.463,0.255,1.013,2.15,0.48,0.51],
[7.56,7.56,7.47,6.25,6.55,8.88,1.26,2.31,0.552,2.456,0.185,2.15,0.48,0.51],
[6.45,8.01,8.56,6.54,6.08,8.48,1.26,3.56,0.463,0.574,1.013,2.15,0.48,0.51]]))
print('回归预测结果:', pred)
ypred = clf.predict(X)
print(ypred)
index=0
for w in clf.coefs_:index += 1print('第{}层网络层:'.format(index))print('权重矩阵:', w.shape)print('系数矩阵:', w)
score = clf.score(X, Y)# 模型评价
print(np.abs(df.iloc[:,14]-ypred).mean() ) # 模型评价
plt.figure()
plt.plot(np.arange(len(Y)), Y, "bo-", label="真实值") # 训练数据和训练标签
plt.plot(np.arange(len(ypred )), ypred , "ro-", label="预测值") # 训练数据和模型预测的标签
plt.rcParams['font.sans-serif'] = ['SimHei'] # 显示中文
plt.title(f'sklearn神经网络---拟合度:{score}\n')
plt.legend(loc="best")
plt.show()

推荐阅读
author-avatar
栋逼逼丶
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有