热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

数据分析(4)sklearn入门

如何选择机器学习方法http:scikit-learn.orgstabletutorialmachine_learning_mapindex.html通用学习模式只需要先定义

如何选择机器学习方法

http://scikit-learn.org/stable/tutorial/machine_learning_map/index.html
这里写图片描述


通用学习模式

只需要先定义 用什么model学习,然后再 model.fit(数据), 这样 model 就能从数据中学到东西. 最后还可以用 model.predict() 来预测值.

from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier
iris = datasets.load_iris()
iris_X = iris.data
iris_Y = iris.target
'''
输入有四个属性:[[ 5.1 3.5 1.4 0.2] [ 4.9 3. 1.4 0.2] ...]
输出类别:[0 0 0 ... 1 1 1 ... 2 2 2 ...]
'''

X_train,X_test,ytrain,y_test = train_test_split(iris_X,iris_Y,test_size=0.3) # 顺序也被打乱,按7:3
knn = KNeighborsClassifier()
knn.fit(X_train,ytrain) # 训练
print(knn.predict(X_test)) # 预测
print(y_test)

sklearn 的 datasets 数据库

Sklearn 提供了很多的有用的数据库,既有真实数据也有你可以编造的数据!特别的强大.http://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets

from sklearn import datasets
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
loaded_data = datasets.load_boston()
data_X = loaded_data.data
data_y = loaded_data.target
model = LinearRegression() # 里面有参数可以改变
model.fit(data_X,data_y)
print(model.predict(data_X[:4,:]))
print(data_y[:4])
'''
[ 30.00821269 25.0298606 30.5702317 28.60814055]
[ 24. 21.6 34.7 33.4]
'''

X, y = datasets.make_regression(n_samples=100, n_features=1, n_targets=1, noise=10)
plt.scatter(X, y)
plt.show()

这里写图片描述
model 常用属性和功能

# y = 0.1x + 0.3
print(model.coef_) # 输出0.1
print(model.intercept_) # 输出0.3
print(model.get_params()) # 返回给model认定的参数,比如{'copy_X': True, 'n_jobs': 1, 'normalize': False, 'fit_intercept': True}
print(model.score(data_X, data_y)) # R^2 coefficient of determination

normalization 标准化数据

normalization 在数据跨度不一的情况下对机器学习有很重要的作用.特别是各种数据属性还会互相影响的情况之下. Scikit-learn 中标准化的语句是 preprocessing.scale() . scale 以后, model 就更能从标准化数据中学到东西.
这里写图片描述

from sklearn.model_selection import train_test_split
from sklearn.datasets.samples_generator import make_classification
from sklearn.svm import SVC
import matplotlib.pyplot as plt
X, y = make_classification(n_samples=300, n_features=2 , n_redundant=0, n_informative=2,random_state=22, n_clusters_per_class=1, scale=100)
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
X = preprocessing.scale(X) # normalization step
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)
clf = SVC()
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test)) # 0.944444444444

cross validation 交叉验证1
sklearn 中的 cross validation 交叉验证 对于我们选择正确的 model 和model 的参数是非常有帮助的. 有了他的帮助, 我们能直观的看出不同 model 或者参数对结构准确度的影响.

from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split,cross_val_score
from sklearn.neighbors import KNeighborsClassifier
iris = load_iris()
X = iris.data
y = iris.target
# test train split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=4)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(knn.score(X_test, y_test)) # 0.973684210526
# this is cross_val_score
knn = KNeighborsClassifier(n_neighbors=5)
scores = cross_val_score(knn, X, y, cv=5, scoring='accuracy')
print(scores) # [ 0.96666667 1. 0.93333333 0.96666667 1. ]
print(scores.mean()) # 0.973333333333

这里写图片描述

import matplotlib.pyplot as plt
k_range = range(1, 31)
k_scores = []
for k in k_range:knn = KNeighborsClassifier(n_neighbors=k)# loss = -cross_val_score(knn, X, y, cv=10, scoring='mean_squared_error') # for regressionscores = cross_val_score(knn, X, y, cv=10, scoring='accuracy') # for classificationk_scores.append(scores.mean())
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.show()

cross validation 交叉验证2
sklearn.learning_curve 中的 learning curve 可以很直观的看出我们的 model 学习的进度,对比发现有没有 overfitting 的问题.然后我们可以对我们的 model 进行调整,克服 overfitting 的问题.

from sklearn.learning_curve import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
X = digits.data
y = digits.target
train_sizes, train_loss, test_loss= learning_curve(SVC(gamma=0.01), X, y, cv=10, scoring='mean_squared_error',train_sizes=[0.1, 0.25, 0.5, 0.75, 1])
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
plt.plot(train_sizes, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",label="Cross-validation")
plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

这里写图片描述
cross validation 交叉验证3
连续三节的 cross validation让我们知道在机器学习中 validation 是有多么的重要, 这一次的 sklearn 中我们用到了 sklearn.learning_curve 当中的另外一种, 叫做 validation_curve, 用这一种 curve 我们就能更加直观看出改变 model 中的参数的时候有没有 overfitting 的问题了.这也是可以让我们更好的选择参数的方法.

from sklearn.learning_curve import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
X = digits.data
y = digits.target
param_range = np.logspace(-6, -2.3, 5)
train_loss, test_loss = validation_curve(SVC(), X, y, param_name='gamma', param_range=param_range, cv=10,scoring='mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
plt.plot(param_range, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",label="Cross-validation")
plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

这里写图片描述
Save
练习好了一个 model 以后总需要保存和再次预测, 所以保存和读取我们的 sklearn model 也是同样重要的一步.

from sklearn import svm
from sklearn import datasets
clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X, y)
# method 1: pickle
import pickle
# save
with open('save/clf.pickle', 'wb') as f:pickle.dump(clf, f)
# restore
with open('save/clf.pickle', 'rb') as f:clf2 = pickle.load(f)print(clf2.predict(X[0:1]))
# method 2: joblib
from sklearn.externals import joblib
# Save
joblib.dump(clf, 'save/clf.pkl')
# restore
clf3 = joblib.load('save/clf.pkl')
print(clf3.predict(X[0:1]))

推荐阅读
  • Python自动化测试入门:Selenium环境搭建
    本文详细介绍如何在Python环境中安装和配置Selenium,包括开发工具PyCharm的安装、Python环境的设置以及Selenium包的安装方法。此外,还提供了编写和运行第一个自动化测试脚本的步骤。 ... [详细]
  • 本文详细介绍了一种通过MySQL弱口令漏洞在Windows操作系统上获取SYSTEM权限的方法。该方法涉及使用自定义UDF DLL文件来执行任意命令,从而实现对远程服务器的完全控制。 ... [详细]
  • 通常情况下,修改my.cnf配置文件后需要重启MySQL服务才能使新参数生效。然而,通过特定命令可以在不重启服务的情况下实现配置的即时更新。本文将详细介绍如何在线调整MySQL配置,并验证其有效性。 ... [详细]
  • 访问一个网页的全过程
    准备:DHCPUDPIP和以太网启动主机,用一根以太网电缆连接到学校的以太网交换机,交换机又与学校的路由器相连.学校的这台路由器与一个ISP链接,此ISP(Intern ... [详细]
  • 在高并发需求的C++项目中,我们最初选择了JsonCpp进行JSON解析和序列化。然而,在处理大数据量时,JsonCpp频繁抛出异常,尤其是在多线程环境下问题更为突出。通过分析发现,旧版本的JsonCpp存在多线程安全性和性能瓶颈。经过评估,我们最终选择了RapidJSON作为替代方案,并实现了显著的性能提升。 ... [详细]
  • 本文详细介绍了钩子(hook)的概念、原理及其在编程中的实际应用。通过对比回调函数和注册函数,解释了钩子的工作机制,并提供了具体的Python示例代码,帮助读者更好地理解和掌握这一重要编程工具。 ... [详细]
  • 本文详细介绍了 phpMyAdmin 的安装与配置方法,适用于多个版本的 phpMyAdmin。通过本教程,您将掌握从下载到部署的完整流程,并了解如何根据不同的环境进行必要的配置调整。 ... [详细]
  • 本文探讨了如何在Classic ASP中实现与PHP的hash_hmac('SHA256', $message, pack('H*', $secret))函数等效的哈希生成方法。通过分析不同实现方式及其产生的差异,提供了一种使用Microsoft .NET Framework的解决方案。 ... [详细]
  • 本文详细介绍了如何在云服务器上配置Nginx、Tomcat、JDK和MySQL。涵盖从下载、安装到配置的完整步骤,帮助读者快速搭建Java Web开发环境。 ... [详细]
  • 优化SQL Server批量数据插入存储过程的实现
    本文介绍了一种改进的SQL Server存储过程,用于生成批量插入语句。该方法不仅提高了性能,还支持单行和多行模式,适用于SQL Server 2005及以上版本。 ... [详细]
  • 本题要求在一组数中反复取出两个数相加,并将结果放回数组中,最终求出最小的总加法代价。这是一个经典的哈夫曼编码问题,利用贪心算法可以有效地解决。 ... [详细]
  • 使用JS、HTML5和C3创建自定义弹出窗口
    本文介绍如何结合JavaScript、HTML5和C3.js来实现一个功能丰富的自定义弹出窗口。通过具体的代码示例,详细讲解了实现过程中的关键步骤和技术要点。 ... [详细]
  • 本文探讨了如何利用HTML5和JavaScript在浏览器中进行本地文件的读取和写入操作,并介绍了获取本地文件路径的方法。HTML5提供了一系列API,使得这些操作变得更加简便和安全。 ... [详细]
  • 本文详细介绍了Java中实现异步调用的多种方式,包括线程创建、Future接口、CompletableFuture类以及Spring框架的@Async注解。通过代码示例和深入解析,帮助读者理解并掌握这些技术。 ... [详细]
  • ArcXML:互联网空间数据交换的专用语言
    ArcXML是一种专为ArcIMS平台设计的数据交换协议,基于XML标准,用于在不同组件之间传输和描述地理空间数据。本文将详细介绍ArcXML的背景、用途及其与XML的关系。 ... [详细]
author-avatar
mobiledu2502856411
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有