热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Scala利用Buffer写UDAF详解

1、什么情况下使用UDAF?当官方提供的聚合函数不能满足需求的时候可以考虑自己写一个UDAF。但通常官方提供的聚合函数基本是能满足常用的开发需求的了。2、怎么注册

1、什么情况下使用UDAF? 

当官方提供的聚合函数不能满足需求的时候可以考虑自己写一个UDAF。但通常官方提供的聚合函数基本是能满足常用的开发需求的了。

2、怎么注册UDAF?注册后怎么调用UDAF?

2.1 注册UDAF

spark.udf.register("myUDAF", new MyUDAFExample)

2.2 调用UDAF

2.2.1 SparkSql中调用

// 使用group by 做分组,然后调用 UDAF聚合函数.
spark.sql("""select group_id, myUDAFExampe(id) from simple group by group_id""")

2.2.2 使用DataFrame语法调用

// 在DataFrame语法中使用UDAF//创建实例
val myUDAF = new MyUDAFExample// groupBy 分组后,在agg()中调用UDAF
df.groupBy("group_id").agg(myUDAF(col("id")).as("udafCnt")).show()// groupBy 分组后,在agg(expr())表达式中调用UDAF
df.groupBy("group_id").agg(expr("myUDAF(id) as udafCnt")).show()

3、怎么写UDAF?

3.1 首先我们看看官方的UDAF的抽象类是怎么定义的

abstract class UserDefinedAggregateFunction extends Serializable {/*** A `StructType` represents data types of input arguments of this aggregate function.* For example, if a [[UserDefinedAggregateFunction]] expects two input arguments* with type of `DoubleType` and `LongType`, the returned `StructType` will look like** ```* new StructType()* .add("doubleInput", DoubleType)* .add("longInput", LongType)* ```** The name of a field of this `StructType` is only used to identify the corresponding input argument. Users can choose names to identify the input arguments.** 定义输入row的Schema,可以用来识别对应的输入参数*/def inputSchema: StructType/*** A `StructType` represents data types of values in the aggregation buffer.* For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values* (i.e. two intermediate values) with type of `DoubleType` and `LongType`,* the returned `StructType` will look like** ```* new StructType()* .add("doubleInput", DoubleType)* .add("longInput", LongType)* ```** The name of a field of this `StructType` is only used to identify the corresponding buffer value. Users can choose names to identify the input arguments.** 定义buffer的Schema。根据Schema中定义的列名来识别其在buffer中的值。*/def bufferSchema: StructType/*** The `DataType` of the returned value of this [[UserDefinedAggregateFunction]].* 定义UDAF最终返回结果的数据类型。*/def dataType: DataType/*** Returns true iff this function is deterministic, i.e. given the same input,* always return the same output.*一致性检验,即要求计算结果是严格准确的。对于相同的输入,不管计算几次,结果都是一样的,则设置为true,否则设置为false。false对应的是近似计算。*/def deterministic: Boolean/*** Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.* The contract should be that applying the merge function on two initial buffers should just return the initial buffer itself, i.e.* `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.* 将聚合buffer初始化为0。如果对两个初始化的buffer调用merge方法,得到的还是一个初始化的buffer。*/def initialize(buffer: MutableAggregationBuffer): Unit/*** Updates the given aggregation buffer `buffer` with new input data from `input`.This is called once per input row.* 根据输入row对聚合buffer进行更新。每一条输入调用一次更新。*/def update(buffer: MutableAggregationBuffer, input: Row): Unit/*** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.*This is called when we merge two partially aggregated data together.*两两将部分聚合的buffer进行聚合,通过归并的方法直到剩下最后一个聚合好的buffer*/def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit/*** Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given aggregation buffer.*根据聚合好的buffer计算输出结果。比如对buffer结果调用一个函数得出最后的聚合结果。如果不再需要计算,可以直接将buffer结果作为输出。*/def evaluate(buffer: Row): Any/*** Creates a `Column` for this UDAF using given `Column`s as input arguments.*按照inputSchema中指定的列名为UDAF创建列。*/@scala.annotation.varargsdef apply(exprs: Column*): Column = {val aggregateExpression =AggregateExpression(ScalaUDAF(exprs.map(_.expr), this),Complete,isDistinct = false)Column(aggregateExpression)}/*** Creates a `Column` for this UDAF using the distinct values of the given* `Column`s as input arguments.* 对UDAF的输入数据,按列名进行去重。列名是inputSchema中指定的。*/@scala.annotation.varargsdef distinct(exprs: Column*): Column = {val aggregateExpression =AggregateExpression(ScalaUDAF(exprs.map(_.expr), this),Complete,isDistinct = true)Column(aggregateExpression)}
}

3.2 理解UDAF的定义背后对应的计算实现原理 或数据流

3.3 自己动手写一个UDAF

核心代码:8步法构建一个UDAF

abstract class UserDefinedAggregateFunction extends Serializable {//第1步,定义输入数据的Schemadef inputSchema: StructType//第2步,定义buffer的Schemadef bufferSchema: StructType//第3步,定义UDAF输出结果的数据类型def dataType: DataType//第4步,初始化bufferdef initialize(buffer: MutableAggregationBuffer): Unit//第5步,更新bufferdef update(buffer: MutableAggregationBuffer, input: Row): Unit//第6步,merge buffer进行聚合def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit//第7步,计算输出结果def evaluate(buffer: Row): Any//第8步,设置一致性检验,一般设置为truedef deterministic: Boolean}

UDAF的代码构建示例1: 直接将buffer聚合后的结果作为UDAF的函数结果输出,即UDAF生成的是一个Array的聚合结果

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types._class MyUDAFExample extends UserDefinedAggregateFunction{//第1步,定义输入数据Schemaoverride def inputSchema:StructType = StructType(Seq(StructField("col1",IntegerType),StructField("col2",IntegerType)))//第2步,定义缓存bufferSchemaoverride def bufferSchema: StructType = StructType(Seq(StructField("col1",IntegerType),StructField("col2",IntegerType),StructField("count",LongType)))//第3步,定义输出数据类型override def dataType: DataType = ArrayType(LongType)//第4步,初始化缓存为0override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) = 0buffer(1) = 0buffer(2) = 0}//第5步,更新缓存override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {if (!input.isNullAt(0)) {val col1 = input.getAs[Int](0) //使用索引获取input中的数据,getAs用来保证数据类型的正确性val col2 = input.getAs[Int](1)buffer(0) = buffer.getLong(0) + (if ((col1 == 0 || col1 == 4) && col2 == 1) 1L else 0L)buffer(1) = buffer.getLong(1) + (if (col1 == 1 && col2 == 1) 1L else 0L) //根据规则对buffer进行增量更新buffer(2) = buffer.getLong(1) + (if (col1 == 2 && col2 == 1) 1L else 0L)}}//第6步,聚合缓存override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //定义对缓存的数据进行合并的方法buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)buffer1(2) = buffer1.getLong(2) + buffer2.getLong(2) }//第7步,选择输出结果override def evaluate(buffer: Row): Seq[Long] = Seq(buffer.getLong(0),buffer.getLong(1),buffer.getLong(2))//第8步,设置一致性检验override def deterministic: Boolean = true //一致性检验设置为true
}

注册UDAF

spark.udf.register("myUDAF",new MyUDAFExample)

调用UDAF

spark.sql("""
selectbrand_code,type_code,result[0] as a1_count, --引用UDAF结果中的元素result[1] as a2_count,result[2] as a3_count
from
(selectcode,type_code,myUDAF(status_flag, flow_flag) as result --调用UDAFfrom tb_statusgroup bybrand_code,type_code
)a""")

UDAF的代码构建示例2: 将buffer聚合后的结果作为输入,通过一个函数后,得到UDAF的计算输出结果

此示例可参考:https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._class GeometricMean extends UserDefinedAggregateFunction {// 第1步,定义输入Schema。This is the input fields for your aggregate function.override def inputSchema: org.apache.spark.sql.types.StructType =StructType(StructField("value", DoubleType) :: Nil)//第2步,定义buffer Schema。 This is the internal fields you keep for computing your aggregate.override def bufferSchema: StructType = StructType(StructField("count", LongType) ::StructField("product", DoubleType) :: Nil)//第3步,定义输出Scheme This is the output type of your aggregatation function.override def dataType: DataType = DoubleType//第4步,buffer初始化为0。 This is the initial value for your buffer schema.override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) = 0Lbuffer(1) = 1.0}//第5步,更新buffer This is how to update your buffer schema given an input.override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {buffer(0) = buffer.getAs[Long](0) + 1buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)}//第6步,定义merge buffer的方法。 This is how to merge two objects with the bufferSchema type.override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)}//第7步,定义最终结果的计算方法。 This is where you output the final value, given the final value of your bufferSchema.override def evaluate(buffer: Row): Any = {math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))}//第8步,一致性检验override def deterministic: Boolean = true
}

注册UDAF调用UDAF, 方法同。 

至此,一个完整的手写UDAF的过程就结束了。


推荐阅读
  • 深入剖析Java中SimpleDateFormat在多线程环境下的潜在风险与解决方案
    深入剖析Java中SimpleDateFormat在多线程环境下的潜在风险与解决方案 ... [详细]
  • 零拷贝技术是提高I/O性能的重要手段,常用于Java NIO、Netty、Kafka等框架中。本文将详细解析零拷贝技术的原理及其应用。 ... [详细]
  • 在Android平台中,播放音频的采样率通常固定为44.1kHz,而录音的采样率则固定为8kHz。为了确保音频设备的正常工作,底层驱动必须预先设定这些固定的采样率。当上层应用提供的采样率与这些预设值不匹配时,需要通过重采样(resample)技术来调整采样率,以保证音频数据的正确处理和传输。本文将详细探讨FFMpeg在音频处理中的基础理论及重采样技术的应用。 ... [详细]
  • 使用Maven JAR插件将单个或多个文件及其依赖项合并为一个可引用的JAR包
    本文介绍了如何利用Maven中的maven-assembly-plugin插件将单个或多个Java文件及其依赖项打包成一个可引用的JAR文件。首先,需要创建一个新的Maven项目,并将待打包的Java文件复制到该项目中。通过配置maven-assembly-plugin,可以实现将所有文件及其依赖项合并为一个独立的JAR包,方便在其他项目中引用和使用。此外,该方法还支持自定义装配描述符,以满足不同场景下的需求。 ... [详细]
  • 在Android应用开发中,实现与MySQL数据库的连接是一项重要的技术任务。本文详细介绍了Android连接MySQL数据库的操作流程和技术要点。首先,Android平台提供了SQLiteOpenHelper类作为数据库辅助工具,用于创建或打开数据库。开发者可以通过继承并扩展该类,实现对数据库的初始化和版本管理。此外,文章还探讨了使用第三方库如Retrofit或Volley进行网络请求,以及如何通过JSON格式交换数据,确保与MySQL服务器的高效通信。 ... [详细]
  • 在 Python 中,eval() 函数用于将字符串转换为相应的 Python 表达式。然而,eval() 存在安全风险,因为它会执行任何有效的 Python 代码。相比之下,ast.literal_eval() 只评估有限的表达式,确保安全性。 ... [详细]
  • 字节流(InputStream和OutputStream),字节流读写文件,字节流的缓冲区,字节缓冲流
    字节流抽象类InputStream和OutputStream是字节流的顶级父类所有的字节输入流都继承自InputStream,所有的输出流都继承子OutputStreamInput ... [详细]
  • Flowable 流程图路径与节点展示:已执行节点高亮红色标记,增强可视化效果
    在Flowable流程图中,通常仅显示当前节点,而路径则需自行获取。特别是在多次驳回的情况下,节点可能会出现混乱。本文重点探讨了如何准确地展示流程图效果,包括已结束的流程和正在执行的流程。具体实现方法包括生成带有高亮红色标记的图片,以增强可视化效果,确保用户能够清晰地了解每个节点的状态。 ... [详细]
  • 本文详细介绍了在MySQL中如何高效利用EXPLAIN命令进行查询优化。通过实例解析和步骤说明,文章旨在帮助读者深入理解EXPLAIN命令的工作原理及其在性能调优中的应用,内容通俗易懂且结构清晰,适合各水平的数据库管理员和技术人员参考学习。 ... [详细]
  • 【问题】在Android开发中,当为EditText添加TextWatcher并实现onTextChanged方法时,会遇到一个问题:即使只对EditText进行一次修改(例如使用删除键删除一个字符),该方法也会被频繁触发。这不仅影响性能,还可能导致逻辑错误。本文将探讨这一问题的原因,并提供有效的解决方案,包括使用Handler或计时器来限制方法的调用频率,以及通过自定义TextWatcher来优化事件处理,从而提高应用的稳定性和用户体验。 ... [详细]
  • Java Socket 关键参数详解与优化建议
    Java Socket 的 API 虽然被广泛使用,但其关键参数的用途却鲜为人知。本文详细解析了 Java Socket 中的重要参数,如 backlog 参数,它用于控制服务器等待连接请求的队列长度。此外,还探讨了其他参数如 SO_TIMEOUT、SO_REUSEADDR 等的配置方法及其对性能的影响,并提供了优化建议,帮助开发者提升网络通信的稳定性和效率。 ... [详细]
  • 深入解析Android 4.4中的Fence机制及其应用
    在Android 4.4中,Fence机制是处理缓冲区交换和同步问题的关键技术。该机制广泛应用于生产者-消费者模式中,确保了不同组件之间高效、安全的数据传输。通过深入解析Fence机制的工作原理和应用场景,本文探讨了其在系统性能优化和资源管理中的重要作用。 ... [详细]
  • 使用 ListView 浏览安卓系统中的回收站文件 ... [详细]
  • 本文探讨了如何利用Java代码获取当前本地操作系统中正在运行的进程列表及其详细信息。通过引入必要的包和类,开发者可以轻松地实现这一功能,为系统监控和管理提供有力支持。示例代码展示了具体实现方法,适用于需要了解系统进程状态的开发人员。 ... [详细]
  • Java能否直接通过HTTP将字节流绕过HEAP写入SD卡? ... [详细]
author-avatar
旭89浪子_499
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有