Spark UDFs

Published on Feb 5, 2024

·

5 min read

Blogpost Main Image

Prelude

You have some complex code to run on pyspark that just cannot use native Spark functions. You need to use a UDF but you are not sure which approach is best.

Prerequisites

  • a machine
  • a terminal
  • Java 8 & JAVA_HOME set
  • Python 3.8+ including pip
  • pyspark[sql], see an example here

How to implement the UDFs?

The example is simple: add 1 to a column. We use Spark 3.4.1 in this post.

Setup

from pyspark.sql import SparkSession, types as t

spark = (
    SparkSession
    .builder
    .getOrCreate()
)

schema = t.StructType([t.StructField("integers", t.IntegerType(), True)])
df = spark.createDataFrame([
    (i,) for i in range(100)
], schema=schema)

Native

from pyspark.sql import functions as f

df_native = df.withColumn("integersPlusOne", f.col("integers") + f.lit(1))

Python UDF

from pyspark.sql import types as t
from pyspark.sql.functions import udf

@udf(returnType=t.IntegerType())
def add_one_udf(num: int) -> int:
    return num + 1


df_python_udf = df.withColumn("integersPlusOne", add_one_udf(f.col("integers")))

Pandas UDF

Pandas UDFs are flexible but also carry some additional complexity. By default, Spark splits the columns into batches and passes them to the UDF batch-by-batch, hence the input and output being a Series. However, to avoid skewed computations, it will attempt to shuffle.

import pandas as pd
from pyspark.sql import types as t
from pyspark.sql.functions import pandas_udf

@pandas_udf(t.IntegerType())
def add_one_pandas_udf(nums: pd.Series) -> pd.Series:
    return nums + 1


df_pandas_udf = df.withColumn("integersPlusOne", add_one_pandas_udf(f.col("integers")))

What useful config to know?

Ensure Arrow is used:

spark = (
    SparkSession
    .builder
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .config("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
    .getOrCreate()
)

Other configurations seen here; search for spark.sql.execution.arrow. Most likely need to tune maxRecordsPerBatch which defaults to 10000.

Shuffling is a deep topic, however tuning spark.sql.shuffle.partitions could come in handy too.

Other Pandas UDF types

Several other types are outlined here in the official docs. Most notably, one has more flexibility on how much data each call receives. The options are:

  • Series to Series (the example in this blogpost)
  • Iterator of Series to Iterator of Series
  • Iterator of Multiple Series to Iterator of Series
  • Series to Scalar (aka “UDAF”)

Spark 3.5 also introduces User-Defined Table Functions (UDTFs) which can return an entire table in the output.

Java UDF

This requires the Spark SQL binary on Scala 2.12 which can be fetched from mvn.

  1. Write the UDF in Java. It needs to be in a package, here sparkudfs.
package sparkudfs;

import org.apache.spark.sql.api.java.UDF1;


public class AddOne implements UDF1<Integer, Integer> {
    private static final long serialVersionUID = 1L;
    @Override
    public Integer call(Integer num) throws Exception {
        return num + 1;
    }
}
  1. Compile and package in a jar.
javac -cp spark-sql_2.12-3.4.1.jar:. sparkudfs/AddOne.java
jar cf add-one.jar sparkudfs/AddOne.class
  1. Use the UDF. Can only be used via SQL.
from pyspark.sql import SparkSession, types as t

spark = (
    SparkSession
    .builder
    .config("spark.jars", "add-one.jar")
    .getOrCreate()
)

df.createOrReplaceTempView("df")
spark.udf.registerJavaFunction("addOne", "sparkudfs.AddOne", t.IntegerType())
df_java_udf = spark.sql("SELECT integers, addOne(integers) AS integersPlusOne FROM df")

Scala UDF

This additionally requires Scala 2.12. It follows the same flow as the Java UDF.

package scalaudfs

import org.apache.spark.sql.api.java.UDF1

class AddOne extends UDF1[Int, Int] {

  override def call(x: Int): Int = {
    x + 1
  }
}
scalac -classpath spark-sql_2.12-3.4.1.jar scalaudfs/AddOne.scala
jar cf add-one-scala.jar scalaudfs/AddOne.class
from pyspark.sql import SparkSession, types as t

spark = (
    SparkSession
    .builder
    .config("spark.jars", "add-one-scala.jar")
    .getOrCreate()
)

df.createOrReplaceTempView("df")
spark.udf.registerJavaFunction("addOneScala", "scalaudfs.AddOne", t.IntegerType())
df_scala_udf = spark.sql("SELECT integers, addOneScala(integers) AS integersPlusOne FROM df")

Note that registering the UDF can also be done on the Scala side.

How do the UDFs compare?

Well, let’s look at the physical plans.

Native

== Physical Plan ==
*(1) Project [integers#0, (integers#0 + 1) AS integersPlusOne#20]
+- *(1) Scan ExistingRDD[integers#0]

Python UDF

== Physical Plan ==
*(2) Project [integers#0, pythonUDF0#11 AS integersPlusOne#8]
+- BatchEvalPython [add_one_udf(integers#0)#7], [pythonUDF0#11]
   +- *(1) Scan ExistingRDD[integers#0]

Pandas UDF

== Physical Plan ==
*(2) Project [integers#0, pythonUDF0#16 AS integersPlusOne#13]
+- ArrowEvalPython [add_one_pandas_udf(integers#0)#12], [pythonUDF0#16], 200
   +- *(1) Scan ExistingRDD[integers#0]

Java UDF

== Physical Plan ==
*(1) Project [integers#0, addOne(integers#0) AS integersPlusOne#17]
+- *(1) Scan ExistingRDD[integers#0]

Scala UDF

== Physical Plan ==
*(1) Project [integers#0, addOneScala(integers#0) AS integersPlusOne#20]
+- *(1) Scan ExistingRDD[integers#0]

Wait, aren’t the Java/Scala UDFs the same as the native function?

Almost…

The difference is that the UDF is still a black box to Spark, more specifically to its Catalyst optimizer. Let’s see what happens if we try to add another operation.

spark.sql("SELECT integers + 1 + 1 AS result FROM df").explain()
spark.sql("SELECT addOne(integers) + 1 AS result FROM df").explain()
== Physical Plan ==
*(1) Project [(integers#0 + 2) AS result#35]
+- *(1) Scan ExistingRDD[integers#0]

== Physical Plan ==
*(1) Project [(addOne(integers#0) + 1) AS result#37]
+- *(1) Scan ExistingRDD[integers#0]

Catalyst was able to merge the two native operations into a single “add two”. Internally, this is possible because there is a rule in Catalyst exactly for this scenario. This also means that one could extend the optimizer to include your custom UDF, but that is a topic for another day :)

Addendum

Is going through the hassle of implementing it in Java/Scala really worth it?

There are several benchmarks available online, such as this one, that find even a 15x times improvement on both simple and more complex operations.

Why Java 8 specifically?

Java UDFs compiled on version 11 or higher are not supported as of writing this post. This is in line with Databricks also still using Java 8 for their Runtime’s, see here.

Why use Pandas UDF?

It is the middle ground between pure Python and Java/Scala UDFs. Being vectorized, Pandas UDFs are expected to be faster than pure Python but still slower than Java/Scala UDFs. This article goes deeper into how pyspark works and what contributes to data operations being inherently slower.

There are also some other valid use cases, such as using functionality already implemented in Pandas, or libs that work with Pandas data structures. One could also perform expensive Python initialization less frequently on the workers, such as in this example.

Can we find this code somewhere?

Yes, it is posted here.

Notice something wrong? Have an additional tip?

Contribute to the discussion here