Understanding Disjoint Set (Union-Find) in Python

Understanding Disjoint Set (Union-Find) in Python

Introduction

Disjoint Set, also known as Union-Find, is a data structure that efficiently handles dynamic connectivity queries. It supports two key operations:

  1. Find(x): Determines the representative (or root) of the set containing element x.
  2. Union(x, y): Merges the sets containing x and y.

Union-Find is widely used in applications such as:

  • Graph algorithms (e.g., Kruskal’s Minimum Spanning Tree)
  • Connected components detection
  • Network connectivity problems

Implementation in Python

Below is an optimized implementation of Disjoint Set using Path Compression (to speed up find) and Union by Rank (to keep trees balanced).

class DisjointSet:
    def __init__(self, n):
        """Initialize the disjoint set with n elements."""
        self.parent = list(range(n))
        self.rank = [1] * n  # Rank helps to optimize union operation

    def find(self, x):
        """Find the root of x with path compression."""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # Path compression
        return self.parent[x]

    def union(self, x, y):
        """Union two sets by rank."""
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            if self.rank[root_x] > self.rank[root_y]:
                self.parent[root_y] = root_x
            elif self.rank[root_x] < self.rank[root_y]:
                self.parent[root_x] = root_y
            else:
                self.parent[root_y] = root_x
                self.rank[root_x] += 1

# Example usage
ds = DisjointSet(5)
ds.union(0, 1)
ds.union(1, 2)
print(ds.find(2))  # Output: 0 (or another representative)
print(ds.find(3))  # Output: 3 (since 3 is isolated)