热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Python西瓜书使用数据集3.0α线性核和高斯核训练SVM+散点图可视化

西瓜数据集3.0α#-*-coding:utf-8-*-importnumpyasnpimportmatplotlib.pyplotaspltfromsklearnimpo

西瓜数据集3.0α

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
import pandas as pd
from sklearn.metrics import accuracy_score#返回正确的比例
from sklearn.preprocessing import LabelEncoderplt.rcParams['font.sans-serif'] = ['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False #用来正常显示负号
plt.close('all')def main():
#1.获取x,ydata = pd.read_table('watermelon30a.txt',delimiter=',') x = pd.DataFrame({'密度':data['密度'],'含糖率':data['含糖率']})x = x.values.tolist() encoder = LabelEncoder()#将好瓜坏瓜映射为1/0y = encoder.fit_transform(data['好瓜']).tolist()x,y = np.array(x),np.array(y)
#2.1.线性核 linear_svm = svm.SVC(C=0.5, #惩罚参数kernel='linear')linear_svm.fit(x,y)y_pred = linear_svm.predict(x)print('**linear_svm的准确率**: %s' %(accuracy_score(y_pred=y_pred, y_true=y)))
##2.2.高斯核gauss_svm = svm.SVC(C=0.5,kernel='rbf')gauss_svm.fit(x,y)y_pred2 = gauss_svm.predict(x)print('**gauss_svm的准确率**: %s' %(accuracy_score(y_pred=y_pred2, y_true=y))) class_method = {'线性核':linear_svm,'高斯核':gauss_svm}visual(data,class_method)##数据特征可视化
def visual(data,class_method):colormap = dict(zip(data['好瓜'].value_counts().index.tolist(),['blue','green']))#坏瓜好瓜颜色die = data.groupby('好瓜') plt.figure()for species,klass in die:plt.scatter(klass['密度'],klass['含糖率'],color = colormap[species],label = species)for name,model in class_method.items():sv = model.support_vectors_plt.plot(sv[:,0],sv[:,1],label=str(name)+'_supported_vector') plt.legend(frameon=True, title='好瓜',loc="upper left") plt.title('SVC')plt.show()if __name__=="__main__":main()

结果表明,使用线性核和高斯训练核的支持向量实际是一样的(两条线重合):



推荐阅读
author-avatar
哈罗xeh_406
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有