def merge_sorted_lists(lists: List[Optional[ListNode]]) -> Optional[ListNode]:
class HeapNode:
def __init__(self, node: ListNode):
self.node = node
def __lt__(self, other):
return self.node.val < other.node.val
heap = []
current = dummy = ListNode(0)
for lst in lists:
if lst:
heapq.heappush(heap, HeapNode(lst))
while heap:
node = heapq.heappop(heap).node
current.next = node
current = current.next
if node.next:
heapq.heappush(heap, HeapNode(node.next))
return dummy.next