@dataclass
class Edge:
u: int
v: int
w: float
def prims_mst(num_vertices: int, edges: List[Edge]) -> Tuple[List[Edge], float]:
graph = {u: [] for u in range(num_vertices)}
for edge in edges:
graph[edge.u].append((edge.v, edge.w))
graph[edge.v].append((edge.u, edge.w))
visited = [False] * num_vertices
min_heap = [(0.0, 0)] # (weight, vertex)
total_cost = 0
mst_edges = []
while min_heap:
weight, u = heapq.heappop(min_heap)
if visited[u]:
continue
visited[u] = True
total_cost += weight
if weight > 0: # Skip the initial vertex
mst_edges.append(Edge(prev_vertex, u, weight))
for v, edge_weight in graph[u]:
if not visited[v]:
heapq.heappush(min_heap, (edge_weight, v))
prev_vertex = u # Track the previous vertex for the edge
return mst_edges, total_cost, all(visited)