5 of 7
Write a function count_nodes(root) that returns the total number of nodes in a binary tree.
# 1 # / \ # 2 3 # / \ \ # 4 5 6 root = TreeNode(1) root.left = TreeNode(2) root.right = TreeNode(3) root.left.left = TreeNode(4) root.left.right = TreeNode(5) root.right.right = TreeNode(6) print(count_nodes(root)) # 6
# 1 print(count_nodes(TreeNode(1))) # 1
print(count_nodes(None)) # 0
| Complexity | |
|---|---|
| Time | O(n) β must visit every node |
| Space | O(h) β recursion stack depth |
root is None, return 0 (no nodes here)1 + left_count + right_count (the 1 is for the current node)Before writing code, answer these questions:
01 for rootThis follows the exact same recursive template as max_depth:
def count_nodes(root): if root is None: return ??? left = count_nodes(root.left) right = count_nodes(root.right) return ???
def count_nodes(root): if root is None: return 0 return 1 + count_nodes(root.left) + count_nodes(root.right)
Why this works: Each node returns "I am 1, plus however many nodes are in my left and right subtrees." Leaf nodes return 1 + 0 + 0 = 1. The total accumulates up to the root.