In this lesson, we'll learn how to solve the very useful union-find problem.

There are other resources available on the same topic online, for example TopCoder, Wikipedia. Here's a visualization.

The Motivation

We will be motivated by the following problem (which we'll call the friendship problem):

The Friendship Problem. There are $n$ people in a room. You don't know who is friends with whom. However, from time to time, you obtain a piece of information of the following form:

From time to time, you would like to be able to answer questions of the following form:

Since we only have partial information at any point, the answer to "is friends $a$ friends with person $b$" can either be “definitely yes”, or “we're not sure”.

Deductions about friendship can be made using the following rules (which we may call the principles of friendship):

Never mind that friendship is complicated and these are not necessarily true in real life! Actually, let's also add in the following rule to make things even simpler.

Of course, we don't have to limit ourselves to friendships. One can replace the relation "$a$ is friends with $b$" with anything else that satisfies the properties above. For example:

(Verify that these indeed satisfy the properties above.)

As long as the relation satisfies the properties above, then we basically have the same problem. Relations satisfying the above three properties are called equivalence relations. The nice thing about such relations is that they nicely partition the objects into a bunch of disjoint sets such that two objects are related if and only if they belong to the same set. (It is a nice exercise to show why.) For example, in the friendship problem, suppose there are $n = 6$ people, and we know that $1$ is friends with $4$, $2$ is friends with $6$ and $3$ is friends with $6$. Then our set of people is partitioned as $$\{\{1, 4\}, \{2, 3, 6\}, \{5\}\}.$$ We also infer that $2$ is friends with $3$, even though it wasn't stated explicitly.

The Union-Find Problem

Let's consider an abstract version of the problem. (Doing so will enable us to answer any of the variants of the problem as described above.)

The Union-Find Problem. Given the set $$\{0, 1, 2, \ldots, n-1\}.$$ Initially, they are partitioned into $n$ sets such that each number belongs to its own partition. In other words, they are partitioned into $$\{\{0\}, \{1\}, \{2\}, \ldots, \{n-1\}\}.$$ You need to implement two kinds of operations:

A representative of a set is anything that will allow us to identify that set. We can use this operation, for example, to find out whether $a$ and $b$ belong to the same set (iff $\mathrm{find}(a) = \mathrm{find}(b)$). Note that when we perform a union operation, the partition may change, so the representatives may change as well. Thus, the value returned by the find operation is only useful before we perform the next union.

A solution to this problem involves a data structure that can support the required operations correctly. It should now be clear how to use such a solution to solve the friendship problem (and the related ones).

Actually, a slightly different version of the problem called the Union-Find-Make problem can sometimes be seen. It's defined as follows:

The Union-Find-Make problem. Initially, we have an empty set. You need to implement three kinds of operations:

The main differences with the union-find problem are that:

Nonetheless, these two problems are very similar in the sense that one can use the solution to one problem to solve the other, and in fact, most of the solutions we'll discuss can easily be modified to solve either version. Thus we'll only focus on one of them, namely the union-find problem.

Exercise. Given a solution to the union-find-make problem, use it to create a solution to the union-find problem.

Exercise. Given a solution to the union-find problem, use it to create a solution to the union-find-make problem, assuming you know beforehand of an upper bound on the number of make operations, say $u$.

Let's take an example instance of the problem. Suppose $n = 11$. Thus, initially, the partition looks like $$\{\{0\}, \{1\}, \{2\}, \{3\}, \{4\}, \{5\}, \{6\}, \{7\}, \{8\}, \{9\}, \{10\}\}.$$

Let's consider a bunch of union operations. Suppose we call $\mathrm{union}(2, 7)$. Then the partition will look like $$\{\{0\}, \{1\}, \{2, 7\}, \{3\}, \{4\}, \{5\}, \{6\}, \{8\}, \{9\}, \{10\}\}.$$

Next, suppose we call $\mathrm{union}(3, 5)$. Then the partition will look like $$\{\{0\}, \{1\}, \{2, 7\}, \{3, 5\}, \{4\}, \{6\}, \{8\}, \{9\}, \{10\}\}.$$

Next, suppose we call $\mathrm{union}(5, 10)$. Then the partition will look like $$\{\{0\}, \{1\}, \{2, 7\}, \{3, 5, 10\}, \{4\}, \{6\}, \{8\}, \{9\}\}.$$

Next, suppose we call $\mathrm{union}(3, 10)$. Note that $3$ and $10$ already belong to the same set, so nothing needs to be done.

Next, suppose we call $\mathrm{union}(6, 5)$. Then the partition will look like $$\{\{0\}, \{1\}, \{2, 7\}, \{3, 5, 6, 10\}, \{4\}, \{8\}, \{9\}\}.$$

At any point during this sequence of operations, we can call $\mathrm{find}(a)$, and it should return a representative/identifier for the set containing $a$. What this representative is depends on the implementation.

Approach 1: List of Lists

The above example actually suggests a simple solution: simply implement the partition as a list of lists! This is as straightforward as you can get. We just need to figure out how to implement the operations:

Here's an implementation of the above idea in C++. We will use a vector<int> to represent each list $L[i]$. Remember that a vector<int> is a dynamic array of ints.

  1. #include <vector>
  2. using namespace std;
  3. struct UnionFind {
  4. int n;
  5. vector<int>* L;
  6. UnionFind(int n) {
  7. this->n = n;
  8. L = new vector<int>[n];
  9. for (int i = 0; i < n; i++) {
  10. L[i] = vector<int>();
  11. L[i].push_back(i);
  12. }
  13. }
  14. int find(int a) {
  15. // loop through all lists
  16. for (int i = 0; i < n; i++) {
  17. // check if L[i] contains a
  18. for (int x = 0; x < L[i].size(); x++) {
  19. if (L[i][x] == a) {
  20. return i;
  21. }
  22. }
  23. }
  24. }
  25. // we use 'union_' since 'union' is a keyword in C++
  26. void union_(int a, int b) {
  27. int i = find(a);
  28. int j = find(b);
  29. if (i == j) return; // do nothing if they are in the same list
  30. // add all elements of L[i] to L[j]...
  31. for (int x = 0; x < L[i].size(); x++) {
  32. L[j].push_back(L[i][x]);
  33. }
  34. L[i].clear(); // ...and empty L[i]
  35. }
  36. };

Here's some sample usage.

  1. int main() {
  2. UnionFind uf = UnionFind(11);
  3. uf.union_(2, 7);
  4. uf.union_(3, 5);
  5. uf.union_(5, 10);
  6. uf.union_(3, 10);
  7. if (uf.find(6) == uf.find(3)) {
  8. cout << "6 and 3 belong to the same set!" << endl;
  9. } else {
  10. cout << "6 and 3 do not belong to the same set." << endl;
  11. }
  12. uf.union_(6, 5);
  13. if (uf.find(6) == uf.find(3)) {
  14. cout << "6 and 3 belong to the same set!" << endl;
  15. } else {
  16. cout << "6 and 3 do not belong to the same set." << endl;
  17. }
  18. }

While this works, this is quite slow; for every find operation, we need to go through the whole list of lists just to find $a$. Our goal, of course, is to make everything as efficient as possible.

In terms of Big-O, we find that each find or union operation takes $O(n)$ in the worst case. Thus, performing $q$ operations takes $O(qn)$ time.

We can optimize the solution above by storing the location of all elements beforehand; that way, a find operation becomes a single lookup. For example:

  1. struct UnionFind {
  2. int n;
  3. vector<int>* L;
  4. int* location;
  5. UnionFind(int n) {
  6. this->n = n;
  7. L = new vector<int>[n];
  8. location = new int[n];
  9. for (int i = 0; i < n; i++) {
  10. L[i] = vector<int>();
  11. L[i].push_back(i);
  12. location[i] = i; // initialize location[i] as i
  13. }
  14. }
  15. int find(int a) {
  16. return location[a]; // just return the location of a
  17. }
  18. void union_(int a, int b) {
  19. int i = find(a);
  20. int j = find(b);
  21. if (i == j) return;
  22. for (int x = 0; x < L[i].size(); x++) {
  23. L[j].push_back(L[i][x]);
  24. location[L[i][x]] = j; // set the new location of L[i][x]
  25. }
  26. L[i].clear();
  27. }
  28. };

Much better! Now, find is much faster: it now takes $O(1)$. However, union still takes $O(n)$ time, so $q$ operations still take $O(qn)$ in the worst case. To trigger this worst case, one can do the following series of unions: $\mathrm{union}(0, 1), \mathrm{union}(1, 2), \mathrm{union}(2, 3), \mathrm{union}(3, 4), \ldots$.

There's one way to somewhat optimize the union. Note that in the union operation above, we're only looping across one of the lists $L[i]$ or $L[j]$. But it doesn't matter which array we choose; thus, it makes sense to loop on the shorter one! Thus, we could optimize it by always appending the shorter list to the longer list.

Here's an implementation of this idea:

  1. void union_(int a, int b) {
  2. int i = find(a);
  3. int j = find(b);
  4. if (i == j) return;
  5. if (L[i].size() < L[j].size()) {
  6. for (int x = 0; x < L[i].size(); x++) {
  7. L[j].push_back(L[i][x]);
  8. location[L[i][x]] = j;
  9. }
  10. L[i].clear();
  11. } else {
  12. for (int x = 0; x < L[j].size(); x++) {
  13. L[i].push_back(L[j][x]);
  14. location[L[j][x]] = i;
  15. }
  16. L[j].clear();
  17. }
  18. }

There's some duplication in this code. There should be a better implementation. Indeed, here is one:

  1. void union_(int a, int b) {
  2. int i = find(a);
  3. int j = find(b);
  4. if (i == j) return;
  5. if (L[i].size() > L[j].size()) swap(i, j); // swap to ensure that L[i] is shorter than L[j]
  6. for (int x = 0; x < L[i].size(); x++) {
  7. L[j].push_back(L[i][x]);
  8. location[L[i][x]] = j;
  9. }
  10. L[i].clear();
  11. }

This swaps i and j to ensure that $L[i]$ is shorter than (or equal in length with) $L[j]$.

Of course, a single union operation is still $O(n)$ in the worst case. However, amazingly, a more careful analysis shows that $q$ operations takes $O(n + q \log n)$ in the worst case. This simple improvement leads to a huge improvement in running time, and it actually gives us a good-enough solution!

We will not prove right now why it takes $O(n + q \log n)$; we will do so when we learn about amortized analysis.

Approach 2: Forests

It turns out that we can do much better: there's a solution to the union-find problem that's faster and much easier to implement! Thus, you can forget about the previous solution.

The solution involves forming a forest among the $n$ objects. Specifically, our $n$ objects will be organized into a forest of rooted trees such that each rooted tree represents a set in our partition.

It turns out that for this implementation to work, we only need to remember the parent of each object in the tree. For simplicity in the implementation, we say that the parent of a root node is itself.

To get a representative of each set, the most natural choice is to simply use the root of the tree. Thus, to implement $\mathrm{find}(a)$, one simply has to go up the tree starting from $a$ until we reach the root. To perform a union, first find the roots of the two trees to unite, and then simply set one of them to be the parent of the other!

Here's an implementation:

  1. struct UnionFind {
  2. int n;
  3. int* parent;
  4. UnionFind(int n) {
  5. this->n = n;
  6. parent = new int[n];
  7. for (int i = 0; i < n; i++) {
  8. parent[i] = i;
  9. }
  10. }
  11. int find(int a) {
  12. // go up the tree by following the "parent" link
  13. while (parent[a] != a) {
  14. a = parent[a];
  15. }
  16. return a;
  17. }
  18. void union_(int a, int b) {
  19. int ra = find(a);
  20. int rb = find(b);
  21. if (ra == rb) return;
  22. // point ra to rb
  23. parent[ra] = rb;
  24. }
  25. };

This is clearly much easier to implement!

Let's see this thing in action. Consider $n = 11$ again. Thus, initially, the forest looks like:

The parent array looks like $$\begin{array}{r|rrrrrrrrrrr} i & 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10 \\ \hline \mathrm{parent}[i] & 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}.$$

Now, suppose we call $\mathrm{union}(2, 7)$. Then the forest will look like

The parent array looks like $$\begin{array}{r|rrrrrrrrrrr} i & 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10 \\ \hline \mathrm{parent}[i] & 0 & 1 & 7 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}.$$

Next, suppose we call $\mathrm{union}(3, 5)$. Then the forest will look like

The parent array looks like $$\begin{array}{r|rrrrrrrrrrr} i & 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10 \\ \hline \mathrm{parent}[i] & 0 & 1 & 7 & 5 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}.$$

Next, suppose we call $\mathrm{union}(3, 2)$. Then the forest will look like

The parent array looks like $$\begin{array}{r|rrrrrrrrrrr} i & 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10 \\ \hline \mathrm{parent}[i] & 0 & 1 & 7 & 5 & 4 & 7 & 6 & 7 & 8 & 9 & 10\end{array}.$$

Notice that even though $2$ and $3$ are not the roots, only the roots of the corresponding trees are updated.

There's an alternative implementation of the find operation using recursion:

  1. int find(int a) {
  2. if (parent[a] == a) {
  3. return a;
  4. } else {
  5. return find(parent[a]);
  6. }
  7. }

One can even write the operations as one-liners:

  1. int find(int a) {
  2. return parent[a] == a ? a : find(parent[a]);
  3. }
  4. void union_(int a, int b) {
  5. parent[find(a)] = find(b);
  6. }

Short and sweet!

Unfortunately, the above implementation is not really that efficient; it can be shown that $q$ operations can still take $O(qn)$. We need to find some improvements.

Improvement 1

First, we can use the same sort of improvement that we did for the list of lists solution, namely point the smaller tree to the larger tree. This optimization is called union by weight, and doing this also gives us an $O(n + q \log n)$ solution.

Alternatively, we can observe that the height of a tree is more important than its weight, so we can compare heights instead of weight, i.e., point the shorter tree to the taller one. This also yields $O(n + q \log n)$.

Here's an implementation of union by height. Here, we store the heights in a separate array called height.

  1. struct UnionFind {
  2. int n;
  3. int* parent;
  4. int* height;
  5. UnionFind(int n) {
  6. this->n = n;
  7. parent = new int[n];
  8. height = new int[n];
  9. for (int i = 0; i < n; i++) {
  10. parent[i] = i;
  11. height[i] = 0;
  12. }
  13. }
  14. int find(int a) {
  15. return parent[a] == a ? a : find(parent[a]);
  16. }
  17. void union_(int a, int b) {
  18. int ra = find(a);
  19. int rb = find(b);
  20. if (ra == rb) return;
  21. // point to the taller tree
  22. if (height[ra] < height[rb]) {
  23. parent[ra] = rb;
  24. } else if (height[ra] > height[rb]) {
  25. parent[rb] = ra;
  26. } else {
  27. // arbitrarily choose the new root
  28. parent[ra] = rb;
  29. height[rb]++;
  30. }
  31. }
  32. };

Improvement 2

Next, notice that for every node, all we really need to know is the root of its tree. We don't really need to remember anything else. Hence, we can choose to reorganize the tree however we want, as long as it still contains the same set of nodes. Of course, we want the height of the tree to be as small as possible so that find becomes as fast as possible. One thing we can do, then, is the following: whenever we call $\mathrm{find}(a)$, we compress the path we traversed to the root. What this means that we set the parent of all nodes in that path to the root, so that the $\mathrm{find}(a)$ call (or any $\mathrm{find}(x)$ calls for any $x$ in this path) can be done quickly, since we only need to traverse one link to get to the root!

Here's an implementation, which only requires a slight change to the recursive version:

  1. int find(int a) {
  2. if (parent[a] == a) {
  3. return a;
  4. } else {
  5. int res = find(parent[a]);
  6. parent[a] = res; // point this node to the root
  7. return res;
  8. }
  9. }

Or, in one line:

  1. int find(int a) {
  2. return parent[a] == a ? a : parent[a] = find(parent[a]);
  3. }

This simple change works, since the assignment parent[a] = res will be done for every node in the path.

Using only this improvement (without union by height), we can also show that $q$ operations takes $O(n + q \log n)$ time. But we get an amazing improvement if we use both union by height and path compression! Specifically, by doing both, the running time for $q$ operations becomes $O(n + q\cdot \alpha(n))$, where $\alpha(n)$ is the very slow-growing inverse Ackermann function. We won't prove why this is the running time, but all you need to know to get a feel for how good this is is that $\alpha(n)$ grows veeeeerrryyy slowly. It's even slower than $\log n$, $\log \log n$, or any number of $\log$s iterated together! It's hard to describe just how slow this function grows; to give you an idea, one can show that $\alpha(n) \le 5$ for all $n \le 2^{2^{2^{65536}}}$. That number is much, much larger than any $n$ we will possibly need! Hence, for all practical purposes, we can consider $\alpha(n)$ to be a constant. (However, strictly speaking, it's not a constant; as slow-growing as it is, it still goes to infinity as $n$ goes to infinity.)

As a technical note, when doing both improvements, the "height" of a tree will sometimes change as a result of path compression. However, we will not bother to ensure that the contents of height[x] is the correct height of the tree (since it might be hard to ensure that). Instead, we will just call it by another name: rank. It can be shown that the running time is still $O(n + q\cdot \alpha(n))$ even by doing this.

Here's the full implementation which contains both improvements path compression and union by rank.

  1. struct UnionFind {
  2. int n;
  3. int* parent;
  4. int* rank;
  5. UnionFind(int n) {
  6. this->n = n;
  7. parent = new int[n];
  8. rank = new int[n];
  9. for (int i = 0; i < n; i++) {
  10. parent[i] = i;
  11. rank[i] = 0;
  12. }
  13. }
  14. int find(int a) {
  15. return parent[a] == a ? a : parent[a] = find(parent[a]);
  16. }
  17. void union_(int a, int b) {
  18. int ra = find(a);
  19. int rb = find(b);
  20. if (ra == rb) return;
  21. if (rank[ra] < rank[rb]) {
  22. parent[ra] = rb;
  23. } else if (rank[ra] > rank[rb]) {
  24. parent[rb] = ra;
  25. } else {
  26. parent[ra] = rb;
  27. rank[rb]++;
  28. }
  29. }
  30. };

Extensions

We can extend the problem by asking a few more types of questions. For example, considering the friendship problem again, we could ask the following:

But it turns out that it's straightforward to modify the solutions above to account for these questions. For example, the first question is simply asking for the size of the set containing $a$, which can be solved by augmenting the structure with additional information, namely the size of each component. This only requires creating an additional array, which can be called size, and just being careful to update this value when merging two trees!

Of course, one can augment the structure to have any kind of information you need, for example, “the number of female members” or “the number of zombies”, or whatever. It really depends on what you need.

Problems

Non-Coding Problems