Merging an arbitrary number of Binary Trees

Binary Tree



November 8, 2022

Using functional python tools to merge several Binary Trees together.


There is a classic programming interview question that asks us to merge two Binary Trees.

Below is one possible setup, borrowed from the official LeetCode problem description Merge Two Binary Trees:

You are given two binary trees root1 and root2.

Imagine that when you put one of them to cover the other, some nodes of the two trees are overlapped while the others are not. You need to merge the two trees into a new binary tree. The merge rule is that if two nodes overlap, then sum node values up as the new value of the merged node. Otherwise, the NOT null node will be used as the node of the new tree.

Return the merged tree.

Note: The merging process must start from the root nodes of both trees.

Approaching the problem.

Like many BST problems, this one is a natural fit for a recursive solution where we consider the following scenarios:

  1. The base case(s): when to return and start working up the recursive stack.
  2. If we are not in a base case, what specific actions must we take?
  3. Then, call the function on the remaining sub-problems, usually the children of the current node.

The intuition to merge two Binary Trees.

The general intuition to solve this problem is:

  1. Overlay the two trees together, starting from their root nodes.
  2. Then, merge the values of the root nodes.
  3. Finally, merge both the left and and right subtrees in the same way.

What will these steps look like in code?

Merging only two BSTs

We can translate the publicly available Java implementation to arrive at the following python solution:

class Solution:
    def mergeTrees(self, t1: Optional[TreeNode], t2: Optional[TreeNode]) -> Optional[TreeNode]:
        # Base cases:
        ## 1) The first tree is null, return the second tree
        ## 2) The second tree is null, return the first tree
        if (t1 is None):
            return t2
        if (t2 is None):
            return t1
        # If we make it here, then there are two valid nodes we have to merge
        # Merge the nodes (add the value from the first into the second)
        t1.val += t2.val
        # Now merge the left and right subtrees. NOTE: this is recursive call
        t1.left = self.mergeTrees(t1.left, t2.left)
        t1.right = self.mergeTrees(t1.right, t2.right)
        # At the end of the recursive stack, t1 will be the root of the valid, merged tree.
        return t1

If a matching, overlapping node exists in both trees, then we add their values together.

If a node exists in one tree but not the other, then we take the value from the existing node.

Once all nodes have been visited, then the trees are fully merged and we are done.

Merging an arbitrary number of Binary Trees

It turns out that we can leverage some functional tools from python to make the solution above even more general.

Specifically, we will use python’s functional map and lambda, together with getattr and sequence expansion via *, to merge an arbitrary number of Binary Trees.

class Solution:
    def mergeTrees(self, *args: Optional[List[TreeNode]]) -> Optional[TreeNode]:
        # Base case: all trees are empty, we have nothing to merge
        if not any(args): return None
        # Get the values of every matched overlapping node, and sum them together.
        vals = map(lambda n: getattr(n, 'val', 0), args)
        node = TreeNode(sum(vals))
        # Create the left child from the merged left-subtrees
        node.left = self.mergeTrees(*map(lambda n: getattr(n, 'left', None), args))
        # Create the right child from the merged right-subtrees
        node.right = self.mergeTrees(*map(lambda n: getattr(n, 'right', None), args))

        # Return the new, merged tree        
        return node

This solution is more general at the cost of more memory: we create a new TreeNode instead of adding to an existing node’s value.

However, this still follows the problem’s constraints that we return a “new binary tree”. In our more general solution, the returned node at the top of the recursive stack will be the root of a new binary tree.


The Binary Tree image for this post is from the good folks at Codiwan.