Skip to content

Avoiding, and Preventing, Joins

Joins are expensive and they should be avoided when possible. There are a few ways you can prevent them. In this blogpost I'll demonstrate two options; one in R and one in spark.

The Situation

Let's say that we have a database of users that performed some action on a website that we own. For the purpose of our research, we want to omit users that have only had one interaction with the website. If we take the example dataset like below, you will want to remove user c;

  user val
1    a   3
2    a   2
3    a   1
4    b   5
5    b   6
6    b   4
7    b   1
8    c   2

You can do this with a join. But you can avoid it in this situation.

Tidy Nests in R

A standard way to do this in R would be to first create a dataframe that aggregates the number of interactions per user. This dataset can then be joined back to the original dataframe such that it can be filtered.

library(tidyverse)

df <- data.frame(
  user = c("a", "a", "a", "b", "b", "b", "b", "c"),
  val  = c(3, 2, 1, 5, 6, 4, 1, 2)
)

agg <- df %>% 
  group_by(user) %>% 
  summarise(n = n())

df %>% 
  left_join(agg) %>% 
  filter(n >= 2) %>% 
  select(user, val)

The output will be correct, but because of the join the calculation can be rather expensive when the dataset becomes very large. An alternative in R is to use the nest function from the tidyr package (which comes withtidyverse). This nest function allows you have to have a dataframe inside of a dataframe. You can see how it works by calling;

> df %>% 
  group_by(user) %>% 
  nest() 

    user             data
  <fctr>           <list>
1      a <tibble [3 x 1]>
2      b <tibble [4 x 1]>
3      c <tibble [1 x 1]>

In this form you can now create a new column that queries each dataset. The map function from the purrr package (again, comes with tidyverse) will help you do just that.

> df %>% 
  group_by(user) %>% 
  nest() %>% 
  unnest(n = map(data, ~nrow(.)))

    user             data     n
  <fctr>           <list> <int>
1      a <tibble [3 x 1]>     3
2      b <tibble [4 x 1]>     4
3      c <tibble [1 x 1]>     1

There is no need for a join anymore, the only thing we need to do is to unnest the data column.

> df %>% 
  group_by(user) %>% 
  nest() %>% 
  unnest(n = map(data, ~nrow(.))) %>% 
  unnest(data)

    user     n   val
  <fctr> <int> <dbl>
1      a     3     3
2      a     3     2
3      a     3     1
4      b     4     5
5      b     4     6
6      b     4     4
7      b     4     1
8      c     1     2

You can wrap everything up by removing the columns you don't need anymore.

> df %>% 
  group_by(user) %>% 
  nest() %>% 
  unnest(n = map(data, ~nrow(.))) %>% 
  unnest(data) %>% 
  filter(n >= 2) %>% 
  select(user, val)

    user   val
  <fctr> <dbl>
1      a     3
2      a     2
3      a     1
4      b     5
5      b     6
6      b     4
7      b     1

Obviously there is a much better way in R though.

> df %>% 
  group_by(user) %>% 
  mutate(n = n()) %>% 
  ungroup()

   user   val     n
  <chr> <dbl> <int>
1     a     3     3
2     a     2     3
3     a     1     3
4     b     5     4
5     b     6     4
6     b     4     4
7     b     1     4
8     c     2     1

> df
  group_by(user) %>% 
  filter(n() >= 2) %>% 
  ungroup()

    user   val
  <fctr> <dbl>
1      a     3
2      a     2
3      a     1
4      b     5
5      b     6
6      b     4
7      b     1

Gotta love that dplyr.

Window Partitions in Spark

The tidyr workflow is awesome but it won't work everywhere. Spark does not have support for dataframe-in-a-column so we might need to do it some other way. The trick is to use window functions where we partition based on a user.

Let's first create the dataframe in pyspark.

import pandas as pd 
df = pd.DataFrame({
    "user": ["a", "a", "a", "b", "b", "b", "b", "c"],
    "val" : [3, 2, 1, 5, 6, 4, 1, 2]
})
ddf = sqlCtx.createDataFrame(df)

We can confirm that this is the same dataframe.

> ddf.show()

+----+---+
|user|val|
+----+---+
|   a|  3|
|   a|  2|
|   a|  1|
|   b|  5|
|   b|  6|
|   b|  4|
|   b|  1|
|   c|  2|
+----+---+

Again, the naive way to filter users is to use a join. You could do this via the code below.

from pyspark.sql import functions as sf 
agg = (ddf
  .groupBy("user")
  .agg(sf.count("user").alias("n")))

res = (ddf
  .join(agg, ddf.user == agg.user, "left")
  .filter(sf.col("n") >= 2)
  .select(ddf.user, ddf.val))

This new res dataframe filters out the correct rows.

> res.show()

+----+---+
|user|val|
+----+---+
|   a|  3|
|   a|  2|
|   a|  1|
|   b|  5|
|   b|  6|
|   b|  4|
|   b|  1|
+----+---+

To prevent the join we need to define a window functions that allows us to apply a function over a partition of data.

from pyspark.sql import Window 
window_spec = Window.partitionBy(ddf.user)

This window_spec can be used for many functions in spark; min, max, sum and also count. You can see it in action via;

> (ddf
  .withColumn("n", sf.count(sf.col("user")).over(window_spec))\
  .show())

+----+---+---+
|user|val|  n|
+----+---+---+
|   a|  3|  3|
|   a|  2|  3|
|   a|  1|  3|
|   b|  5|  4|
|   b|  6|  4|
|   b|  4|  4|
|   b|  1|  4|
|   c|  2|  1|
+----+---+---+

This new n column can be used for filtering, just like we saw in R.

> (ddf
  .withColumn("n", sf.count(sf.col("user")).over(window_spec))
  .filter(sf.col("n") >= 2)
  .select("user", "val")
  .show())

+----+---+
|user|val|
+----+---+
|   a|  3|
|   a|  2|
|   a|  1|
|   b|  5|
|   b|  6|
|   b|  4|
|   b|  1|
+----+---+

Conclusion

You can't always prevent joins, but this is a use-case where you definately don't need them. Joins are error-prone, expensive to calculate and require extra tables to be created. You can apply this trick in smaller datasets in R or even on larger datasets via spark. Please prevent where applicable.