Finding the max branch sum of a sequentially represented binary tree in Python

Given a binary tree represented sequentially as a list of integer node values, we want to identify whether the sum of the left branch or the sum of th right branch is larger. In other words we want to identify which branch has the maximum sum.

For example, this binary tree:

····································
················┌──┐················
00··············│05│················
················└──┘················
··········╱··············╲··········
········┌──┐············┌──┐········
01······│04│············│06│········
········└──┘············└──┘········
······╱······╲······················
····┌──┐····┌──┐····················
02··│16│····│12│····················
····└──┘····└──┘····················
····································

Could be represented sequentially as [5, 4, 6, 16, 12]. Missing nodes are represented as -1.

(The numbers on the left of the diagram are the height of the tree at that level.)

If we label each node with its index in the sequential representation, it looks like this:

····································
················┌──┐················
00··············│00│················
················└──┘················
··········╱··············╲··········
········┌──┐············┌──┐········
01······│01│············│02│········
········└──┘············└──┘········
······╱······╲········╱······╲······
····┌──┐····┌──┐····┌──┐····┌──┐····
02··│03│····│04│····│05│····│06│····
····└──┘····└──┘····└──┘····└──┘····
····································

It’s possible to determine whether the sum of the left or right branch is larger in $O(n)$ time.

The main problem to solve is summing a branch of a sequentially represented binary tree, after which we just need to compare the two sums. The sum can be solved either recursively or iteratively, which we’ll look at below.

O(n) recursive solution

The recursive solution to this problem is the more readable and elegant of the two.

The key insight is to be able to jump to the left or right child node index given a parent node index. Then we can recursively call the sum function from a given parent index to sum up a whole branch.

Looking at the indexed diagram above, it seems that you can jump to the left child of a parent node index with $2p + 1$, and to the right child index with $2p + 2$, where $p$ is the index of the parent node.

For example, if we’re on parent index $01$ in the diagram, then the child indexes are at:

  • Left: $2p + 1 = (2 \times 1) + 1 = 2 + 1 = 3$
  • Right: $2p + 2 = (2 \times 1) + 2 = 2 + 2 = 4$

Similarly, if we’re on parent index $02$, then the child indexes are at:

  • Left: $2p + 1 = (2 \times 2) + 1 = 4 + 1 = 5$
  • Right: $2p + 2 = (2 \times 2) + 2 = 4 + 2 = 6$

So this method seems to allow us to jump to the left and right child indexes from a parent index.

Using that, we can use a recursive solution to sum the branch of a sequentially represented binary tree:

def sum_branch_recur(nodes, parent_idx): 
  if parent_idx >= len(nodes): 
    return 0 
  if nodes[parent_idx] == -1: 
    return 0 
  return ( 
      nodes[parent_idx] 
      + sum_branch_recur(nodes, (parent_idx * 2) + 1) 
      + sum_branch_recur(nodes, (parent_idx * 2) + 2) 
  )

For example, calling sum_branch_recur([5, 7, 6, 16, 12], 1) will sum 7, 16 and 12 for a branch sum of 35.

We can then use that to determine the larger branch of a given binary tree:

def seq_max_branch(nodes): 
  left_sum = sum_branch_recur(nodes, 1) 
  right_sum = sum_branch_recur(nodes, 2) 
  if left_sum == right_sum: 
    return "" 
  return "Left" if left_sum > right_sum else "Right"

This has $O(n)$ time complexity as it only needs to visit each node in the list once.

O(n) iterative solution

The recursive solution is simpler and more readable, but it is also possible to solve this problem iteratively in $O(n)$ time.

The iterative solution looks like this:

from math import ceil, log

def seq_max_branch_iter(nodes): 
  n = len(nodes) 
  height = ceil(log(n + 1, 2)) 
  left_sum = 0 
  right_sum = 0 
  for level in range(1, height): 
    width = 2 ** level 
    mid = -(-width // 2) 
    for left in range(0, mid): 
      li = (width - 1) + left 
      if li < n: 
        left_sum += nodes[li] 
    for right in range(mid, width): 
      ri = (width - 1) + right 
      if ri < n: 
        right_sum += nodes[ri] 
  return left_sum, right_sum

At first glance this might look like it’s $O(n^2)$ due to the nested loops, but the outer loop is iterating on levels of the tree, and then each sub loop is iterating on half of the nodes in that level, so the time complexity is actually still $O(n)$ as it visits each node once.

Using this index-labelled example again:

····································
················┌──┐················
00··············│00│················
················└──┘················
··········╱··············╲··········
········┌──┐············┌──┐········
01······│01│············│02│········
········└──┘············└──┘········
······╱······╲········╱······╲······
····┌──┐····┌──┐····┌──┐····┌──┐····
02··│03│····│04│····│05│····│06│····
····└──┘····└──┘····└──┘····└──┘····
····································

The first insight is that we can calculate the height of the tree based on the number of nodes with $\lceil log_2(n+1) \rceil$, or ceil(log(n + 1, 2)) in Python.

We can then iterate on levels of the tree, and add the first half of that level to the left sum, and the second half of that level to the right sum.

We know that each level contains up to $2^l$ nodes, for example the level at index 2 has up to $2^2 = 4$ nodes.

We can then calculate the mid-point of the level with -(-width // 2), which uses a little trick to do “ceil division” in Python. Alternatively this could be math.ceil(width / 2).

The last piece is that we know that each level has $w - 1$ nodes before it in the sequential list, where $w$ is the width of this level. For example, the level at index 2 has a width of 4, and has 3 nodes before it.

Then we just iterate the level’s nodes up to the mid-point adding to the left sum, and from the mid-point adding to the right sum. We need to check that a node actually exist at that index in the sequential list each time.

With this iterative function, we can then compare the left and right branch sums as with the recursive solution.


View post: Finding the max branch sum of a sequentially represented binary tree in Python