Disjoint Set Data Structure in C
The Disjoint Set, also known as Union-Find, is a data structure that keeps track of a set of elements partitioned into disjoint (non-overlapping) subsets. It provides near-constant time operations to perform union and find operations.
In this implementation, we’ll use two primary functions:
find(int x)
: Determines the root of the set containingx
and applies path compression to optimize future queries.union(int x, int y)
: Merges the sets containingx
andy
. We’ll implement union by rank to keep the tree flat and improve efficiency.
Program Structure
We’ll break down the implementation into the following parts:
- Struct definition to hold the parent and rank arrays.
- Initialization function to set up the data structure.
find()
function to locate the root of an element with path compression.union()
function to merge two sets with union by rank.- A main function to demonstrate the usage of the Disjoint Set data structure.
Code Implementation
#include <stdio.h>
#include <stdlib.h>
// Define the structure for the Disjoint Set
typedef struct {
int *parent; // Array to hold the parent of each element
int *rank; // Array to hold the rank (or depth) of each tree
int size; // Number of elements in the set
} DisjointSet;
// Function to initialize the Disjoint Set
DisjointSet* createSet(int n) {
DisjointSet *set = (DisjointSet*)malloc(sizeof(DisjointSet));
set->parent = (int*)malloc(n * sizeof(int));
set->rank = (int*)malloc(n * sizeof(int));
set->size = n;
for (int i = 0; i < n; i++) { set->parent[i] = i; // Each element is initially its own parent
set->rank[i] = 0; // Rank starts at 0
}
return set;
}
// Function to find the root of an element with path compression
int find(DisjointSet *set, int x) {
if (set->parent[x] != x) {
set->parent[x] = find(set, set->parent[x]); // Path compression
}
return set->parent[x];
}
// Function to perform union by rank
void unionSets(DisjointSet *set, int x, int y) {
int rootX = find(set, x);
int rootY = find(set, y);
if (rootX != rootY) {
// Union by rank
if (set->rank[rootX] > set->rank[rootY]) {
set->parent[rootY] = rootX;
} else if (set->rank[rootX] < set->rank[rootY]) {
set->parent[rootX] = rootY;
} else {
set->parent[rootY] = rootX;
set->rank[rootX]++;
}
}
}
// Main function to demonstrate the usage of Disjoint Set
int main() {
int n = 5; // Number of elements (0 to 4)
DisjointSet *set = createSet(n);
// Perform some unions
unionSets(set, 0, 2);
unionSets(set, 4, 2);
unionSets(set, 3, 1);
// Find the root of each element
printf("Find(0): %d\n", find(set, 0));
printf("Find(1): %d\n", find(set, 1));
printf("Find(2): %d\n", find(set, 2));
printf("Find(3): %d\n", find(set, 3));
printf("Find(4): %d\n", find(set, 4));
// Check if 0 and 4 are in the same set
if (find(set, 0) == find(set, 4)) {
printf("0 and 4 are in the same set.\n");
} else {
printf("0 and 4 are in different sets.\n");
}
// Free allocated memory
free(set->parent);
free(set->rank);
free(set);
return 0;
}
Explanation
The program begins by defining a struct DisjointSet
that contains two arrays: parent
and rank
. The parent
array holds the parent of each element, and the rank
array helps keep the trees flat during union operations.
The createSet()
function initializes the disjoint set with each element as its own parent and rank 0. The find()
function uses path compression to optimize the find operation by making each node in the path point directly to the root. The unionSets()
function merges two sets based on their rank, ensuring that the smaller tree (by rank) is attached under the root of the larger tree.
The main()
function demonstrates the usage of the disjoint set by performing some union operations and then using the find()
function to check the root of various elements.