1、什么情况下使用UDAF?
当官方提供的聚合函数不能满足需求的时候可以考虑自己写一个UDAF。但通常官方提供的聚合函数基本是能满足常用的开发需求的了。
2、怎么注册UDAF?注册后怎么调用UDAF?
2.1 注册UDAF
spark.udf.register("myUDAF", new MyUDAFExample)
2.2 调用UDAF
2.2.1 SparkSql中调用
// 使用group by 做分组,然后调用 UDAF聚合函数.
spark.sql("""select group_id, myUDAFExampe(id) from simple group by group_id""")
2.2.2 使用DataFrame语法调用
// 在DataFrame语法中使用UDAF//创建实例
val myUDAF = new MyUDAFExample// groupBy 分组后,在agg()中调用UDAF
df.groupBy("group_id").agg(myUDAF(col("id")).as("udafCnt")).show()// groupBy 分组后,在agg(expr())表达式中调用UDAF
df.groupBy("group_id").agg(expr("myUDAF(id) as udafCnt")).show()
3、怎么写UDAF?
3.1 首先我们看看官方的UDAF的抽象类是怎么定义的
abstract class UserDefinedAggregateFunction extends Serializable {/*** A `StructType` represents data types of input arguments of this aggregate function.* For example, if a [[UserDefinedAggregateFunction]] expects two input arguments* with type of `DoubleType` and `LongType`, the returned `StructType` will look like** ```* new StructType()* .add("doubleInput", DoubleType)* .add("longInput", LongType)* ```** The name of a field of this `StructType` is only used to identify the corresponding input argument. Users can choose names to identify the input arguments.** 定义输入row的Schema,可以用来识别对应的输入参数*/def inputSchema: StructType/*** A `StructType` represents data types of values in the aggregation buffer.* For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values* (i.e. two intermediate values) with type of `DoubleType` and `LongType`,* the returned `StructType` will look like** ```* new StructType()* .add("doubleInput", DoubleType)* .add("longInput", LongType)* ```** The name of a field of this `StructType` is only used to identify the corresponding buffer value. Users can choose names to identify the input arguments.** 定义buffer的Schema。根据Schema中定义的列名来识别其在buffer中的值。*/def bufferSchema: StructType/*** The `DataType` of the returned value of this [[UserDefinedAggregateFunction]].* 定义UDAF最终返回结果的数据类型。*/def dataType: DataType/*** Returns true iff this function is deterministic, i.e. given the same input,* always return the same output.*一致性检验,即要求计算结果是严格准确的。对于相同的输入,不管计算几次,结果都是一样的,则设置为true,否则设置为false。false对应的是近似计算。*/def deterministic: Boolean/*** Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.* The contract should be that applying the merge function on two initial buffers should just return the initial buffer itself, i.e.* `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.* 将聚合buffer初始化为0。如果对两个初始化的buffer调用merge方法,得到的还是一个初始化的buffer。*/def initialize(buffer: MutableAggregationBuffer): Unit/*** Updates the given aggregation buffer `buffer` with new input data from `input`.This is called once per input row.* 根据输入row对聚合buffer进行更新。每一条输入调用一次更新。*/def update(buffer: MutableAggregationBuffer, input: Row): Unit/*** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.*This is called when we merge two partially aggregated data together.*两两将部分聚合的buffer进行聚合,通过归并的方法直到剩下最后一个聚合好的buffer*/def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit/*** Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given aggregation buffer.*根据聚合好的buffer计算输出结果。比如对buffer结果调用一个函数得出最后的聚合结果。如果不再需要计算,可以直接将buffer结果作为输出。*/def evaluate(buffer: Row): Any/*** Creates a `Column` for this UDAF using given `Column`s as input arguments.*按照inputSchema中指定的列名为UDAF创建列。*/@scala.annotation.varargsdef apply(exprs: Column*): Column = {val aggregateExpression =AggregateExpression(ScalaUDAF(exprs.map(_.expr), this),Complete,isDistinct = false)Column(aggregateExpression)}/*** Creates a `Column` for this UDAF using the distinct values of the given* `Column`s as input arguments.* 对UDAF的输入数据,按列名进行去重。列名是inputSchema中指定的。*/@scala.annotation.varargsdef distinct(exprs: Column*): Column = {val aggregateExpression =AggregateExpression(ScalaUDAF(exprs.map(_.expr), this),Complete,isDistinct = true)Column(aggregateExpression)}
}
3.2 理解UDAF的定义背后对应的计算实现原理 或数据流
3.3 自己动手写一个UDAF
核心代码:8步法构建一个UDAF
abstract class UserDefinedAggregateFunction extends Serializable {//第1步,定义输入数据的Schemadef inputSchema: StructType//第2步,定义buffer的Schemadef bufferSchema: StructType//第3步,定义UDAF输出结果的数据类型def dataType: DataType//第4步,初始化bufferdef initialize(buffer: MutableAggregationBuffer): Unit//第5步,更新bufferdef update(buffer: MutableAggregationBuffer, input: Row): Unit//第6步,merge buffer进行聚合def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit//第7步,计算输出结果def evaluate(buffer: Row): Any//第8步,设置一致性检验,一般设置为truedef deterministic: Boolean}
UDAF的代码构建示例1: 直接将buffer聚合后的结果作为UDAF的函数结果输出,即UDAF生成的是一个Array的聚合结果
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types._class MyUDAFExample extends UserDefinedAggregateFunction{//第1步,定义输入数据Schemaoverride def inputSchema:StructType = StructType(Seq(StructField("col1",IntegerType),StructField("col2",IntegerType)))//第2步,定义缓存bufferSchemaoverride def bufferSchema: StructType = StructType(Seq(StructField("col1",IntegerType),StructField("col2",IntegerType),StructField("count",LongType)))//第3步,定义输出数据类型override def dataType: DataType = ArrayType(LongType)//第4步,初始化缓存为0override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) = 0buffer(1) = 0buffer(2) = 0}//第5步,更新缓存override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {if (!input.isNullAt(0)) {val col1 = input.getAs[Int](0) //使用索引获取input中的数据,getAs用来保证数据类型的正确性val col2 = input.getAs[Int](1)buffer(0) = buffer.getLong(0) + (if ((col1 == 0 || col1 == 4) && col2 == 1) 1L else 0L)buffer(1) = buffer.getLong(1) + (if (col1 == 1 && col2 == 1) 1L else 0L) //根据规则对buffer进行增量更新buffer(2) = buffer.getLong(1) + (if (col1 == 2 && col2 == 1) 1L else 0L)}}//第6步,聚合缓存override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //定义对缓存的数据进行合并的方法buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)buffer1(2) = buffer1.getLong(2) + buffer2.getLong(2) }//第7步,选择输出结果override def evaluate(buffer: Row): Seq[Long] = Seq(buffer.getLong(0),buffer.getLong(1),buffer.getLong(2))//第8步,设置一致性检验override def deterministic: Boolean = true //一致性检验设置为true
}
注册UDAF
spark.udf.register("myUDAF",new MyUDAFExample)
调用UDAF
spark.sql("""
selectbrand_code,type_code,result[0] as a1_count, --引用UDAF结果中的元素result[1] as a2_count,result[2] as a3_count
from
(selectcode,type_code,myUDAF(status_flag, flow_flag) as result --调用UDAFfrom tb_statusgroup bybrand_code,type_code
)a""")
UDAF的代码构建示例2: 将buffer聚合后的结果作为输入,通过一个函数后,得到UDAF的计算输出结果
此示例可参考:https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._class GeometricMean extends UserDefinedAggregateFunction {// 第1步,定义输入Schema。This is the input fields for your aggregate function.override def inputSchema: org.apache.spark.sql.types.StructType =StructType(StructField("value", DoubleType) :: Nil)//第2步,定义buffer Schema。 This is the internal fields you keep for computing your aggregate.override def bufferSchema: StructType = StructType(StructField("count", LongType) ::StructField("product", DoubleType) :: Nil)//第3步,定义输出Scheme This is the output type of your aggregatation function.override def dataType: DataType = DoubleType//第4步,buffer初始化为0。 This is the initial value for your buffer schema.override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) = 0Lbuffer(1) = 1.0}//第5步,更新buffer This is how to update your buffer schema given an input.override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {buffer(0) = buffer.getAs[Long](0) + 1buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)}//第6步,定义merge buffer的方法。 This is how to merge two objects with the bufferSchema type.override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)}//第7步,定义最终结果的计算方法。 This is where you output the final value, given the final value of your bufferSchema.override def evaluate(buffer: Row): Any = {math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))}//第8步,一致性检验override def deterministic: Boolean = true
}
注册UDAF调用UDAF, 方法同。
至此,一个完整的手写UDAF的过程就结束了。