Avoiding, and Preventing, Joins
Joins are expensive and they should be avoided when possible. There are a few ways you can prevent them. In this blogpost I'll demonstrate two options; one in R and one in spark.
The Situation
Let's say that we have a database of users that performed some action on a website that we own. For the purpose of our research, we want to omit users that have only had one interaction with the website. If we take the example dataset like below, you will want to remove user c
;
You can do this with a join. But you can avoid it in this situation.
Tidy Nests in R
A standard way to do this in R would be to first create a dataframe that aggregates the number of interactions per user. This dataset can then be joined back to the original dataframe such that it can be filtered.
library(tidyverse)
df <- data.frame(
user = c("a", "a", "a", "b", "b", "b", "b", "c"),
val = c(3, 2, 1, 5, 6, 4, 1, 2)
)
agg <- df %>%
group_by(user) %>%
summarise(n = n())
df %>%
left_join(agg) %>%
filter(n >= 2) %>%
select(user, val)
The output will be correct, but because of the join the calculation can be rather expensive when the dataset becomes very large. An alternative in R is to use the nest
function from the tidyr
package (which comes withtidyverse
). This nest
function allows you have to have a dataframe inside of a dataframe. You can see how it works by calling;
> df %>%
group_by(user) %>%
nest()
user data
<fctr> <list>
1 a <tibble [3 x 1]>
2 b <tibble [4 x 1]>
3 c <tibble [1 x 1]>
In this form you can now create a new column that queries each dataset. The map
function from the purrr
package (again, comes with tidyverse
) will help you do just that.
> df %>%
group_by(user) %>%
nest() %>%
unnest(n = map(data, ~nrow(.)))
user data n
<fctr> <list> <int>
1 a <tibble [3 x 1]> 3
2 b <tibble [4 x 1]> 4
3 c <tibble [1 x 1]> 1
There is no need for a join anymore, the only thing we need to do is to unnest
the data
column.
> df %>%
group_by(user) %>%
nest() %>%
unnest(n = map(data, ~nrow(.))) %>%
unnest(data)
user n val
<fctr> <int> <dbl>
1 a 3 3
2 a 3 2
3 a 3 1
4 b 4 5
5 b 4 6
6 b 4 4
7 b 4 1
8 c 1 2
You can wrap everything up by removing the columns you don't need anymore.
> df %>%
group_by(user) %>%
nest() %>%
unnest(n = map(data, ~nrow(.))) %>%
unnest(data) %>%
filter(n >= 2) %>%
select(user, val)
user val
<fctr> <dbl>
1 a 3
2 a 2
3 a 1
4 b 5
5 b 6
6 b 4
7 b 1
Obviously there is a much better way in R though.
> df %>%
group_by(user) %>%
mutate(n = n()) %>%
ungroup()
user val n
<chr> <dbl> <int>
1 a 3 3
2 a 2 3
3 a 1 3
4 b 5 4
5 b 6 4
6 b 4 4
7 b 1 4
8 c 2 1
> df
group_by(user) %>%
filter(n() >= 2) %>%
ungroup()
user val
<fctr> <dbl>
1 a 3
2 a 2
3 a 1
4 b 5
5 b 6
6 b 4
7 b 1
Gotta love that dplyr
.
Window Partitions in Spark
The tidyr
workflow is awesome but it won't work everywhere. Spark does not have support for dataframe-in-a-column so we might need to do it some other way. The trick is to use window functions where we partition based on a user.
Let's first create the dataframe in pyspark
.
import pandas as pd
df = pd.DataFrame({
"user": ["a", "a", "a", "b", "b", "b", "b", "c"],
"val" : [3, 2, 1, 5, 6, 4, 1, 2]
})
ddf = sqlCtx.createDataFrame(df)
We can confirm that this is the same dataframe.
> ddf.show()
+----+---+
|user|val|
+----+---+
| a| 3|
| a| 2|
| a| 1|
| b| 5|
| b| 6|
| b| 4|
| b| 1|
| c| 2|
+----+---+
Again, the naive way to filter users is to use a join. You could do this via the code below.
from pyspark.sql import functions as sf
agg = (ddf
.groupBy("user")
.agg(sf.count("user").alias("n")))
res = (ddf
.join(agg, ddf.user == agg.user, "left")
.filter(sf.col("n") >= 2)
.select(ddf.user, ddf.val))
This new res
dataframe filters out the correct rows.
> res.show()
+----+---+
|user|val|
+----+---+
| a| 3|
| a| 2|
| a| 1|
| b| 5|
| b| 6|
| b| 4|
| b| 1|
+----+---+
To prevent the join we need to define a window functions that allows us to apply a function over a partition of data.
This window_spec
can be used for many functions in spark; min
, max
, sum
and also count
. You can see it in action via;
> (ddf
.withColumn("n", sf.count(sf.col("user")).over(window_spec))\
.show())
+----+---+---+
|user|val| n|
+----+---+---+
| a| 3| 3|
| a| 2| 3|
| a| 1| 3|
| b| 5| 4|
| b| 6| 4|
| b| 4| 4|
| b| 1| 4|
| c| 2| 1|
+----+---+---+
This new n
column can be used for filtering, just like we saw in R.
> (ddf
.withColumn("n", sf.count(sf.col("user")).over(window_spec))
.filter(sf.col("n") >= 2)
.select("user", "val")
.show())
+----+---+
|user|val|
+----+---+
| a| 3|
| a| 2|
| a| 1|
| b| 5|
| b| 6|
| b| 4|
| b| 1|
+----+---+
Conclusion
You can't always prevent joins, but this is a use-case where you definately don't need them. Joins are error-prone, expensive to calculate and require extra tables to be created. You can apply this trick in smaller datasets in R or even on larger datasets via spark. Please prevent where applicable.