Spark SQL Window with UDAF function

This short topic show full example of using Spark Windows with UDAF

1) I have Cassandra cluster with column family that will be used as our source data


2) I want to calculate average "disp" by the window from current rows plus two next rows.
  SparkWindow

3) Scala code to load source Dataset from Cassandra, I exclude some columns, take just part of data (ticker=1 and width=30) and sort it ascending by ts_begin (this is a unix_timestamp from the external system)



  def getBarsFromCass(TickerID :Int, BarWidthSec :Int) = {
    import org.apache.spark.sql.functions._
    spark.read.format("org.apache.spark.sql.cassandra")
      .options(Map("table" -> "bars", "keyspace" -> "mts_bars"))
      .load()
      .where(col("ticker_id")     === TickerID &&
             col("bar_width_sec") === BarWidthSec)
      .select(col("ts_begin"), col("btype"), col("disp"), col("log_co"))
      .sort(asc("ts_begin"))
  }



Spark Session



  val spark = SparkSession.builder()
    .master("local[*]")
    .appName("PattSearch")
    .config("spark.cassandra.connection.host","10.241.5.234")
    .config("spark.submit.deployMode","client")
    .config("spark.shuffle.service.enabled", "false")
    .config("spark.dynamicAllocation.enabled", "false")
    .config("spark.driver.allowMultipleContexts","true")
    .config("spark.cassandra.input.split.size_in_mb","128")
    .config("spark.cassandra.input.fetch.size_in_rows","10000")
    .config("spark.driver.cores","1")
    .config("spark.cores.max","2")
    .config("spark.driver.memory","1g")
    .config("spark.executor.memory", "1g")
    .config("spark.executor.cores","1")
    .getOrCreate()



Scala code, we get DataSet from Cassandra and print some information



  val t1_common = System.currentTimeMillis

  val listBars = getBarsFromCass(1,30)
  listBars.printSchema()
  
  otocLogg.log.info("listBars.count()=["+listBars.count()+"]")
  
  listBars.take(10) foreach{
   thisRow =>
    println(" -> "+
            thisRow.getAs("ts_begin").toString+" "+
            thisRow.getAs("btype").toString+" "+
            thisRow.getAs("disp").toString+" "+
            thisRow.getAs("log_co").toString)
  }

> OUTPUT:
root
 |-- ts_begin: long (nullable = true)
 |-- btype: string (nullable = true)
 |-- disp: double (nullable = true)
 |-- log_co: double (nullable = true)

 INFO  [02.11.2018 16:47:06,812]  [main] (PattSearch.scala:136)   listBars.count()=[28173]

 -> 1535536647 g 3.9E-5 4.0E-5
 -> 1535536677 r 4.8E-5 -1.5E-4
 -> 1535536707 r 6.7E-5 -1.5E-4
 -> 1535536737 g 4.2E-5 1.0E-4
 -> 1535536770 r 1.9E-5 -6.0E-5
 -> 1535536797 g 1.66E-4 4.6E-4
 -> 1535536827 g 1.13E-4 2.1E-4
 -> 1535536857 r 5.1E-5 -1.2E-4
 -> 1535536887 g 8.6E-5 2.4E-4
 -> 1535536919 r 6.2E-5 -1.7E-4



UDAF code



  class ComparePatter() extends UserDefinedAggregateFunction {

    // Input Data Type Schema of Rows.
    def inputSchema: StructType = StructType(Array(
                                                   StructField("ts_begin", IntegerType),
                                                   StructField("btype",    StringType),
                                                   StructField("disp",     DoubleType),
                                                   StructField("log_co",   DoubleType)
                                                  )
                                            )

    // Intermediate Schema
    def bufferSchema = StructType(Array(
      StructField("sum", DoubleType),
      StructField("cnt", LongType)
    ))

    // Returned Data Type .
    def dataType: DataType = DoubleType

    // Self-explaining
    def deterministic = true

    // This function is called whenever key changes
    def initialize(buffer: MutableAggregationBuffer) = {
      buffer(0) = 0.toDouble // set sum to zero
      buffer(1) = 0L // set number of items to 0
    }

    // Iterate over each entry of a group
    def update(buffer: MutableAggregationBuffer, input: Row) = {
      // With [0] - java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double
      // Because 0 it's a ts_begin (Int)
      buffer(0) = buffer.getDouble(0) + input.getDouble(2)
      buffer(1) = buffer.getLong(1) + 1
    }

    // Merge two partial aggregates
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(2) // get field disp
      buffer1(1) = buffer1.getLong(1) + 1//buffer2.getLong(1)
    }

    // Called after all the entries are exhausted.
    def evaluate(buffer: Row) = {
      buffer.getDouble(0)/buffer.getLong(1).toDouble
    }

  }



Code of using Spark Window with UDAF



  import org.apache.spark.sql.expressions.Window
  import org.apache.spark.sql.functions.col

  val windowSpec = Window
    .orderBy(col("ts_begin").asc)
    .rowsBetween(Window.currentRow, 2)

  val compPatt = new ComparePatter()

  listBars.withColumn("comp_patt", compPatt(listBars.col("ts_begin"),
                                                      listBars.col("btype"),
                                                      listBars.col("disp"),
                                                      listBars.col("log_co")
                                                     ) over windowSpec).show()



For each row spark sents the whole row into "compPatt" during Window frame. The output of show is:


+----------+-----+-------+-------+--------------------+
|  ts_begin|btype|   disp| log_co|           comp_patt|
+----------+-----+-------+-------+--------------------+
|1535536647|    g| 3.9E-5| 4.0E-5|5.133333333333333E-5|
|1535536677|    r| 4.8E-5|-1.5E-4|5.233333333333333E-5|
|1535536707|    r| 6.7E-5|-1.5E-4|4.266666666666666E-5|
|1535536737|    g| 4.2E-5| 1.0E-4|7.566666666666666E-5|
|1535536770|    r| 1.9E-5|-6.0E-5|9.933333333333333E-5|
|1535536797|    g|1.66E-4| 4.6E-4|              1.1E-4|
|1535536827|    g|1.13E-4| 2.1E-4|8.333333333333333E-5|
|1535536857|    r| 5.1E-5|-1.2E-4|6.633333333333334E-5|
|1535536887|    g| 8.6E-5| 2.4E-4|6.800000000000001E-5|
|1535536919|    r| 6.2E-5|-1.7E-4|5.233333333333333E-5|
|1535536947|    g| 5.6E-5| 9.0E-5|5.400000000000000...|
|1535536977|    g| 3.9E-5| 1.2E-4|              6.1E-5|
|1535537009|    g| 6.7E-5| 1.5E-4|6.133333333333334E-5|
|1535537037|    r| 7.7E-5|-2.1E-4|6.366666666666666E-5|
|1535537067|    g| 4.0E-5| 9.0E-5|              5.8E-5|
|1535537097|    g| 7.4E-5| 1.1E-4|5.566666666666667E-5|
|1535537127|    r| 6.0E-5|-3.0E-5|4.566666666666667...|
|1535537157|    r| 3.3E-5|-6.0E-5|4.299999999999999...|
|1535537187|    n| 4.4E-5|    0.0|4.166666666666666...|
|1535537217|    g| 5.2E-5| 1.2E-4|3.933333333333333E-5|
+----------+-----+-------+-------+--------------------+


Little explanation,

Value comp_patt in first rows calculated as current row + 2 following (look at .rowsBetween(Window.currentRow, 2))

comp_patt = (3.9E-5 + 4.8E-5 + 6.7E-5)/3  = 5.133333333333333E-5

for the second row,

comp_patt = (4.8E-5 + 6.7E-5 + 4.2E-5)/3  = 5.233333333333333E-5

Next, my idea is not calculating simple average but accumulate N rows inside comp_patt and on the last row in the group make a comparison with external pattern - PATTERN SEARCH.

p.s. build.sbt


name := "PattSearch"
version := "0.1"
scalaVersion := "2.11.8"
version := "1.0"

val sparkVersion = "2.3.0"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-hive" % sparkVersion % "provided",
  "org.apache.spark" %% "spark-core" % sparkVersion % "provided",
  "org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
  "org.apache.spark" %% "spark-mllib" % sparkVersion,
  "com.datastax.spark" %% "spark-cassandra-connector" % "2.3.2"
)

assemblyMergeStrategy in assembly := {
  case PathList("org","aopalliance", xs @ _*) => MergeStrategy.last
  case PathList("javax", "inject", xs @ _*) => MergeStrategy.last
  case PathList("javax", "servlet", xs @ _*) => MergeStrategy.last
  case PathList("javax", "activation", xs @ _*) => MergeStrategy.last
  case PathList("org", "apache", xs @ _*) => MergeStrategy.last
  case PathList("com", "google", xs @ _*) => MergeStrategy.last
  case PathList("com", "esotericsoftware", xs @ _*) => MergeStrategy.last
  case PathList("com", "codahale", xs @ _*) => MergeStrategy.last
  case PathList("com", "yammer", xs @ _*) => MergeStrategy.last
  case "about.html" => MergeStrategy.rename
  case "META-INF/ECLIPSEF.RSA" => MergeStrategy.last
  case "META-INF/mailcap" => MergeStrategy.last
  case "META-INF/mimetypes.default" => MergeStrategy.last
  case "plugin.properties" => MergeStrategy.last
  case "log4j.properties" => MergeStrategy.last
  case x =>
    val oldStrategy = (assemblyMergeStrategy in assembly).value
    oldStrategy(x)
}

assemblyJarName in assembly := "kmeans_v1.jar"
mainClass in assembly := Some("PattSearch")


You can see full code example on github

Little bit more description:

This class implements the functionality of the Spark org.apache.spark.sql.expressions.UserDefinedAggregateFunction abstract class, and therefore, the functionality of the following should be implemented.

inputSchema: The schema of the input rows.
bufferSchema: The schema of intermediate results.
dataType: The datatype of the final result.
Deterministic: This denotes whether the same inputs always produce the same results.
initialize(): This is called once per node for a given group.
update(): This is called once per input record.
merge(): This is called to compute partial results and combine them together.
evaluate(): This is called to compute the final result.





Комментарии

Популярные сообщения из этого блога

Hadoop 3.0 cluster - installation, configuration, tests on Cent OS 7

Loading data into Spark from Oracle RDBMS, CSV

Load data from Cassandra to HDFS parquet files and select with Hive