热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

数据分析(4)sklearn入门

如何选择机器学习方法http:scikit-learn.orgstabletutorialmachine_learning_mapindex.html通用学习模式只需要先定义

如何选择机器学习方法

http://scikit-learn.org/stable/tutorial/machine_learning_map/index.html
这里写图片描述


通用学习模式

只需要先定义 用什么model学习,然后再 model.fit(数据), 这样 model 就能从数据中学到东西. 最后还可以用 model.predict() 来预测值.

from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier
iris = datasets.load_iris()
iris_X = iris.data
iris_Y = iris.target
'''
输入有四个属性:[[ 5.1 3.5 1.4 0.2] [ 4.9 3. 1.4 0.2] ...]
输出类别:[0 0 0 ... 1 1 1 ... 2 2 2 ...]
'''

X_train,X_test,ytrain,y_test = train_test_split(iris_X,iris_Y,test_size=0.3) # 顺序也被打乱,按7:3
knn = KNeighborsClassifier()
knn.fit(X_train,ytrain) # 训练
print(knn.predict(X_test)) # 预测
print(y_test)

sklearn 的 datasets 数据库

Sklearn 提供了很多的有用的数据库,既有真实数据也有你可以编造的数据!特别的强大.http://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets

from sklearn import datasets
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
loaded_data = datasets.load_boston()
data_X = loaded_data.data
data_y = loaded_data.target
model = LinearRegression() # 里面有参数可以改变
model.fit(data_X,data_y)
print(model.predict(data_X[:4,:]))
print(data_y[:4])
'''
[ 30.00821269 25.0298606 30.5702317 28.60814055]
[ 24. 21.6 34.7 33.4]
'''

X, y = datasets.make_regression(n_samples=100, n_features=1, n_targets=1, noise=10)
plt.scatter(X, y)
plt.show()

这里写图片描述
model 常用属性和功能

# y = 0.1x + 0.3
print(model.coef_) # 输出0.1
print(model.intercept_) # 输出0.3
print(model.get_params()) # 返回给model认定的参数,比如{'copy_X': True, 'n_jobs': 1, 'normalize': False, 'fit_intercept': True}
print(model.score(data_X, data_y)) # R^2 coefficient of determination

normalization 标准化数据

normalization 在数据跨度不一的情况下对机器学习有很重要的作用.特别是各种数据属性还会互相影响的情况之下. Scikit-learn 中标准化的语句是 preprocessing.scale() . scale 以后, model 就更能从标准化数据中学到东西.
这里写图片描述

from sklearn.model_selection import train_test_split
from sklearn.datasets.samples_generator import make_classification
from sklearn.svm import SVC
import matplotlib.pyplot as plt
X, y = make_classification(n_samples=300, n_features=2 , n_redundant=0, n_informative=2,random_state=22, n_clusters_per_class=1, scale=100)
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
X = preprocessing.scale(X) # normalization step
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)
clf = SVC()
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test)) # 0.944444444444

cross validation 交叉验证1
sklearn 中的 cross validation 交叉验证 对于我们选择正确的 model 和model 的参数是非常有帮助的. 有了他的帮助, 我们能直观的看出不同 model 或者参数对结构准确度的影响.

from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split,cross_val_score
from sklearn.neighbors import KNeighborsClassifier
iris = load_iris()
X = iris.data
y = iris.target
# test train split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=4)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(knn.score(X_test, y_test)) # 0.973684210526
# this is cross_val_score
knn = KNeighborsClassifier(n_neighbors=5)
scores = cross_val_score(knn, X, y, cv=5, scoring='accuracy')
print(scores) # [ 0.96666667 1. 0.93333333 0.96666667 1. ]
print(scores.mean()) # 0.973333333333

这里写图片描述

import matplotlib.pyplot as plt
k_range = range(1, 31)
k_scores = []
for k in k_range:knn = KNeighborsClassifier(n_neighbors=k)# loss = -cross_val_score(knn, X, y, cv=10, scoring='mean_squared_error') # for regressionscores = cross_val_score(knn, X, y, cv=10, scoring='accuracy') # for classificationk_scores.append(scores.mean())
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.show()

cross validation 交叉验证2
sklearn.learning_curve 中的 learning curve 可以很直观的看出我们的 model 学习的进度,对比发现有没有 overfitting 的问题.然后我们可以对我们的 model 进行调整,克服 overfitting 的问题.

from sklearn.learning_curve import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
X = digits.data
y = digits.target
train_sizes, train_loss, test_loss= learning_curve(SVC(gamma=0.01), X, y, cv=10, scoring='mean_squared_error',train_sizes=[0.1, 0.25, 0.5, 0.75, 1])
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
plt.plot(train_sizes, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",label="Cross-validation")
plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

这里写图片描述
cross validation 交叉验证3
连续三节的 cross validation让我们知道在机器学习中 validation 是有多么的重要, 这一次的 sklearn 中我们用到了 sklearn.learning_curve 当中的另外一种, 叫做 validation_curve, 用这一种 curve 我们就能更加直观看出改变 model 中的参数的时候有没有 overfitting 的问题了.这也是可以让我们更好的选择参数的方法.

from sklearn.learning_curve import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
X = digits.data
y = digits.target
param_range = np.logspace(-6, -2.3, 5)
train_loss, test_loss = validation_curve(SVC(), X, y, param_name='gamma', param_range=param_range, cv=10,scoring='mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
plt.plot(param_range, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",label="Cross-validation")
plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

这里写图片描述
Save
练习好了一个 model 以后总需要保存和再次预测, 所以保存和读取我们的 sklearn model 也是同样重要的一步.

from sklearn import svm
from sklearn import datasets
clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X, y)
# method 1: pickle
import pickle
# save
with open('save/clf.pickle', 'wb') as f:pickle.dump(clf, f)
# restore
with open('save/clf.pickle', 'rb') as f:clf2 = pickle.load(f)print(clf2.predict(X[0:1]))
# method 2: joblib
from sklearn.externals import joblib
# Save
joblib.dump(clf, 'save/clf.pkl')
# restore
clf3 = joblib.load('save/clf.pkl')
print(clf3.predict(X[0:1]))

推荐阅读
  • 权限_MS17010远程溢出漏洞(CVE20170143)拿权限
    本文由编程笔记#小编为大家整理,主要介绍了MS17-010远程溢出漏洞(CVE-2017-0143)拿权限相关的知识,希望对你有一定的参考价值。0x00 ... [详细]
  • http:blog.chinaunix.netuid-23204078-id-2525171.html[误解]#define_XOPEN_SOURCE决不是简单的宏定义它是使程序符 ... [详细]
  • 本文主要参考《Python机器学习经典实例》  在介绍凝聚层次聚类之前,我们需要先理解层次聚类(hierarchicalclustering)。层次聚类是一组聚类算法,通过不断地分 ... [详细]
  • 点击上方[全栈开发者社区]→右上角[]→[设为星标⭐]Nginx是一个高性能的HTTP和反向代理服务器,特点是占用内存少,并发能力强, ... [详细]
  • nsitionalENhttp:www.w3.orgTRxhtml1DTDxhtml1-transitional.dtd ... [详细]
  • 内容:[J2SE5.0]用Executor灵活处理事件下发作者:AndrewThompson译者:xMatrix版权声明:任何获得Matrix授权的网站,转载时请务必以超链接形式标 ... [详细]
  • nginx配置浅谈
    nginx配置浅谈 ... [详细]
  • 什么都不说,先上总结的图~ SelectorsAPI(选择符API)querySelector()方法接收一个css选择符,返回与该模式匹配的第一个元素,如果没有找到匹配的元素,返 ... [详细]
  • angular熟练使用(demo)增删改查排序
    <!DOCTYPEhtml><html>   &am ... [详细]
  • 引起w3wp.exe(IIS)Cpu占用100%的常见原因如下:1.Web访问量大,从而服务器压力大而引起的2.动态页面(.aspx)的程序逻辑复杂程度 ... [详细]
  • 1、war是一个web模块,其中需要包括WEB-INF,是可以直接运行的WEB模块;jar一般只是包括一些class文件,在声明了Main_class之后是可以用java命令运行的。2、wa ... [详细]
  • 传统c语言开发,C语言系统开发
    本文目录一览:1、简述开发一个c语言程序的步骤 ... [详细]
  • 1http:blog.csdn.netlfdfhlarticledetails8220729代码如下:imageView.startAnimation(welcomeAnimation) ... [详细]
  • 有关partitionaligned,gparted支持所以请使用gparted来进行分区,如果不是aligned,gparted会warning的。有 ... [详细]
  • AOP是Spring的核心,Spring不但自身对多种框架的集成是基于AOP,并且以非常方便的形式暴露给普通使用者。以前用AOP不多,主要是因为它以横截面的方式插入到主流程中,担心导致主流程代码 ... [详细]
author-avatar
mobiledu2502856411
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有