Press "Enter" to skip to content

Category: Spark

Reversing posexplode in SparkSQL

I had a table with many instances of a primary key and an array. I needed to return one row with an averaged array in SparkSQL.

To average an array we just average each value independently, but managing nested types in SQL is notoriously a PITA. I searched around and didn’t find any good answers on StackOverflow or the net that I liked, so I thought I would take a crack at a “pure” SQL approach.

First idea – what if I could explode the values and then reassemble them, grouping by the ordinal and averaging the values?

Cool idea! The problem is that Spark has no ordered array rollup function (at least that I understood reading the docs and SO)… so what can I do to deal with that? Am I stuck?

I reviewed the SparkSQL function documentation and realized I didn’t have any magic bullets, so I reached back into my SQL hat and asked myself “How I would force ordering without an ORDER BY?”

  • What about a subquery for each element in the array? A correlated subquery would “work”… in the most disgusting way possible.
  • Well, we could emit a big ol’ case statement I guess…
  • Or wait, isn’t that just what I always tell people to use instead of … PIVOT?

Ok, let’s try this:

card = spark.sql("select size(array_col) as size from array_table").first()["size"]

print(f"We see the arrays have {card} dimensions.")

cols_as_values = ', '.join(str(x) for x in range(card))
cols_as_cols = ', '.join('`' + str(x) + '`' for x in range(card))

query = f"""
  array({cols_as_cols}) /* #6 */
  avg(val) as avg_val 
from array_table as t0 
lateral view posexplode(array_col) as ord, val /* #3 */ 
group by 
  ord /* #4 */
) as avg_arr
  first_value(avg_val) /* #5 */
  as avg_dim_val for ord in ({cols_as_values})
order by primary_key

For those with Databricks or Databricks Community Edition (free as of this writing) you can also review and fork the notebook here.

Yeah this is ugly, but its significantly faster than the similar code running the same steps with a numpy udf, I need to do more testing to make this claim a bit more solid.