# Time: O(n) # Space: O(h) import random class TreeNode(object): def __init__(self, val=0, left=None, right=None): pass # bfs, quick select class Solution(object): def kthLargestLevelSum(self, root, k): """ :type root: Optional[TreeNode] :type k: int :rtype: int """ def nth_element(nums, n, left=0, compare=lambda a, b: a < b): def tri_partition(nums, left, right, target, compare): mid = left while mid <= right: if nums[mid] == target: mid += 1 elif compare(nums[mid], target): nums[left], nums[mid] = nums[mid], nums[left] left += 1 mid += 1 else: nums[mid], nums[right] = nums[right], nums[mid] right -= 1 return left, right right = len(nums)-1 while left <= right: pivot_idx = random.randint(left, right) pivot_left, pivot_right = tri_partition(nums, left, right, nums[pivot_idx], compare) if pivot_left <= n <= pivot_right: return elif pivot_left > n: right = pivot_left-1 else: # pivot_right < n. left = pivot_right+1 arr = [] q = [root] while q: new_q = [] for u in q: if u.left: new_q.append(u.left) if u.right: new_q.append(u.right) arr.append(sum(x.val for x in q)) q = new_q if k-1 >= len(arr): return -1 nth_element(arr, k-1, compare=lambda a, b: a > b) return arr[k-1]