A greedy algorithm to create a spanning tree. Begin with just the nodes, then add edges.

Steps:

  • init union-find data structure
  • init empty list for MST edges
  • sort edges by weight in ascending order
  • for each edge in the sorted edge list
    • If the edge does not form a cycle, add it to the MST, and union the sets of the two vertices.
    • if the MST contains V-1 edges, stop.

Union-find:

  • find function: return representative of input node
    • optimize with path compression
  • union function: merge smaller set into larger set
class UnionFind:
    def __init__(self, size):
        self.parent = list(range(size))
        self.rank = [1] * size
 
    def find(self, node):
        if self.parent[node] != node:
            self.parent[node] = self.find(self.parent[node])  # Path compression
        return self.parent[node]
 
    def union(self, node1, node2):
        root1 = self.find(node1)
        root2 = self.find(node2)
        
        if root1 != root2:
            if self.rank[root1] > self.rank[root2]:
                self.parent[root2] = root1
            elif self.rank[root1] < self.rank[root2]:
                self.parent[root1] = root2
            else:
                self.parent[root2] = root1
                self.rank[root1] += 1
            return True
        return False
 
def kruskal(n, edges):
    # Sort edges by weight
    edges.sort(key=lambda x: x[2])
    
    # Initialize union-find
    uf = UnionFind(n)
    mst_cost = 0
    mst_edges = []
 
    # Process each edge
    for u, v, weight in edges:
        if uf.union(u, v):
            mst_cost += weight
            mst_edges.append((u, v, weight))
    
    return mst_cost, mst_edges
 
# Example graph as an edge list
# (u, v, w) represents an edge between u and v with weight w
edges = [
    (0, 1, 4), (0, 2, 1), (1, 2, 2), 
    (1, 3, 1), (2, 3, 5)
]
n = 4  # Number of vertices
 
mst_cost, mst_edges = kruskal(n, edges)
print("Cost of MST:", mst_cost)
print("Edges in MST:", mst_edges)