I was recently thinking about how we should write Spark code using the Dataframe API. It turns out that there are a lot of different choices you can make and, sometimes, innocuous looking ones can bite you in the long run.
The Question Before The Question: DataFrame API Or Spark SQL?
(TL;DR: Use the DataFrame API!)
Before beginning: I assume you created the Spark context as spark
and that you have a dataframe called df
.
The first point to consider is that you can write code using the Dataframe API like this1:
df = ... # IO here
df_my_column = df.select("my_column")
Or using Spark SQL, like this
df_my_column = spark.sql("SELECT my_column FROM input")
The are advantages and disadvantages with both methods.
The Spark SQL method is very familiar to all analysts and people fluent in SQL. As a drawback, even though it returns a dataframe, it needs every dataframe to be registered as a temporary view before allowing it to be queried:
df = source_df.select('a_column')
try:
spark.sql("select mean(a_column) from df")
except: # a py4j exception is raised here
print("It can't find df")
df.createOrReplaceTempView("df")
spark.sql("select mean(a_column) from df") # now it works
The Dataframe API is much more concise:
import pyspark.sql.functions as sf
df = source_df.select('a_column')
df.select(sf.mean('a_column'))
On the other hand, it can get quite involved and "scary"
from pyspark.sql import Window
d_types = ...
c_types = ...
df.withColumn('type',
sf.when(sf.sum(sf.col('vehicle').isin(d_types).cast('Int'))
.over(Window.partitionBy('id')) > 0, 'd_type')
.when(sf.col('vehicle').isin(c_types), 'c_type')
.otherwise('other_type')))
(In all fairness, writing the above bit in SQL would also be quite daunting.)
But for me the real advantage comes from composing and dealing with objects in a more abstract way. The above snippet of code should, ideally, be a function:
def my_function(df, d_types, c_types):
return df.withColumn('type',
sf.when(sf.sum(sf.col('vehicle').isin(d_types).cast('Int'))
.over(Window.partitionBy('id')) > 0, 'd_type')
.when(sf.col('vehicle').isin(c_types), 'c_type')
.otherwise('other_type')))
If I were to rewrite that in Spark SQL, I’d have to do the following
def my_function(df, d_types, c_types):
# do something with d_types and c_types to be able to pass them to SQL
table_name = 'find_a_unique_table_name_not_to_clash_with_other'
df.createOrReplaceTempView(table_name)
return spark.sql("""
YOUR SQL HERE WITH %s AND MORE %s's TO INSERT c_types, d_types AND table_name
""" % (c_types, d_types, table_name))
The above function is mixing IO (the createOrReplaceTempView
) with logic (the SQL execution). As a cherry on top of that, it’s doing string interpolation, which is bad (like really really bad!).
Disentangling would mean rewrite them like so
def register_df_as_table(df):
table_name = .... # generate some random unique name here
df.createOrReplaceTempView(table_name)
return table_name
def my_function(table_name, d_types, c_types):
# do something with d_types and c_types to be able to pass them to SQL
return spark.sql("""
YOUR SQL HERE WITH %s AND MORE %s's TO INSERT c_types, d_types AND table_name
""" % (c_types, d_types, table_name))
In principle you could create a decorator out of register_df_as_table
and decorate my_function
, but you can see that this is getting pretty involved. With the dataframe API you can compose function much more easily.
Further composing away
With that out of the door, let’s see how you can further compose your functions and test them.
I won’t write the code here, but let’s say we have two extra functions, a_function
and another_function
, with a flow like this:
def load_data(..):
pass
def my_function(df, other_args):
pass
def a_function(df, other_args):
pass
def another_function(df):
pass
def main():
df_1 = load_data(..)
df_2 = my_function(df_1, args_1)
df_3 = a_function(df_2, args_2)
df_4 = another_function(df_3)
return df_4
The naming of those variables (df_{1..4}
) is terrible, but, as you all know, there are only two hard problems in computer science: naming things, off by one errors, and overwriting variables (such as naming them all df
).
A better alternative would involve piping the various functions
def pipe(data, *funcs):
for func in funcs:
data = func(data)
return data
def main():
partial_my_function = lambda df: my_function(df, args1)
partial_a_function = lambda df: a_function(df, args2)
return pipe(load_data(),
partial_my_function,
partial_a_function,
another_function)
This makes it, to my eyes, much better. Testing such a flow would then look like
def get_test_data():
# do something
return data
def test_my_function():
data = get_test_data()
assert my_function(data, args_1) == something # ideally this is a bit more involved
def test_a_function():
partial_my_function = lambda df: my_function(df, args1)
data = pipe(get_test_data, partial_my_function)
assert a_function(data) == something
def test_another_function():
partial_my_function = lambda df: my_function(df, args1)
partial_a_function = lambda df: a_function(df, args2)
data = pipe(get_test_data, partial_my_function, partial_a_function)
assert another_function(data) == something
# other tests here
This way, when one of the functions breaks, all successive tests will fail2.
Ok, that was a lot of (dummy) code. As always, let me know what you think, especially if you disagree (I’m @gglanzani on Twitter if you want to reach out!).
- Technically (thank you Andrew) this syntax mixes the DataFrame and SQL API. The DataFrame way of doing that is
df.select(df.my_column)
ordf.select(df['my_column'])
ordf.select(sf.col('my_column'))
. I still preferdf.select('my_column')
as it conveys my intent better. ↩ - You’d still want to write isolated tests, not using the pipeline, in case you introduce two regressions in different part of the pipeline that cancel their errors out! ↩