Using functional python tools to merge several Binary Trees together.
There is a classic programming interview question that asks us to merge two Binary Trees.
Below is one possible setup, borrowed from the official LeetCode problem description Merge Two Binary Trees:
You are given two binary trees
root1
androot2
.Imagine that when you put one of them to cover the other, some nodes of the two trees are overlapped while the others are not. You need to merge the two trees into a new binary tree. The merge rule is that if two nodes overlap, then sum node values up as the new value of the merged node. Otherwise, the NOT null node will be used as the node of the new tree.
Return the merged tree.
Note: The merging process must start from the root nodes of both trees.
Like many BST problems, this one is a natural fit for a recursive solution where we consider the following scenarios:
The general intuition to solve this problem is:
What will these steps look like in code?
We can translate the publicly available Java implementation to arrive at the following python solution:
#|output: true
#|echo: false
# first we import the typing helpers and define the TreeNode
from typing import Optional, List
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def mergeTrees(self, t1: Optional[TreeNode], t2: Optional[TreeNode]) -> Optional[TreeNode]:
# Base cases:
## 1) The first tree is null, return the second tree
## 2) The second tree is null, return the first tree
if (t1 is None):
return t2
if (t2 is None):
return t1
# If we make it here, then there are two valid nodes we have to merge
# Merge the nodes (add the value from the first into the second)
t1.val += t2.val
# Now merge the left and right subtrees. NOTE: this is recursive call
t1.left = self.mergeTrees(t1.left, t2.left)
t1.right = self.mergeTrees(t1.right, t2.right)
# At the end of the recursive stack, t1 will be the root of the valid, merged tree.
return t1
If a matching, overlapping node exists in both trees, then we add their values together.
If a node exists in one tree but not the other, then we take the value from the existing node.
Once all nodes have been visited, then the trees are fully merged and we are done.
It turns out that we can leverage some functional tools from python to make the solution above even more general.
Specifically, we will use python's functional map
and lambda
, together with getattr
and sequence expansion via *
, to merge an arbitrary number of Binary Trees.
class Solution:
def mergeTrees(self, *args: Optional[List[TreeNode]]) -> Optional[TreeNode]:
# Base case: all trees are empty, we have nothing to merge
if not any(args): return None
# Get the values of every matched overlapping node, and sum them together.
vals = map(lambda n: getattr(n, 'val', 0), args)
node = TreeNode(sum(vals))
# Create the left child from the merged left-subtrees
node.left = self.mergeTrees(*map(lambda n: getattr(n, 'left', None), args))
# Create the right child from the merged right-subtrees
node.right = self.mergeTrees(*map(lambda n: getattr(n, 'right', None), args))
# Return the new, merged tree
return node
This solution is more general at the cost of more memory: we create a new TreeNode
instead of adding to an existing node's value.
However, this still follows the problem's constraints that we return a "new binary tree". In our more general solution, the returned node
at the top of the recursive stack will be the root of a new binary tree.
The Binary Tree image for this post is from the good folks at Codiwan.