Overview

In this post I will describe how I’ve created an Apache Spark JdbcDialect and how I used it in PySpark.

Lately there was a request from our ML engineers to supply some unified framework that will allow them to run the same code when in some cases the source should be Presto and for some cases the same code should use GLUE catalog/S3. Same GLUE catalog is also used in Presto.

Our Datalake contains some complex events, with complex structures, that contain Arrays and Structs.

Currently in spark by default there is no support for arrays when using JDBC,
JdbcUtils. To be able to use arrays directly with Presto and spark, we had to create a custom dialect.

In the following sections I will describe how I have created one in scala, and how we could use it in PySpark.

Creating the custom dialect

So it is pretty simple, as shown in the overview we need to create a class that extends Spark’s JdbcDialect. It needs to override the canHandle method that describes which jdbc url the custom dialect can handle.

And also for our case we also needed to override the getCatalystType that should convert Presto’s types into Spark types. Most of spark default type conversions is good enough. But as noted earlier the Array type is not supported by default, so we need to create the conversion.

Presto’s Array type looks as following: ARRAY(inner_type). For example ARRAY(INTEGER).

The following snippet is the PrestoDialect that I created.

package my.company.jdbc

import java.sql.Types
import java.util.Locale

import org.apache.spark.sql.jdbc.JdbcDialect
import org.apache.spark.sql.types._

class PrestoDialect extends JdbcDialect {
    private val arrayTypePat = """ARRAY\(([\w\d\(\)].*)\)"""

    override def canHandle(url: String): Boolean =
        url.toLowerCase(Locale.ROOT).startsWith("jdbc:presto")

    override def getCatalystType(sqlType: Int,
                               typeName: String,
                               size: Int,
                               md: MetadataBuilder): Option[DataType] = {
        if (sqlType == Types.ARRAY) {
            val innerType = stripType(arrayTypePat, typeName) match {
                case Some(inner) if inner.startsWith("ARRAY(") =>
                    throw new IllegalArgumentException(s"Nested arrays are not supported: $typeName")
                case None => throw new IllegalArgumentException(s"The type $typeName is unsupported in ARRAY")
                case Some(inner) => inner
            }

            toCatalystType(innerType).map(ArrayType(_))
        } else None
    }

    /**
    * Strip outer types of the provided type.
    * @param regexPat The regex pattern to use for stripping
    * @param typeName The type to use the strip regex on.
    * @return Inner type of the stripped type.
    */
    private def stripType(regexPat: String, typeName: String): Option[String] = {
        val Pattern = regexPat.r

        typeName match {
            case Pattern(inner) => Some(inner)
            case _ => None
        }
    }

    /**
    * Convert Presto's basic type into spark type.
    * @param typeName Presto's type.
    * @return Spark type.
    */
    private def toCatalystType(typeName: String): Option[DataType] = typeName match {
        case "BOOLEAN" => Some(BooleanType)
        case "INTEGER" => Some(IntegerType)
        case "BIGINT" => Some(LongType)
        case "REAL" => Some(FloatType)
        case "DOUBLE" => Some(DoubleType)
        case "varchar" => Some(StringType)
        case x if x.startsWith("VARCHAR") => Some(StringType)
        case x if x.startsWith("DECIMAL") => throw new IllegalArgumentException("Unsupported Array element DECIMAL. " +
            "Please cast array to REAL or DOUBLE.")
        case _ => None
    }
}

Well, after creating this class, compile and package into a jar and now let’s move to how we can use this class to query presto arrays through pyspark.

Use in PySpark

In order for us to use the PrestoDialect class inside spark, we should load the jar from the previous step, for the example I’ve called it: spark-presto-dialect_2.12-1.0.jar and then import the class itself using py4j. And of course we need to load Presto’s JDBC Driver.

Moreover we need to register the custom dialect into spark, using the JdbcDialects.registerDialect method.

All will be shown in the next code snippet:

from pyspark.sql import SparkSession
from py4j.java_gateway import java_import


def get_sparksession():
    return SparkSession\
        .builder\
        .appName("PySpark-Presto")\
        .config("spark.driver.extraClassPath", "../presto-jdbc-0.234.1.jar:../spark-presto-dialect_2.12-1.0.jar")\
        .getOrCreate()

def get_spark_presto_session(spark):
    return spark.read \
        .format("jdbc") \
        .option("driver", "com.facebook.presto.jdbc.PrestoDriver") \
        .option("url", "jdbc:presto://localhost:8889/hive/default") \
        .option("user", "hadoop")\
        .option("numPartitions", "4")


def load_presto_query(spark_presto):
    return spark_presto\
        .option("query", get_sql()) \
        .load()

def get_sql():
    return '''
    SELECT
        ARRAY[1,2,3]
    FROM hive.some_glue_db.some_glue_table 
    '''

if __name__ == "__main__":
    spark = get_sparksession()
    gw = spark.sparkContext._gateway
    java_import(gw.jvm, "my.company.jdbc.PrestoDialect")

    gw.jvm.org.apache.spark.sql.jdbc.JdbcDialects.registerDialect(gw.jvm.com.here.mobility.data.jdbc.PrestoDialect())
    presto_df = load_presto_query(get_spark_presto_session(spark))

    presto_df.show()
    presto_df.printSchema()

Well now our ML Engineers are able to load presto arrays directly to spark’s DataFrames.

Thanks a lot for reading.