Spark SQL Window with UDAF function

This short topic show full example of using Spark Windows with UDAF

1) I have Cassandra cluster with column family that will be used as our source data


2) I want to calculate average "disp" by the window from current rows plus two next rows.
  SparkWindow

3) Scala code to load source Dataset from Cassandra, I exclude some columns, take just part of data (ticker=1 and width=30) and sort it ascending by ts_begin (this is a unix_timestamp from the external system)



  def getBarsFromCass(TickerID :Int, BarWidthSec :Int) = {
    import org.apache.spark.sql.functions._
    spark.read.format("org.apache.spark.sql.cassandra")
      .options(Map("table" -> "bars", "keyspace" -> "mts_bars"))
      .load()
      .where(col("ticker_id")     === TickerID &&
             col("bar_width_sec") === BarWidthSec)
      .select(col("ts_begin"), col("btype"), col("disp"), col("log_co"))
      .sort(asc("ts_begin"))
  }



Spark Session



  val spark = SparkSession.builder()
    .master("local[*]")
    .appName("PattSearch")
    .config("spark.cassandra.connection.host","10.241.5.234")
    .config("spark.submit.deployMode","client")
    .config("spark.shuffle.service.enabled", "false")
    .config("spark.dynamicAllocation.enabled", "false")
    .config("spark.driver.allowMultipleContexts","true")
    .config("spark.cassandra.input.split.size_in_mb","128")
    .config("spark.cassandra.input.fetch.size_in_rows","10000")
    .config("spark.driver.cores","1")
    .config("spark.cores.max","2")
    .config("spark.driver.memory","1g")
    .config("spark.executor.memory", "1g")
    .config("spark.executor.cores","1")
    .getOrCreate()



Scala code, we get DataSet from Cassandra and print some information



  val t1_common = System.currentTimeMillis

  val listBars = getBarsFromCass(1,30)
  listBars.printSchema()
  
  otocLogg.log.info("listBars.count()=["+listBars.count()+"]")
  
  listBars.take(10) foreach{
   thisRow =>
    println(" -> "+
            thisRow.getAs("ts_begin").toString+" "+
            thisRow.getAs("btype").toString+" "+
            thisRow.getAs("disp").toString+" "+
            thisRow.getAs("log_co").toString)
  }

> OUTPUT:
root
 |-- ts_begin: long (nullable = true)
 |-- btype: string (nullable = true)
 |-- disp: double (nullable = true)
 |-- log_co: double (nullable = true)

 INFO  [02.11.2018 16:47:06,812]  [main] (PattSearch.scala:136)   listBars.count()=[28173]

 -> 1535536647 g 3.9E-5 4.0E-5
 -> 1535536677 r 4.8E-5 -1.5E-4
 -> 1535536707 r 6.7E-5 -1.5E-4
 -> 1535536737 g 4.2E-5 1.0E-4
 -> 1535536770 r 1.9E-5 -6.0E-5
 -> 1535536797 g 1.66E-4 4.6E-4
 -> 1535536827 g 1.13E-4 2.1E-4
 -> 1535536857 r 5.1E-5 -1.2E-4
 -> 1535536887 g 8.6E-5 2.4E-4
 -> 1535536919 r 6.2E-5 -1.7E-4



UDAF code



  class ComparePatter() extends UserDefinedAggregateFunction {

    // Input Data Type Schema of Rows.
    def inputSchema: StructType = StructType(Array(
                                                   StructField("ts_begin", IntegerType),
                                                   StructField("btype",    StringType),
                                                   StructField("disp",     DoubleType),
                                                   StructField("log_co",   DoubleType)
                                                  )
                                            )

    // Intermediate Schema
    def bufferSchema = StructType(Array(
      StructField("sum", DoubleType),
      StructField("cnt", LongType)
    ))

    // Returned Data Type .
    def dataType: DataType = DoubleType

    // Self-explaining
    def deterministic = true

    // This function is called whenever key changes
    def initialize(buffer: MutableAggregationBuffer) = {
      buffer(0) = 0.toDouble // set sum to zero
      buffer(1) = 0L // set number of items to 0
    }

    // Iterate over each entry of a group
    def update(buffer: MutableAggregationBuffer, input: Row) = {
      // With [0] - java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double
      // Because 0 it's a ts_begin (Int)
      buffer(0) = buffer.getDouble(0) + input.getDouble(2)
      buffer(1) = buffer.getLong(1) + 1
    }

    // Merge two partial aggregates
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(2) // get field disp
      buffer1(1) = buffer1.getLong(1) + 1//buffer2.getLong(1)
    }

    // Called after all the entries are exhausted.
    def evaluate(buffer: Row) = {
      buffer.getDouble(0)/buffer.getLong(1).toDouble
    }

  }



Code of using Spark Window with UDAF



  import org.apache.spark.sql.expressions.Window
  import org.apache.spark.sql.functions.col

  val windowSpec = Window
    .orderBy(col("ts_begin").asc)
    .rowsBetween(Window.currentRow, 2)

  val compPatt = new ComparePatter()

  listBars.withColumn("comp_patt", compPatt(listBars.col("ts_begin"),
                                                      listBars.col("btype"),
                                                      listBars.col("disp"),
                                                      listBars.col("log_co")
                                                     ) over windowSpec).show()



For each row spark sents the whole row into "compPatt" during Window frame. The output of show is:


+----------+-----+-------+-------+--------------------+
|  ts_begin|btype|   disp| log_co|           comp_patt|
+----------+-----+-------+-------+--------------------+
|1535536647|    g| 3.9E-5| 4.0E-5|5.133333333333333E-5|
|1535536677|    r| 4.8E-5|-1.5E-4|5.233333333333333E-5|
|1535536707|    r| 6.7E-5|-1.5E-4|4.266666666666666E-5|
|1535536737|    g| 4.2E-5| 1.0E-4|7.566666666666666E-5|
|1535536770|    r| 1.9E-5|-6.0E-5|9.933333333333333E-5|
|1535536797|    g|1.66E-4| 4.6E-4|              1.1E-4|
|1535536827|    g|1.13E-4| 2.1E-4|8.333333333333333E-5|
|1535536857|    r| 5.1E-5|-1.2E-4|6.633333333333334E-5|
|1535536887|    g| 8.6E-5| 2.4E-4|6.800000000000001E-5|
|1535536919|    r| 6.2E-5|-1.7E-4|5.233333333333333E-5|
|1535536947|    g| 5.6E-5| 9.0E-5|5.400000000000000...|
|1535536977|    g| 3.9E-5| 1.2E-4|              6.1E-5|
|1535537009|    g| 6.7E-5| 1.5E-4|6.133333333333334E-5|
|1535537037|    r| 7.7E-5|-2.1E-4|6.366666666666666E-5|
|1535537067|    g| 4.0E-5| 9.0E-5|              5.8E-5|
|1535537097|    g| 7.4E-5| 1.1E-4|5.566666666666667E-5|
|1535537127|    r| 6.0E-5|-3.0E-5|4.566666666666667...|
|1535537157|    r| 3.3E-5|-6.0E-5|4.299999999999999...|
|1535537187|    n| 4.4E-5|    0.0|4.166666666666666...|
|1535537217|    g| 5.2E-5| 1.2E-4|3.933333333333333E-5|
+----------+-----+-------+-------+--------------------+


Little explanation,

Value comp_patt in first rows calculated as current row + 2 following (look at .rowsBetween(Window.currentRow, 2))

comp_patt = (3.9E-5 + 4.8E-5 + 6.7E-5)/3  = 5.133333333333333E-5

for the second row,

comp_patt = (4.8E-5 + 6.7E-5 + 4.2E-5)/3  = 5.233333333333333E-5

Next, my idea is not calculating simple average but accumulate N rows inside comp_patt and on the last row in the group make a comparison with external pattern - PATTERN SEARCH.

p.s. build.sbt


name := "PattSearch"
version := "0.1"
scalaVersion := "2.11.8"
version := "1.0"

val sparkVersion = "2.3.0"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-hive" % sparkVersion % "provided",
  "org.apache.spark" %% "spark-core" % sparkVersion % "provided",
  "org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
  "org.apache.spark" %% "spark-mllib" % sparkVersion,
  "com.datastax.spark" %% "spark-cassandra-connector" % "2.3.2"
)

assemblyMergeStrategy in assembly := {
  case PathList("org","aopalliance", xs @ _*) => MergeStrategy.last
  case PathList("javax", "inject", xs @ _*) => MergeStrategy.last
  case PathList("javax", "servlet", xs @ _*) => MergeStrategy.last
  case PathList("javax", "activation", xs @ _*) => MergeStrategy.last
  case PathList("org", "apache", xs @ _*) => MergeStrategy.last
  case PathList("com", "google", xs @ _*) => MergeStrategy.last
  case PathList("com", "esotericsoftware", xs @ _*) => MergeStrategy.last
  case PathList("com", "codahale", xs @ _*) => MergeStrategy.last
  case PathList("com", "yammer", xs @ _*) => MergeStrategy.last
  case "about.html" => MergeStrategy.rename
  case "META-INF/ECLIPSEF.RSA" => MergeStrategy.last
  case "META-INF/mailcap" => MergeStrategy.last
  case "META-INF/mimetypes.default" => MergeStrategy.last
  case "plugin.properties" => MergeStrategy.last
  case "log4j.properties" => MergeStrategy.last
  case x =>
    val oldStrategy = (assemblyMergeStrategy in assembly).value
    oldStrategy(x)
}

assemblyJarName in assembly := "kmeans_v1.jar"
mainClass in assembly := Some("PattSearch")


You can see full code example on github

Little bit more description:

This class implements the functionality of the Spark org.apache.spark.sql.expressions.UserDefinedAggregateFunction abstract class, and therefore, the functionality of the following should be implemented.

inputSchema: The schema of the input rows.
bufferSchema: The schema of intermediate results.
dataType: The datatype of the final result.
Deterministic: This denotes whether the same inputs always produce the same results.
initialize(): This is called once per node for a given group.
update(): This is called once per input record.
merge(): This is called to compute partial results and combine them together.
evaluate(): This is called to compute the final result.





Комментарии

Популярные сообщения из этого блога

Loading data into Spark from Oracle RDBMS, CSV

Load data from Cassandra to HDFS parquet files and select with Hive

Hadoop 3.0 cluster - installation, configuration, tests on Cent OS 7