热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

数据分析(4)sklearn入门

如何选择机器学习方法http:scikit-learn.orgstabletutorialmachine_learning_mapindex.html通用学习模式只需要先定义

如何选择机器学习方法

http://scikit-learn.org/stable/tutorial/machine_learning_map/index.html
这里写图片描述


通用学习模式

只需要先定义 用什么model学习,然后再 model.fit(数据), 这样 model 就能从数据中学到东西. 最后还可以用 model.predict() 来预测值.

from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier
iris = datasets.load_iris()
iris_X = iris.data
iris_Y = iris.target
'''
输入有四个属性:[[ 5.1 3.5 1.4 0.2] [ 4.9 3. 1.4 0.2] ...]
输出类别:[0 0 0 ... 1 1 1 ... 2 2 2 ...]
'''

X_train,X_test,ytrain,y_test = train_test_split(iris_X,iris_Y,test_size=0.3) # 顺序也被打乱,按7:3
knn = KNeighborsClassifier()
knn.fit(X_train,ytrain) # 训练
print(knn.predict(X_test)) # 预测
print(y_test)

sklearn 的 datasets 数据库

Sklearn 提供了很多的有用的数据库,既有真实数据也有你可以编造的数据!特别的强大.http://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets

from sklearn import datasets
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
loaded_data = datasets.load_boston()
data_X = loaded_data.data
data_y = loaded_data.target
model = LinearRegression() # 里面有参数可以改变
model.fit(data_X,data_y)
print(model.predict(data_X[:4,:]))
print(data_y[:4])
'''
[ 30.00821269 25.0298606 30.5702317 28.60814055]
[ 24. 21.6 34.7 33.4]
'''

X, y = datasets.make_regression(n_samples=100, n_features=1, n_targets=1, noise=10)
plt.scatter(X, y)
plt.show()

这里写图片描述
model 常用属性和功能

# y = 0.1x + 0.3
print(model.coef_) # 输出0.1
print(model.intercept_) # 输出0.3
print(model.get_params()) # 返回给model认定的参数,比如{'copy_X': True, 'n_jobs': 1, 'normalize': False, 'fit_intercept': True}
print(model.score(data_X, data_y)) # R^2 coefficient of determination

normalization 标准化数据

normalization 在数据跨度不一的情况下对机器学习有很重要的作用.特别是各种数据属性还会互相影响的情况之下. Scikit-learn 中标准化的语句是 preprocessing.scale() . scale 以后, model 就更能从标准化数据中学到东西.
这里写图片描述

from sklearn.model_selection import train_test_split
from sklearn.datasets.samples_generator import make_classification
from sklearn.svm import SVC
import matplotlib.pyplot as plt
X, y = make_classification(n_samples=300, n_features=2 , n_redundant=0, n_informative=2,random_state=22, n_clusters_per_class=1, scale=100)
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
X = preprocessing.scale(X) # normalization step
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)
clf = SVC()
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test)) # 0.944444444444

cross validation 交叉验证1
sklearn 中的 cross validation 交叉验证 对于我们选择正确的 model 和model 的参数是非常有帮助的. 有了他的帮助, 我们能直观的看出不同 model 或者参数对结构准确度的影响.

from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split,cross_val_score
from sklearn.neighbors import KNeighborsClassifier
iris = load_iris()
X = iris.data
y = iris.target
# test train split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=4)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(knn.score(X_test, y_test)) # 0.973684210526
# this is cross_val_score
knn = KNeighborsClassifier(n_neighbors=5)
scores = cross_val_score(knn, X, y, cv=5, scoring='accuracy')
print(scores) # [ 0.96666667 1. 0.93333333 0.96666667 1. ]
print(scores.mean()) # 0.973333333333

这里写图片描述

import matplotlib.pyplot as plt
k_range = range(1, 31)
k_scores = []
for k in k_range:knn = KNeighborsClassifier(n_neighbors=k)# loss = -cross_val_score(knn, X, y, cv=10, scoring='mean_squared_error') # for regressionscores = cross_val_score(knn, X, y, cv=10, scoring='accuracy') # for classificationk_scores.append(scores.mean())
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.show()

cross validation 交叉验证2
sklearn.learning_curve 中的 learning curve 可以很直观的看出我们的 model 学习的进度,对比发现有没有 overfitting 的问题.然后我们可以对我们的 model 进行调整,克服 overfitting 的问题.

from sklearn.learning_curve import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
X = digits.data
y = digits.target
train_sizes, train_loss, test_loss= learning_curve(SVC(gamma=0.01), X, y, cv=10, scoring='mean_squared_error',train_sizes=[0.1, 0.25, 0.5, 0.75, 1])
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
plt.plot(train_sizes, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",label="Cross-validation")
plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

这里写图片描述
cross validation 交叉验证3
连续三节的 cross validation让我们知道在机器学习中 validation 是有多么的重要, 这一次的 sklearn 中我们用到了 sklearn.learning_curve 当中的另外一种, 叫做 validation_curve, 用这一种 curve 我们就能更加直观看出改变 model 中的参数的时候有没有 overfitting 的问题了.这也是可以让我们更好的选择参数的方法.

from sklearn.learning_curve import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
digits = load_digits()
X = digits.data
y = digits.target
param_range = np.logspace(-6, -2.3, 5)
train_loss, test_loss = validation_curve(SVC(), X, y, param_name='gamma', param_range=param_range, cv=10,scoring='mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
plt.plot(param_range, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",label="Cross-validation")
plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

这里写图片描述
Save
练习好了一个 model 以后总需要保存和再次预测, 所以保存和读取我们的 sklearn model 也是同样重要的一步.

from sklearn import svm
from sklearn import datasets
clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X, y)
# method 1: pickle
import pickle
# save
with open('save/clf.pickle', 'wb') as f:pickle.dump(clf, f)
# restore
with open('save/clf.pickle', 'rb') as f:clf2 = pickle.load(f)print(clf2.predict(X[0:1]))
# method 2: joblib
from sklearn.externals import joblib
# Save
joblib.dump(clf, 'save/clf.pkl')
# restore
clf3 = joblib.load('save/clf.pkl')
print(clf3.predict(X[0:1]))

推荐阅读
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文详细介绍了IBM DB2数据库在大型应用系统中的应用,强调其卓越的可扩展性和多环境支持能力。文章深入分析了DB2在数据利用性、完整性、安全性和恢复性方面的优势,并提供了优化建议以提升其在不同规模应用程序中的表现。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 优化ListView性能
    本文深入探讨了如何通过多种技术手段优化ListView的性能,包括视图复用、ViewHolder模式、分批加载数据、图片优化及内存管理等。这些方法能够显著提升应用的响应速度和用户体验。 ... [详细]
  • 本文详细介绍如何使用Python进行配置文件的读写操作,涵盖常见的配置文件格式(如INI、JSON、TOML和YAML),并提供具体的代码示例。 ... [详细]
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 1:有如下一段程序:packagea.b.c;publicclassTest{privatestaticinti0;publicintgetNext(){return ... [详细]
  • 本文介绍了Java并发库中的阻塞队列(BlockingQueue)及其典型应用场景。通过具体实例,展示了如何利用LinkedBlockingQueue实现线程间高效、安全的数据传递,并结合线程池和原子类优化性能。 ... [详细]
  • 本文详细介绍了Java中org.eclipse.ui.forms.widgets.ExpandableComposite类的addExpansionListener()方法,并提供了多个实际代码示例,帮助开发者更好地理解和使用该方法。这些示例来源于多个知名开源项目,具有很高的参考价值。 ... [详细]
  • 本文介绍了如何在C#中启动一个应用程序,并通过枚举窗口来获取其主窗口句柄。当使用Process类启动程序时,我们通常只能获得进程的句柄,而主窗口句柄可能为0。因此,我们需要使用API函数和回调机制来准确获取主窗口句柄。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 本文探讨了如何优化和正确配置Kafka Streams应用程序以确保准确的状态存储查询。通过调整配置参数和代码逻辑,可以有效解决数据不一致的问题。 ... [详细]
  • 本文详细介绍了macOS系统的核心组件,包括如何管理其安全特性——系统完整性保护(SIP),并探讨了不同版本的更新亮点。对于使用macOS系统的用户来说,了解这些信息有助于更好地管理和优化系统性能。 ... [详细]
  • IT项目管理过程中的方法、工具、技术
    工欲善其事,必先利其器。而对于一个软件开发项目,最重要的器就是方法,工具和技术。而这三要素中重要的又是方法论,方法是基础&# ... [详细]
author-avatar
mobiledu2502856411
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有