A greedy algorithm to create a spanning tree. Begin with just the nodes, then add edges.
Steps:
- init union-find data structure
- init empty list for MST edges
- sort edges by weight in ascending order
- for each edge in the sorted edge list
- If the edge does not form a cycle, add it to the MST, and union the sets of the two vertices.
- if the MST contains V-1 edges, stop.
Union-find:
- find function: return representative of input node
- optimize with path compression
- union function: merge smaller set into larger set
class UnionFind:
def __init__(self, size):
self.parent = list(range(size))
self.rank = [1] * size
def find(self, node):
if self.parent[node] != node:
self.parent[node] = self.find(self.parent[node]) # Path compression
return self.parent[node]
def union(self, node1, node2):
root1 = self.find(node1)
root2 = self.find(node2)
if root1 != root2:
if self.rank[root1] > self.rank[root2]:
self.parent[root2] = root1
elif self.rank[root1] < self.rank[root2]:
self.parent[root1] = root2
else:
self.parent[root2] = root1
self.rank[root1] += 1
return True
return False
def kruskal(n, edges):
# Sort edges by weight
edges.sort(key=lambda x: x[2])
# Initialize union-find
uf = UnionFind(n)
mst_cost = 0
mst_edges = []
# Process each edge
for u, v, weight in edges:
if uf.union(u, v):
mst_cost += weight
mst_edges.append((u, v, weight))
return mst_cost, mst_edges
# Example graph as an edge list
# (u, v, w) represents an edge between u and v with weight w
edges = [
(0, 1, 4), (0, 2, 1), (1, 2, 2),
(1, 3, 1), (2, 3, 5)
]
n = 4 # Number of vertices
mst_cost, mst_edges = kruskal(n, edges)
print("Cost of MST:", mst_cost)
print("Edges in MST:", mst_edges)