Saturday, 31 December 2016

Disjoint Set Data Structure

This post will illustrate the basic implementation of Disjoint-set data structure. We will use the two popular heuristics:
1. Union by rank
2. Path Compression

Refer this TopCoder tutorial for detailed explanation of this data structure.

To illustrate with example, we solve this problem on hackerrank.

The basic implementation using class is as follows:

#include <bits/stdc++.h>
using namespace std;


// heuristics:
//    1. Union by rank
/// 2. Path compression
class DisjointSet
{
public:
    class Node
    {
    public:
        Node*parent;
        int rank; // used for union by rank
        int associatedNodes; // maintaining the number of nodes in a group (for representative element)

        Node()
        {    
            rank = 1;
            associatedNodes = 1;
            parent = this;
        }
    };

    vector<Node*>nodes;
    DisjointSet(int n)
    {
        nodes=vector<Node*>(n);
        for(int i=0;i<n;i++)
            nodes[i]=new Node();
    }
    Node*find(Node*node) // returns the representative element of node
    {
        if(node->parent != node)
            node->parent = find(node->parent); // path compression
        return node->parent;
    }
    void merge(Node*n1,Node*n2)
    {
        Node*pn1 = find(n1);
        Node*pn2 = find(n2);
        if(pn1!=pn2)
        {
            if(pn1->rank > pn2->rank)
            {
                pn2->parent = pn1;
                pn1->associatedNodes += pn2->associatedNodes;
            }
            else 
            {
                pn1->parent = pn2;
                pn2->associatedNodes += pn1->associatedNodes;
            }
            if(pn1->rank == pn2->rank)
                pn2->rank++;
        }
    }

};

int main()
{
    int n;
    cin>>n;
    DisjointSet ds(n);

    int q;
    cin>>q;
    while(q--)
    {
        char ch;
        cin>>ch;
        if(ch=='M')
        {
            int a,b;
            cin>>a>>b;
            ds.merge(ds.nodes[a-1],ds.nodes[b-1]);
        }
        else if(ch=='Q')
        {
            int num;
            cin>>num;
            cout<<ds.find(ds.nodes[num-1])->associatedNodes<<endl;
        }
    }
    return 0;
}
But that's too much of hassle
The following code illustrates a clever implementation of Disjoint set data structure.
#include <bits/stdc++.h>
using namespace std;
#define MAXN 100010

vector<int>parent(MAXN),size(MAXN);

int find(int n) {
    if(parent[n]==n) return n;
    else return parent[n] = find(parent[n]);
}

void merge(int a,int b) {
    int pa = find(a);
    int pb = find(b);
    if(pa!=pb) {
        if(size[pa]>size[pb]) {
            parent[pb]=pa;
            size[pa]+=size[pb];
        }
        else {
            parent[pa]=pb;
            size[pb]+=size[pa];
        }
    }
}

int main()
{

    for(int i=0;i<MAXN;i++)
        parent[i] = i, size[i] = 1;

    int q,n,a,b,num;
    cin>>n>>q;

    while(q--) {
        char ch;
        cin>>ch;
        if(ch=='M') {
            cin>>a>>b;
            merge(a-1,b-1);
        }
        else if(ch=='Q') {
            cin>>num;
            cout<<size[find(num-1)]<<endl;
        }
    }
    return 0;
}

No comments:

Post a Comment