Tag Archives: tail recursion

Recursion in Clojure

One of the hardest things for a non-functional programmer to get used to in a functional language like Clojure is the absence of loops. We OO and procedural programmers are so used to calling for loops to do any sort of repetitive task that when this is pulled out from under our feet, we’re left floundering.

Clojure replaces loops with recursion, which is the process of a function calling itself. Most OO and procedural programming languages such as Java also support recursion, but it is not often used except in specialized situations where the algorithms being implemented are explicitly recursive (binary trees are a good example of this). Recursion, we are told, is an inefficient way of implementing reptitive tasks and should be avoided unless it really does clarify the code.

The simplest examples of recursion typically involve elementary arithmetic, so we’ll have a look at one such problem: raising a number to an integer power. Suppose we wish to raise the number x to the power y. A typical bit of procedural code for doing this would be:

double result = 1.0;
for (int i = 0; i < y; i++)
  result = result * x;

After the loop, the answer is stored in result.

How would we do this recursively? There are two main parts of a recursive algorithm. First, we need a base or anchor condition, which provides a termination condition for the algorithm. Second, we need a recursive step, in which a function will call itself. The recursive step must be such that each step takes the calculation closer to the anchor condition.

In the case of raising x to the power y, the anchor step is when y has the value 0. In this case, the answer will always be 1 (unless x=0, in which case the answer isn’t defined, but we’ll assume x isn’t 0 here). For the recursive step, we note that (assuming y>0) x to the power y is x times x to the power (y-1). That is, if the function we’re evaluating is f(x,y), then f(x,y)=x*f(x,y-1). We can see that this takes y one step closer to the anchor condition of y=0.

Our first implementation of such a function in Clojure might look like this (the reason for calling the function bad_power will become apparent afterwards):

(defn bad_power
  ([x y]
    (if (= y 0)
      (* x (bad_power x (- y 1))))))

This function introduces the if statement, which has the general form (if a b c). Here ‘a’ is the condition to test, ‘b’ is what to return if ‘a’ is true, and ‘c’ is an optional third argument which is returned if ‘a’ is false.

Thus bad_power takes two arguments, x and y. The if statement tests if the anchor condition is true (is y=0?). If it is, we return 1 immediately and that’s the end of the function. If it isn’t, then the last line is run. This multiplies x by a recursive call to bad_power with arguments of x and y-1.

Running this program produces the expected result, as in (bad_power 2 3) gives the answer 8.

Now, why have we called this function bad_power? It seems to work well enough. The answer is a bit technical, and deals with the way recursion is implemented.

In a running program, each function call pushes the current state of the program onto a stack, with execution jumping to a new instance of the function that is called. When the function call ends, the previous state of the program is popped off the stack and execution continues from that point. This all works well enough if function calls are sequential: program starts, function A is called, then returns, then function B is called and returns and so on. In this case, the stack never gets very deep since each function ends before the next one is called.

In the case of recursion, however, a function will call itself a potentially large number of times before any of these calls returns, and for each call, a separate instance of the program state must be saved on the stack. In the bad_power function, there will be y instances of the function called before any of them returns, so if we want to calculate a large power, the stack will get very deep.

Any computer will have a limit to the size of this stack, so if we give it too many recursive calls, we’ll get a stack overflow. The limit of a stack varies from one machine to another, but usually you’ll hit the limit with around 5000 calls. Try it with this function by giving it a large exponent. On my machine raising 2 to the power of 2000 runs OK, but 3000 causes a stack overflow.

We can see that this problem doesn’t arise with a loop in a language like Java, since no function calls are made, and the result is built up by saving the value in a single variable. Since loops are to be replaced by recursion in Clojure, and some loops can be many thousands of iterations in length, how can we get recursion to work properly without causing a stack overflow?

Tail recursion

The answer is that recursive algorithms have to be written in a special way called tail recursion. If we examine the algorithm above, we see that the result of each recursive call is needed, since this result is multiplied by x before being passed back up to the next function in the chain. If we could rewrite the algorithm in a way that didn’t require saving each recursive call until the end of the recursion, then we could throw away each function after it made its recursive call to the next function in the chain. The final answer is obtained when the bottom of the recursive chain is reached, so there is no need to climb all the way back up to the top of the chain again, as there was in the example above.

If we can arrange things in this way, the resulting algorithm is called tail recursive. How can we rewrite the bad_power function to be tail recursive? The answer is as follows, with additions to allow negative exponents to be handled as well:

(defn power
  ([x y] (power x y 1))
  ([x y current]
  (if (= y 0)
    (if (> y 0)
      (power x (- y 1) (* x current))
      (power x (+ y 1) (/ current x))))))

The function uses multiple arity, so we’d better explain that first. Basically multiple arity allows a function to take a variable number of arguments (so it’s like function overloading in Java). In this case, power can take either two or three arguments. If it is given two arguments, the code on line 2 is run, which is a recursive call to power, but this time with 3 arguments. This will call the code starting on line 3, in which the symbol current is initialized to 1, and x and y are just passed along as they are.

On line 4, we test the anchor condition as before, but this time, if y=0, we return current. We can see this works if the initial value of y is 0, since it will just return 1.

If y isn’t 0, we then test to see if y>0. If so, we then recursively call power, passing x along unchanged, but reducing y by 1, as before. This time, however, we perform the incremental calculation by multiplying current by x and passing this along as the new value of current in the recursive call. Note that the return value of this recursive call is not used in any further calculations, so the value of this call need not be saved. All the required information is passed along to the next recursive call in the chain, so we can just forget about the current function. This means that it need not be stored on the stack.

The last line of this function adds some code that calculates x raised to a negative exponent, but it works in exactly the same way so it too is tail recursive.

Now, if you run this code, you might expect to be able to calculate much higher exponents than with bad_power, but you’d be disappointed. It still chokes on high exponent values. Why?

It turns out that the Clojure compiler does not automatically optimize tail recursive functions; it must be told to do so. This is easy enough to do: we replace the explicit name of the function in the recursive call by the keyword recur. The final version of the function thus looks like this:

(defn power
  ([x y] (power x y 1))
  ([x y current]
  (if (= y 0)
    (if (> y 0)
      (recur x (- y 1) (* x current))
      (recur x (+ y 1) (/ current x))))))

The only changes we have made are the insertion of the recur keywords on lines 7 and 8. Now, finally, if you run this program, you can enter enormous exponents and the program doesn’t choke.

Incidentally, you might notice that Clojure can handle some truly huge numbers. Raising 2 to the power of 5000 gives an answer of 14124670321394260368352096670161473336688961751845411168136880858
























I’ll take the program’s word for it that is correct (at least it’s an even number so that much is correct).

You’ll also note that, for negative exponents, if you enter x as an integer, the result is returned as a fraction rather than a decimal number. Clojure deals with rational numbers separately from decimal expansions of them, so if you want the decimal representation, enter x as a floating point number, as in (power 2.0 -2), which gives an answer of 0.25.