热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Spark实现高斯朴素贝叶斯模型的低配版

本文介绍了使用Spark实现低配版高斯朴素贝叶斯模型的原因和原理。随着数据量的增大,单机上运行高斯朴素贝叶斯模型会变得很慢,因此考虑使用Spark来加速运行。然而,Spark的MLlib并没有实现高斯朴素贝叶斯模型,因此需要自己动手实现。文章还介绍了朴素贝叶斯的原理和公式,并对具有多个特征和类别的模型进行了讨论。最后,作者总结了实现低配版高斯朴素贝叶斯模型的步骤。
Motivation

最近有项目用到Scikit-learn上的高斯朴素贝叶斯模型(简称GNB),随着数据量增大,单机上跑GNB肯定会很慢,所以打算转Spark上。然后发现MLlib并没有实现GNB,自己动手,丰衣足食~

原理

GNB的原理是基于朴素贝叶斯,所以先交代朴素贝叶斯的原理。

朴素贝叶斯

贝叶斯公式

![](http://www.forkosh.com/mathtex.cgi? P(Y \mid X) = \frac{P(X \mid Y)*P(Y)}{P(X)})

利用贝叶斯公式我们就可以在已知P(X|Y)和P(Y)的情况下计算得出P(Y|X)。现在把Y看成类别,把X看成特征,那么利用贝叶斯公式,我们在已知“特征X出现的时候类别为Y的概率P(X|Y)” 和 “类别为Y的概率P(Y)”的情况下,我们就可以计算在特征X出现的情况下其类别为Y的概率P(Y|X)。
  上面只考虑了只有一种特征的情况,现在考虑模型有N种特征和C种类别的情况。在给定特征X的情况下,求类别为k的概率,公式可以表示成

![](http://www.forkosh.com/mathtex.cgi? P(Y=k \mid X_{1},…,X_{N}) \= \frac{P(X_{1},…,X_{N} \mid Y=k)P(Y=k)}{P(X_{1},…,X_{N})} \= \frac{P(Y=k)\prod_{i}^{N}P(X_{i}\mid Y=k)}{\sum_{j}{C}P(Y=j)*\prod_{i}{N}P(X_{i}\mid Y=j)} )

根据上式,我们可以计算在特征X出现的情况下其类别为Y=k的概率,对于所有的k,我们取概率最大的(最大后验)作为我们的Predict,这就是朴素贝叶斯的思路。
  等等,好像有点问题,凭什么说
![](http://www.forkosh.com/mathtex.cgi? \prod_{i}^{N}P(X_{i}\mid Y=k) = P(X_{1},…,P_{N}|Y=k) )

对的,这就是朴素贝叶斯Naive的地方,它基于一个很强的假设——所有特征的出现是相互独立的,这也是朴素贝叶斯的局限性。
  在实际应用中,还需要考虑极端情况——某个类别没有出现在样本集中 or 某个特征没有出现在某类样本集中。这个时候就需要加入平滑因子lambda去调整。

![](http://www.forkosh.com/mathtex.cgi? P(Y=k)=\frac{Number\ of\ Labeled\ k\ Samples\ +\ lambda}{Number\ of\ Samples\ +\ Number\ of\ Labels\ * \ lambda} )

多项式模型下:
![](http://www.forkosh.com/mathtex.cgi? P(X_{i} \mid Y=k) = \frac{Count\ of\ Feature\ i\ in Labeled\ k\ Samples\ +\ lambda}{Count\ of\ All\ Features\ in\ Labeled\ k\ Samples\ +\ Count\ of Feature’s kind\ * \ lambda} )

伯努力模型下:
![](http://www.forkosh.com/mathtex.cgi? P(X_{i} \mid Y=k) = \frac{Count\ of\ Feature\ i\ in Labeled\ k\ Samples\ +\ lambda}{Count\ of\ All\ Features\ in\ Labeled\ k\ Samples\ +2 \ * \ lambda} )

朴素贝叶斯有两种常用的模型,一种叫伯努利模型,另一种叫多项式模型。两者的区别就在于伯努利模型只考虑在一个样本中,特征是否出现了(例如某个词语是否出现了,0 or 1),而多项式模型则会考虑一个样本中特征出现的次数(例如某个词语出现的次数,一个具体的数字)。两种模型都是面向离散型的特征,如果被建模对象的特征是连续变量时,一般有两个解决方案,一是量化连续型的特征成离散型的,另一种则使用高斯朴素贝叶斯。

高斯朴素贝叶斯

高斯模型下的朴素贝叶斯与上面介绍的两种模型不同的地方是在计算P(X|Y)时,假设其服从高斯分布,这是对于连续型的特征有很友好的表现。

![](http://www.forkosh.com/mathtex.cgi? P(X \mid Y) \backsim N(\mu,\sigma^{2}) \P(X=a \mid Y = k)=\frac{1}{\sqrt{2\pi}\sigma}e{-\frac{(a-\mu){2}}{2\sigma^{2}}})
  对于上式的均值和方差都是可以从样本集中统计得出。
  上述利用高斯分布,我们把连续变量转变成一个概率,上一小节提到的特征是连续变量的问题解决了,其它一切照搬Naive Bayes即可。

实现

Talk is cheap,show me the code. 接下来讲讲具体实现,由于Spark MLlib中实现的向量对外API甚少,所以自己动手写了个LabeledPoint

class LabeledPoint(val label: Double, val denseVector: DenseVector[Double])
extends Serializable {
}
object LabeledPoint extends Serializable {
def apply(label: Double, denseVector: DenseVector[Double]) = {
new LabeledPoint(label, denseVector)
}
}

高斯分布函数,给入均值和方差,生成分布函数,使用柯里化

def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = {
if (variance == 0.0) {
if (x == mean) 1.0
else 0.0
} else {
1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance))
}
}

核心代码全览

import breeze.linalg.DenseVector
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import breeze.numerics._
import scala.math.Pi
import xyz.qspring.spark.ml.base.LabeledPoint//注意:就是上面的LabeledPoint
/**
* Created by qero on 16/8/7.
*/
class GuassianNaiveBayes private (private val input: RDD[LabeledPoint], private val lambda: Double = 1.0) extends Serializable with Logging{
def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = { //柯里化分布函数
if (variance == 0.0) {
if (x == mean) 1.0
else 0.0
} else {
1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance))
}
}
def run() = {
val sampleN = input.count
val grouped = input.map(point => (point.label, point.denseVector)).groupByKey().cache
val classN = grouped.count
//计算各类的出现概率(注意平滑因子lambda)
val pi = grouped.map{case (c, a) => {
val p = (a.toList.length * 1.0 + lambda) / (sampleN + lambda * classN)
(c, log2(p)) //取对数,防止后期出现连乘(小数连乘容易精度丢失)
}}
//计算在各类情况下的各特征的均值和方差
val pji = grouped.mapValues(a => {
val aSum = a.reduce((v1 ,v2) => v1 + v2) //求总数
val aSampleN = a.toArray.length //求总数
val mean = aSum / (aSampleN * 1.0) //求均值
val variance = a.map(i => { //求方差(去中心化->求和->求均值)
(i - mean) :* (i - mean)
}).reduce((v1 ,v2) => v1 + v2) / (aSampleN * 1.0)
val paras = mean.toArray.zip(variance.toArray)
paras.map(p => distributiveFunc(p._1, p._2)_) //返回(类别,[特征1的分布函数, ..., 特征n的分布函数])
})
new GuassianNBModel(pi.collectAsMap(), pji.collectAsMap())
}
}
class GuassianNBModel(val pi:collection.Map[Double, Double], val pji:collection.Map[Double, Array[Double => Double]]) extends Serializable {
def predict(features: DenseVector[Double]) = {
pji.map{case (label, models) => {
val score = models.zip(features.toArray).map{case (m, v) => {
log2(m(v)) //取对数,防止后期出现连乘(小数连乘容易精度丢失)
}}.sum + pi(label)
(score, label) //返回(log(P(F1...Fn|Label)*P(Label)), Label)
}}.max //选概率最大的,其对应的Label就是模型的预测
}
}
object GuassianNaiveBayes extends Serializable {
def fit(input: RDD[LabeledPoint]) = {
new GuassianNaiveBayes(input).run()
}
}

测试文件,训练集train.dat

-0.017612 14.053064 0
-1.395634 4.662541 1
-0.752157 6.538620 0
-1.322371 7.152853 0
0.423363 11.054677 0
0.406704 7.067335 1
0.667394 12.741452 0
-2.460150 6.866805 1
0.569411 9.548755 0
-0.026632 10.427743 0
0.850433 6.920334 1
1.347183 13.175500 0
1.176813 3.167020 1
-1.781871 9.097953 0
-0.566606 5.749003 1
0.931635 1.589505 1
-0.024205 6.151823 1
-0.036453 2.690988 1
-0.196949 0.444165 1
1.014459 5.754399 1
1.985298 3.230619 1
-1.693453 -0.557540 1
-0.576525 11.778922 0
-0.346811 -1.678730 1
-2.124484 2.672471 1
1.217916 9.597015 0
-0.733928 9.098687 0
-3.642001 -1.618087 1
0.315985 3.523953 1
1.416614 9.619232 0
-0.386323 3.989286 1
0.556921 8.294984 1
1.224863 11.587360 0
-1.347803 -2.406051 1
-0.445678 3.297303 1
1.042222 6.105155 1
-0.618787 10.320986 0
1.152083 0.548467 1
0.828534 2.676045 1
-1.237728 10.549033 0
-0.683565 -2.166125 1
0.229456 5.921938 1
-0.959885 11.555336 0
0.492911 10.993324 0
0.184992 8.721488 0
-0.355715 10.325976 0
-0.397822 8.058397 0
0.824839 13.730343 0
1.507278 5.027866 1
0.099671 6.835839 1
-0.344008 10.717485 0
1.785928 7.718645 1
-0.918801 11.560217 0
-0.364009 4.747300 1
-0.841722 4.119083 1
0.490426 1.960539 1
-0.007194 9.075792 0
0.356107 12.447863 0
0.342578 12.281162 0
-0.810823 -1.466018 1
2.530777 6.476801 1
1.296683 11.607559 0
0.475487 12.040035 0
-0.783277 11.009725 0
0.074798 11.023650 0
-1.337472 0.468339 1
-0.102781 13.763651 0
-0.147324 2.874846 1
0.518389 9.887035 0
1.015399 7.571882 0
-1.658086 -0.027255 1
1.319944 2.171228 1
2.056216 5.019981 1
-0.851633 4.375691 1
-1.510047 6.061992 0
-1.076637 -3.181888 1
1.821096 10.283990 0
3.010150 8.401766 1
-1.099458 1.688274 1
-0.834872 -1.733869 1
-0.846637 3.849075 1

测试文件,测试集test.dat

1.400102 12.628781 0
1.752842 5.468166 1
0.078557 0.059736 1
0.089392 -0.715300 1
1.825662 12.693808 0
0.197445 9.744638 0
0.126117 0.922311 1
-0.679797 1.220530 1
0.677983 2.556666 1
0.761349 10.693862 0
-2.168791 0.143632 1
1.388610 9.341997 0
0.275221 9.543647 0
0.470575 9.332488 0
-1.889567 9.542662 0
-1.527893 12.150579 0
-1.185247 11.309318 0

测试程序

object Main extends App {
override def main(args: Array[String]) {
val cOnf= new SparkConf().setAppName("naive_bayes")
val sc = new SparkContext(conf)
val data = sc.textFile("data/train.dat")
Logger.getRootLogger.setLevel(Level.WARN)
val trainData = data.map(line => {
val items = line.split("\\s+")
LabeledPoint(items(items.length-1).toDouble, DenseVector(items.slice(0, items.length-1).map(_.toDouble)))
})
val model = GuassianNaiveBayes.fit(trainData)
val testData = sc.textFile("data/test.dat").foreach(line => {
val items = line.split("\\s+")
val res = model.predict(DenseVector(items.slice(0, items.length-1).map(_.toDouble)))
println("true is " + items(items.length - 1) + ", predict is " + res._2 + ", score = " + pow(2, res._1))
})
}
}

结果

true is 0, predict is 0.0, score = 0.007287035226911837
true is 1, predict is 1.0, score = 0.006537938765007012
true is 1, predict is 1.0, score = 0.012801368971056088
true is 1, predict is 1.0, score = 0.00970655657450153
true is 0, predict is 0.0, score = 0.00305462018270487
true is 0, predict is 0.0, score = 0.03716655013066987
true is 1, predict is 1.0, score = 0.01613160178250759
true is 1, predict is 1.0, score = 0.01548224987302873
true is 1, predict is 1.0, score = 0.01784234527209572
true is 0, predict is 0.0, score = 0.029683595996118462
true is 1, predict is 1.0, score = 0.0037636068269885714
true is 0, predict is 0.0, score = 0.011051732411404247
true is 0, predict is 0.0, score = 0.034819190499309864
true is 0, predict is 0.0, score = 0.03027279470621322
true is 0, predict is 0.0, score = 0.003400879969005375
true is 0, predict is 0.0, score = 0.0060605923826227105
true is 0, predict is 0.0, score = 0.014488715477020412

推荐阅读
  • 扫描线三巨头 hdu1928hdu 1255  hdu 1542 [POJ 1151]
    学习链接:http:blog.csdn.netlwt36articledetails48908031学习扫描线主要学习的是一种扫描的思想,后期可以求解很 ... [详细]
  • This document outlines the recommended naming conventions for HTML attributes in Fast Components, focusing on readability and consistency with existing standards. ... [详细]
  • libsodium 1.0.15 发布:引入重大不兼容更新
    最新发布的 libsodium 1.0.15 版本带来了若干不兼容的变更,其中包括默认密码散列算法的更改和其他重要调整。 ... [详细]
  • ServiceStack与Swagger的无缝集成指南
    本文详细介绍了如何在ServiceStack项目中集成Swagger,以实现API文档的自动生成和在线测试。通过本指南,您将了解从配置到部署的完整流程,并掌握如何优化API接口的开发和维护。 ... [详细]
  • 机器学习中的相似度度量与模型优化
    本文探讨了机器学习中常见的相似度度量方法,包括余弦相似度、欧氏距离和马氏距离,并详细介绍了如何通过选择合适的模型复杂度和正则化来提高模型的泛化能力。此外,文章还涵盖了模型评估的各种方法和指标,以及不同分类器的工作原理和应用场景。 ... [详细]
  • 本文详细介绍了macOS系统的核心组件,包括如何管理其安全特性——系统完整性保护(SIP),并探讨了不同版本的更新亮点。对于使用macOS系统的用户来说,了解这些信息有助于更好地管理和优化系统性能。 ... [详细]
  • 2023年京东Android面试真题解析与经验分享
    本文由一位拥有6年Android开发经验的工程师撰写,详细解析了京东面试中常见的技术问题。涵盖引用传递、Handler机制、ListView优化、多线程控制及ANR处理等核心知识点。 ... [详细]
  • 本文详细介绍了Java中的访问器(getter)和修改器(setter),探讨了它们在保护数据完整性、增强代码可维护性方面的重要作用。通过具体示例,展示了如何正确使用这些方法来控制类属性的访问和更新。 ... [详细]
  • 高效解决应用崩溃问题!友盟新版错误分析工具全面升级
    友盟推出的最新版错误分析工具,专为移动开发者设计,提供强大的Crash收集与分析功能。该工具能够实时监控App运行状态,快速发现并修复错误,显著提升应用的稳定性和用户体验。 ... [详细]
  • 使用Python在SAE上开发新浪微博应用的初步探索
    最近重新审视了新浪云平台(SAE)提供的服务,发现其已支持Python开发。本文将详细介绍如何利用Django框架构建一个简单的新浪微博应用,并分享开发过程中的关键步骤。 ... [详细]
  • 图数据库中的知识表示与推理机制
    本文探讨了图数据库及其技术生态系统在知识表示和推理问题上的应用。通过理解图数据结构,尤其是属性图的特性,可以为复杂的数据关系提供高效且优雅的解决方案。我们将详细介绍属性图的基本概念、对象建模、概念建模以及自动推理的过程,并结合实际代码示例进行说明。 ... [详细]
  • 本文详细介绍了VMware的多种认证选项,帮助你根据职业需求和个人技能选择最合适的认证路径,涵盖从基础到高级的不同层次认证。 ... [详细]
  • 本文详细介绍了 Java 中 org.apache.xmlbeans.SchemaType 类的 getBaseEnumType() 方法,提供了多个代码示例,并解释了其在不同场景下的使用方法。 ... [详细]
  • VPX611是北京青翼科技推出的一款采用6U VPX架构的高性能数据存储板。该板卡搭载两片Xilinx Kintex-7系列FPGA作为主控单元,内置RAID控制器,支持多达8个mSATA盘,最大存储容量可达8TB,持续写入带宽高达3.2GB/s。 ... [详细]
  • 获取计算机硬盘序列号的方法与实现
    本文介绍了如何通过编程方法获取计算机硬盘的唯一标识符(序列号),并提供了详细的代码示例和解释。此外,还涵盖了如何使用这些信息进行身份验证或注册保护。 ... [详细]
author-avatar
台艾辉_435
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有