def binary_tree_max_path_sum(root: Optional[TreeNode]) -> int:
ans = float('-inf')
def ending_with(node):
# What is the max path sum with path ending at node
# and the path contains only node's children?
nonlocal ans
if not node: return 0
left_gain = max(ending_with(node.left), 0)
right_gain = max(ending_with(node.right), 0)
# max path sum of path passing through node
current_max = left_gain + node.val + right_gain
ans = max(ans, current_max)
node_gain = max(left_gain, left_gain) + node.val
return node_gain
ending_with(root)
return ans