Tag Archives: tail recursion

Recursion in Clojure

One of the hardest things for a non-functional programmer to get used to in a functional language like Clojure is the absence of loops. We OO and procedural programmers are so used to calling for loops to do any sort of repetitive task that when this is pulled out from under our feet, we’re left floundering.

Clojure replaces loops with recursion, which is the process of a function calling itself. Most OO and procedural programming languages such as Java also support recursion, but it is not often used except in specialized situations where the algorithms being implemented are explicitly recursive (binary trees are a good example of this). Recursion, we are told, is an inefficient way of implementing reptitive tasks and should be avoided unless it really does clarify the code.

The simplest examples of recursion typically involve elementary arithmetic, so we’ll have a look at one such problem: raising a number to an integer power. Suppose we wish to raise the number x to the power y. A typical bit of procedural code for doing this would be:

double result = 1.0;
for (int i = 0; i < y; i++)
{
  result = result * x;
}

After the loop, the answer is stored in result.

How would we do this recursively? There are two main parts of a recursive algorithm. First, we need a base or anchor condition, which provides a termination condition for the algorithm. Second, we need a recursive step, in which a function will call itself. The recursive step must be such that each step takes the calculation closer to the anchor condition.

In the case of raising x to the power y, the anchor step is when y has the value 0. In this case, the answer will always be 1 (unless x=0, in which case the answer isn’t defined, but we’ll assume x isn’t 0 here). For the recursive step, we note that (assuming y>0) x to the power y is x times x to the power (y-1). That is, if the function we’re evaluating is f(x,y), then f(x,y)=x*f(x,y-1). We can see that this takes y one step closer to the anchor condition of y=0.

Our first implementation of such a function in Clojure might look like this (the reason for calling the function bad_power will become apparent afterwards):

(defn bad_power
  ([x y]
    (if (= y 0)
      1
      (* x (bad_power x (- y 1))))))

This function introduces the if statement, which has the general form (if a b c). Here ‘a’ is the condition to test, ‘b’ is what to return if ‘a’ is true, and ‘c’ is an optional third argument which is returned if ‘a’ is false.

Thus bad_power takes two arguments, x and y. The if statement tests if the anchor condition is true (is y=0?). If it is, we return 1 immediately and that’s the end of the function. If it isn’t, then the last line is run. This multiplies x by a recursive call to bad_power with arguments of x and y-1.

Running this program produces the expected result, as in (bad_power 2 3) gives the answer 8.

Now, why have we called this function bad_power? It seems to work well enough. The answer is a bit technical, and deals with the way recursion is implemented.

In a running program, each function call pushes the current state of the program onto a stack, with execution jumping to a new instance of the function that is called. When the function call ends, the previous state of the program is popped off the stack and execution continues from that point. This all works well enough if function calls are sequential: program starts, function A is called, then returns, then function B is called and returns and so on. In this case, the stack never gets very deep since each function ends before the next one is called.

In the case of recursion, however, a function will call itself a potentially large number of times before any of these calls returns, and for each call, a separate instance of the program state must be saved on the stack. In the bad_power function, there will be y instances of the function called before any of them returns, so if we want to calculate a large power, the stack will get very deep.

Any computer will have a limit to the size of this stack, so if we give it too many recursive calls, we’ll get a stack overflow. The limit of a stack varies from one machine to another, but usually you’ll hit the limit with around 5000 calls. Try it with this function by giving it a large exponent. On my machine raising 2 to the power of 2000 runs OK, but 3000 causes a stack overflow.

We can see that this problem doesn’t arise with a loop in a language like Java, since no function calls are made, and the result is built up by saving the value in a single variable. Since loops are to be replaced by recursion in Clojure, and some loops can be many thousands of iterations in length, how can we get recursion to work properly without causing a stack overflow?

Tail recursion

The answer is that recursive algorithms have to be written in a special way called tail recursion. If we examine the algorithm above, we see that the result of each recursive call is needed, since this result is multiplied by x before being passed back up to the next function in the chain. If we could rewrite the algorithm in a way that didn’t require saving each recursive call until the end of the recursion, then we could throw away each function after it made its recursive call to the next function in the chain. The final answer is obtained when the bottom of the recursive chain is reached, so there is no need to climb all the way back up to the top of the chain again, as there was in the example above.

If we can arrange things in this way, the resulting algorithm is called tail recursive. How can we rewrite the bad_power function to be tail recursive? The answer is as follows, with additions to allow negative exponents to be handled as well:

(defn power
  ([x y] (power x y 1))
  ([x y current]
  (if (= y 0)
    current
    (if (> y 0)
      (power x (- y 1) (* x current))
      (power x (+ y 1) (/ current x))))))

The function uses multiple arity, so we’d better explain that first. Basically multiple arity allows a function to take a variable number of arguments (so it’s like function overloading in Java). In this case, power can take either two or three arguments. If it is given two arguments, the code on line 2 is run, which is a recursive call to power, but this time with 3 arguments. This will call the code starting on line 3, in which the symbol current is initialized to 1, and x and y are just passed along as they are.

On line 4, we test the anchor condition as before, but this time, if y=0, we return current. We can see this works if the initial value of y is 0, since it will just return 1.

If y isn’t 0, we then test to see if y>0. If so, we then recursively call power, passing x along unchanged, but reducing y by 1, as before. This time, however, we perform the incremental calculation by multiplying current by x and passing this along as the new value of current in the recursive call. Note that the return value of this recursive call is not used in any further calculations, so the value of this call need not be saved. All the required information is passed along to the next recursive call in the chain, so we can just forget about the current function. This means that it need not be stored on the stack.

The last line of this function adds some code that calculates x raised to a negative exponent, but it works in exactly the same way so it too is tail recursive.

Now, if you run this code, you might expect to be able to calculate much higher exponents than with bad_power, but you’d be disappointed. It still chokes on high exponent values. Why?

It turns out that the Clojure compiler does not automatically optimize tail recursive functions; it must be told to do so. This is easy enough to do: we replace the explicit name of the function in the recursive call by the keyword recur. The final version of the function thus looks like this:

(defn power
  ([x y] (power x y 1))
  ([x y current]
  (if (= y 0)
    current
    (if (> y 0)
      (recur x (- y 1) (* x current))
      (recur x (+ y 1) (/ current x))))))

The only changes we have made are the insertion of the recur keywords on lines 7 and 8. Now, finally, if you run this program, you can enter enormous exponents and the program doesn’t choke.

Incidentally, you might notice that Clojure can handle some truly huge numbers. Raising 2 to the power of 5000 gives an answer of 14124670321394260368352096670161473336688961751845411168136880858

5711816984270751255808912631671152637335603208431366082764203838

0699793383359711857266399234310517778518653990118779996451317070

69373498212631323752553111215372844035950900535954860733418453405

575566736801565587405464699640499050849699472357900905617571376

61822821643421318152099155667712649865178220417406183093923917686

1341383294018240225838692725596147005144243281075275629495339093

81319896673563360632969102384245412583588865687313398128724098000

8838073668221804264432910894030789020219440578198488267339768238

872279902157420307247570510423845868872596735891805818727796435

7530185180866413560128513025467268230092502183280182519073402454

49863183265637987862198511046362985461949587281119139907228004385

94288095395881655456762529608691688577482893444994136241658867532

69403325611036645569826222068344742198110818724049295034819913767

4037982599879141187980271758388549857511529947174346924111707023

0398103378615232793710290992656444842895511830355733152020804157

9200900418119518804567055154683494461827317423276859892776076207

09525878318766488368348965015474997864119765441433356928012344111

7657353363935578792149370043475682086659587177640592935928875142

92843557047089164876483116615691886203812997555690171892169733755

224469032475078797830901321579940127337210694377283439922280274

0607982347867404348934581201983411010338125067200466098911607002

8400210098045296403978870433530261933759786205219228037148113216

4147186514169090917191909376

I’ll take the program’s word for it that is correct (at least it’s an even number so that much is correct).

You’ll also note that, for negative exponents, if you enter x as an integer, the result is returned as a fraction rather than a decimal number. Clojure deals with rational numbers separately from decimal expansions of them, so if you want the decimal representation, enter x as a floating point number, as in (power 2.0 -2), which gives an answer of 0.25.