Tail Calls
Consider the factorial function below:
When we make the call fac(3)
, two recursive calls are made: fac(2, 3)
and
fac(1, 6)
. The last call returns 6
, then fac(2, 3)
returns 6
, and
finally the original call returns 6
. I would recommend looking at the execution
in Python Tutor:
If you look carefully, you can see that first a huge call stack is created,
then a base case is reached, and then the return value is simply bubbled back
up to the fac(3)
call, which simply hands that value back to the global
frame. This happens because after the recursive call is made by the caller,
no further computation needs to be done by the caller. This kind of function
call is called a tail call, and languages like Haskell, Scala, and Scheme can
avoid keeping around unnecessary stack frames in such calls. This is called
tail call optimization (TCO) or tail call elimitation.
This is useful because the computation of fac(n)
without TCO requires
$\mathcal{O}(n)$ space to hold the $n$ stack frames and for large $n$, this
causes the stack to overflow, whereas with TCO this would take $\mathcal{O}(1)$
memory, since a constant number of stack frames is used regardless of $n$.
The optimized code should look much like the iterative version of factorial below:
As you can see below, this only creates a constant number of (one) stack frame:
Of course, this code uses a loop and mutation, so as a diligent functional programmer I will deride it and instead suggest that we restrict such behavior to a single function and abstract it away behind a decorator, so that we can make pristine tail calls in Python and also not blow away the stack.
Tail Recursive Functions to Loops
Notice that the variables n
and acc
are the ones that change in every
iteration of the loop, and those are the parameters to each tail recursive
call. So maybe if we can keep track of the parameters and turn each recursive
call into an iteration in a loop, we will be able to avoid recursive calls.
The decorator should be a higher-order function which takes in a function fn
and returns an inner function which when called, calls fn
, but with some
scaffolding. fn
must follow a specific form: it must return something which
instructs the inner function (often called the trampoline function) whether it
wants to recurse or return. For this, we need two classes representing the two
cases:
fn
should return an instance of TailCall
when it wants to make a tail
recursive call, and it should feed the arguments of the next call into the
instance. When it wants to simply return without making a recursive call,
it should return an instance of Return
, which wraps the return value. The
decorator looks like this:
Finally, fac
looks like this:
And thus we have achieved the functional ideal: restricting mutation and loops
to a single location, which in this case is the decorator tco
, without any
(severe) overhead. (Note that a good compiler would look at the original fac
and replace the entire function body with a loop to guarantee zero overhead.)
Notice how there is only a single stack frame belonging to the function fac
at any point in time. This will let you compute fac(1000)
and beyond without
a stack overflow error!
And this is how you implement tail call optimization in a language which does not have native support for it. Below is a Github Gist with all the code, some examples, and static types.