热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

朴素贝叶斯算法原理及SparkMLlib调用(Scala/Java/Python)

朴素贝叶斯算法介绍:朴素贝叶斯法是基于贝叶斯定理与特征条件独立假设的分类方法。朴素贝叶斯的思想基础是这样的:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,在没有其它

朴素贝叶斯

算法介绍:

朴素贝叶斯法是基于贝叶斯定理与特征条件独立假设的分类方法。

朴素贝叶斯的思想基础是这样的:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,在没有其它可用信息下,我们会选择条件概率最大的类别作为此待分类项应属的类别 。

朴素贝叶斯分类的正式定义如下:

1、设《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》为一个待分类项,而每个a为x的一个特征属性。

2、有类别集合《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》

3、计算《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》

4、如果《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》,则《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》

那么现在的关键就是如何计算第3步中的各个条件概率。我们可以这么做:

1、找到一个已知分类的待分类项集合,这个集合叫做训练样本集。

2、统计得到在各类别下各个特征属性的条件概率估计。即《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》

3、如果各个特征属性是条件独立的,则根据贝叶斯定理有如下推导:

《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》

因为分母对于所有类别为常数,因为我们只要将分子最大化皆可。又因为各特征属性是条件独立的,所以有:

《朴素贝叶斯算法原理及Spark MLlib调用(Scala/Java/Python)》

spark.ml现在支持多项朴素贝叶斯和伯努利朴素贝叶斯。

参数:

featuresCol:

类型:字符串型。

含义:特征列名。

labelCol:

类型:字符串型。

含义:标签列名。

modelType:

类型:字符串型。

含义:模型类型(区分大小写)。

predictionCol:

类型:字符串型。

含义:预测结果列名。

probabilityCol:

类型:字符串型。

含义:用以预测类别条件概率的列名。

rawPredictionCol:

类型:字符串型。

含义:原始预测。

smoothing:

类型:双精度型。

含义:平滑参数。

thresholds:

类型:双精度数组型。

含义:多分类预测的阀值,以调整预测结果在各个类别的概率。

示例:

Scala:

import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// Load the data stored in LIBSVM format as a DataFrame. val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing) val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L)
// Train a NaiveBayes model. val model = new NaiveBayes()
.fit(trainingData)
// Select example rows to display. val predictions = model.transform(testData)
predictions.show()
// Select (prediction, true label) and compute test error val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println("Accuracy: " + accuracy)

Java:

import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.classification.NaiveBayesModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// Load training data Dataset<Row> dataFrame =
spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Split the data into train and test Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
Dataset<Row> train = splits[0];
Dataset<Row> test = splits[1];
// create the trainer and set its parameters NaiveBayes nb = new NaiveBayes();
// train the model NaiveBayesModel model = nb.fit(train);
// compute accuracy on the test set Dataset<Row> result = model.transform(test);
Dataset<Row> predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));

Python:

from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# Load training data
data = spark.read.format("libsvm") \
.load("data/mllib/sample_libsvm_data.txt")
# Split the data into train and test
splits = data.randomSplit([0.6, 0.4], 1234)
train = splits[0]
test = splits[1]
# create the trainer and set its parameters
nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
# train the model
model = nb.fit(train)
# compute accuracy on the test set
result = model.transform(test)
predictionAndLabels = result.select("prediction", "label")
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
print("Accuracy: " + str(evaluator.evaluate(predictionAndLabels)))

推荐阅读
author-avatar
追求的幸福2012_102
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有