1 Evaluaton of Nested Function Calls
2 Continuation-Passing
3 CPS as a Monad
4 Continuations as a Language Feature

CSE P 505 Spring 2013 Lecture #5 Notes

1 Evaluaton of Nested Function Calls

In the first homework, we implemented append and reverse:

(define (append lst1 lst2)
  (if (empty? lst1)
      lst2
      (cons (first lst1) (append (rest lst1) lst2))))
 
(define (reverse lst)
  (if (empty? lst)
      empty
      (append (reverse (rest lst)) (list (first lst)))))

Since these procedures are purely functional, we can use substitution to understand how they evaluate. For example,

; Application expands to function body with var replaced by arg.
(reverse (list 1 2 3))
  =>
(if (empty? (list 1 2 3))
    empty
    (append (reverse (rest (list 1 2 3)))
            (list (first (list 1 2 3)))))
 
; (empty? (list 1 2 3)) => false
(if (empty? (list 1 2 3))
    empty
    (append (reverse (rest (list 1 2 3)))
            (list (first (list 1 2 3)))))
  =>
(if false
    empty
    (append (reverse (rest (list 1 2 3)))
            (list (first (list 1 2 3)))))
 
; Conditional selects 'else' branch.
(if false
    empty
    (append (reverse (rest (list 1 2 3)))
            (list (first (list 1 2 3)))))
  =>
(append (reverse (rest (list 1 2 3)))
        (list (first (list 1 2 3))))
 
; (rest (list 1 2 3)) => (list 2 3)
(append (reverse (rest (list 1 2 3)))
        (list (first (list 1 2 3))))
  =>
(append (reverse (list 2 3))
        (list (first (list 1 2 3))))
 
; *** Recursive call to (reverse (list 2 3)) proceeds similarly.
(append (reverse (list 2 3))
        (list (first (list 1 2 3))))
  =>
(append (append (reverse (list 3))
                (list (first (list 2 3))))
        (list (first (list 1 2 3))))
 
; *** Recursive call to (reverse (list 3)) proceeds similarly.
(append (append (reverse (list 3))
                (list (first (list 2 3))))
        (list (first (list 1 2 3))))
(append (append (append (reverse empty)
                        (list (first (list 3))))
                (list (first (list 2 3))))
        (list (first (list 1 2 3))))
 
; *** Base case: (reverse empty) => empty
(append (append (append (reverse empty)
                        (list (first (list 3))))
                (list (first (list 2 3))))
        (list (first (list 1 2 3))))
  =>
(append (append (append empty
                        (list (first (list 3))))
                (list (first (list 2 3))))
        (list (first (list 1 2 3))))
  => ...

And so forth.

The key thing to note is that, each time we make a recursive call to reverse, there’s a call to append waiting for the result of this call. These suspended calls build up a stack of pending sub-computations that don’t get discharged until we reach the base case of the recursion and start returning the results of the sub-problems.

In contrast, consider rev-append, which appends the reversal of a list with another list; rev-append uses its second argument as an accumulator, which holds the result of the largest sub-problem computed so far.

(define (rev-append lst1 lst2)
  (if (empty? lst1)
      lst2
      (rev-append (rest lst1) (cons (first lst) lst2))))

If we evaluate (rev-append (list 1 2 3) empty), the trace looks quite different:

; Application expands to function body with var replaced by arg.
(rev-append (list 1 2 3) empty)
  =>
(if (empty? (list 1 2 3))
    empty
    (rev-append (rest (list 1 2 3)) (cons (first (list 1 2 3)) empty)))
 
; (empty? (list 1 2 3)) => false
(if (empty? (list 1 2 3))
    empty
    (rev-append (rest (list 1 2 3)) (cons (first (list 1 2 3)) empty)))
  =>
(if false
    empty
    (rev-append (rest (list 1 2 3)) (cons (first (list 1 2 3)) empty)))
 
; Conditional selects 'else' branch.
(if false
    empty
    (rev-append (rest (list 1 2 3)) (cons (first (list 1 2 3)) empty)))
  =>
(rev-append (rest (list 1 2 3)) (cons (first (list 1 2 3)) empty))
 
; Function arguments are evaluated.
(rev-append (rest (list 1 2 3)) (cons (first (list 1 2 3)) empty))
  =>
(rev-append (list 2 3) (cons 1 empty))
 
; Similar steps for recursive call.
(rev-append (list 2 3) (cons 1 empty))
  =>
(rev-append (list 3) (cons 2 (cons 1 empty)))
 
; And next recursive call.
(rev-append (list 3) (cons 2 (cons 1 empty)))
  =>
(rev-append empty (cons 3 (cons 2 (cons 1 empty))))
 
; Now we're at the base case (lst1 is empty).
(rev-append empty (cons 3 (cons 2 (cons 1 empty))))
  =>
(cons 3 (cons 2 (cons 1 empty))) = (list 3 2 1)

Note the difference: when rev-append calls itself recursively, there’s nothing more to be done once the result is computed. That is, the result of the recursive call is itself the result of the original call. A subexpression with this property relative to the enclosing expression is said to be in tail position. When this is a function call, we say it’s a tail call, and when it’s a call to the same function we refer to it as tail recursion.

As this example shows, tail calls do not build up a stack of pending subcomputations. The Racket runtime exploits this fact and reuses the parent stack frame to evaluate a subexpression that’s in tail position. This is called tail call optimization or TCO. TCO is essential for functional languages, where "looping" is achieved through recursion; without it, reactive programs would quickly overflow the runtime stack. Most imperative/object-oriented languages (e.g., C, C++, Java, JavaScript, Python) do not implement TCO.

2 Continuation-Passing

One natural question we might ask after the preceding discussion is whether any recursive function can be written in this sort of accumulator style and made tail-recursive. For example, recall append from the first homework:

(define (append lst1 lst2)
  (if (empty? lst1)
      lst2
      (cons (first lst1) (append (rest lst1) lst2))))

We’d like to rewrite this so it has the following form:
(define (append/acc lst1 lst2 acc)
  (if (empty? lst1)
      ...
      (append/acc (rest lst1) lst2 ...)))

We might try to do something similar to what we did in rev-append, e.g. consing up a list of the elements we’ve traversed so far:
(define (append/acc lst1 lst2 acc)
  (if (empty? lst1)
      ...
      (append/acc (rest lst1) lst2 (cons (first lst) acc))))

However, all this will do is the same thing rev-append did, so when we get to the base case, we’ll have the reversal of the original lst1, and then we’ll have to traverse it again to cons its reversal back onto lst2. That would work but isn’t entirely satisfying (though it’s interesting to note how this approach takes advantage of the fact that Racket’s lists are a natural implementation of a stack).

Let’s think for a moment about what’s going on. We’re making a recursive call to append the rest of the first list onto the second list. What’s the relationship between the result of this call and the result we need to return? ...

The result we need to return is what we’d get by consing the first element of the lst1 onto the result of the recursive sub-problem. How can we express this relationship as a Racket value? How about as a function? For example,

(λ (v) (cons (first lst1) v))

is a function that exactly captures the transformation we need to perform to the result of the sub-problem. Let’s try to use this to help rewrite append.

This is incorrect. Don’t use this.

(define (append/acc lst1 lst2 acc)
  (if (empty? lst1)
      ...
      (append/acc (rest lst1) lst2 (λ (v) (cons (first lst) v)))))

Now, when we get to the base case, we apply the "accumulator" to lst1.

This is also wrong.

(define (append/acc lst1 lst2 acc)
  (if (empty? lst1)
      (acc lst2)
      (append/acc (rest lst1) lst2 (λ (v) (cons (first lst1) v)))))

We also need to figure out what to pass for the initial accumulator. This should be a function that transforms the result of the final call into the answer we want. In this case, these are the same thing, so we can pass the identity function (λ (x) x). Let’s see what happens when we call it:

(append/acc (list 1 2) (list 3) identity)
  =>
(append/acc (list 2) (list 3) (λ (v) (cons (first (list 1 2)) v)))
  =>
(append/acc empty (list 3) (λ (v) (cons (first (list 2)) v)))
  =>
((λ (v) (cons (first (list 2)) v)) (list 3))
  =>
(cons (first (list 2)) (list 3))
  =>
(cons 2 (list 3))
  =>
(list 2 3)

Oops. We lost an element somewhere. Looking back at the code, notice that we don’t refer to acc in the "else" branch. In other words, while we’re keeping track of what we need to do with the result of the call we make, we’ve neglected the fact that our caller gave us something that represented what it needed to do with our result. We need to hold up our end of the bargain and give our result to it.

This code is correct.

(define (append/acc lst1 lst2 acc)
  (if (empty? lst1)
      (acc lst2)
      (append/acc (rest lst1) lst2 (λ (v) (acc (cons (first lst1) v))))))

Now when we run it, we get the result we expect:

(append/acc (list 1 2) (list 3) identity)
  =>
(append/acc (list 2) (list 3) (λ (v) (identity (cons (first (list 1 2)) v))))
  =>
(append/acc empty (list 3)
        (λ (v) (identity ((λ (v) (cons (first (list 1 2)) v))
                          (cons (first (list 2)) v)))))
 
; This is the base case, so we apply the accumulator to the second list.
(append/acc empty (list 3)
        (λ (v) (identity ((λ (v) (cons (first (list 1 2)) v))
                          (cons (first (list 2)) v)))))
  =>
((λ (v) (identity ((λ (v) (cons (first (list 1 2)) v))
                   (cons (first (list 2)) v))))
 (list 3))
 
; Application becomes function body with arg substituted for var.
((λ (v) (identity ((λ (v) (cons (first (list 1 2)) v))
                   (cons (first (list 2)) v))))
 (list 3))
  =>
(identity ((λ (v) (cons (first (list 1 2)) v))
           (cons (first (list 2)) (list 3))))
 
; (first (list 2)) => 2
(identity ((λ (v) (cons (first (list 1 2)) v)) (cons (first (list 2)) (list 3))))
  =>
(identity ((λ (v) (cons (first (list 1 2)) v)) (cons 2 (list 3))))
 
; (cons 2 (list 3)) => (list 2 3)
(identity ((λ (v) (cons (first (list 1 2)) v)) (cons 2 (list 3))))
  =>
(identity ((λ (v) (cons (first (list 1 2)) v)) (list 2 3)))
 
; Another function call.
(identity ((λ (v) (cons (first (list 1 2)) v)) (list 2 3)))
  =>
(identity (cons (first (list 1 2)) (list 2 3)))
 
; (first (list 1 2) => 1
(identity (cons (first (list 1 2)) (list 2 3)))
  =>
(identity (cons 1 (list 2 3)))

Etc.

So, we succeeded in rewriting append in "accumulator style", but the word "accumulator" is sort of misleading. What we’re really passing around is a function that represents the rest of the computation once we’ve computed our result. Such a function is usually called a continuation, and code that has been rewritten like this is said to be in continuation-passing style, or CPS.

Continuations are typically named k (or something ending in -k), and we’ll distinguish functions that are in CPS by appending /k to their names. So, the final version of append looks like this:

(define (append/k lst1 lst2 k)
  (if (empty? lst1)
      (k lst2)
      (append/k (rest lst1) lst2 (λ (v) (k (cons (first lst1) v))))))

You might wonder what the point was in performing this transformation. Actually, there are a couple of points. First, transforming into CPS creates code where every function call (aside from primitives like +, cons, first, etc., which we assume to be atomic) is in tail position. This means our code no longer uses the Racket runtime stack at all; we are effectively managing it ourselves by passing around the continuation. The continuation is a representation of the stack as a closure. We’ve also seen previously how to represent closures as data structures, so this implies that, if we tried, we could represent the stack as a data structure.

The second point is that making the stack explicit like this gives us some capabilities that we didn’t have before. For example, recall the homework exercise of rewriting contains? with foldl.

(define (contains? lst thing)
  (foldl (λ (elt acc) (or acc (eq? elt thing))) false lst))

This is semantically equivalent to the hand-coded version we wrote in the first lecture, but in terms of performance it can be much worse. In particular, if the thing we’re looking for is near the beginning of the list, then the hand-coded version will complete very quickly, whereas the fold-based version always traverses the entire list. Let’s try rewriting this in CPS and see what happens.

(define (foldl/k fn/k acc lst k)
  (if (empty? lst)
      (k acc)
      (fn/k (first lst) acc
            (λ (new-acc) (foldl/k fn/k new-acc (rest lst) k)))))
 
(define (contains?/k lst thing k)
  (foldl/k (λ (elt acc k) (k (or acc (eq? elt thing)))) false lst k))

This is also equivalent to the code we had before, but in CPS. (Note how, in the recursive call to foldl/k, the continuation is the same as the one originally passed in. This reflects the fact that foldl is tail-recursive to begin with. In general, when a tail call is transformed into CPS, the new call uses the previous continuation.)

Returning to the contains? example, the key ideas is that, if we find the element we’re looking for, there’s no point in having foldl/k continue traversing the rest of the list. Since k is our continuation, it’s the calling of k that causes this traversal to occur. Maybe we can skip calling k to short-circuit all of that execution... Here’s an attempt:

(define (contains?/k lst thing k)
  (foldl/k (λ (elt acc k) (if (or acc (eq? elt thing))
                              true
                              (k false)))
           false lst k))

In this version, when we find the element we’re looking for, we bypass the continuation and simply return true. This works if the call to contains?/k is the only thing our program is doing. However, as written, the code is also bypassing the continuation that was originally passed to contains?/k; this means, effectively, that we’re escaping all the way out past the program’s "main" function and making true the result of the entire program. That can’t be right in general.

What we want is to escape from the rest of the foldl, but only to the continuation for contains?/k. Fortunately, since CPS has made our continuations explicit, this is simple: we just rename that k to something that’s not shadowed by the inner one, and then we can call it from our function.

(define (contains?/k lst thing k0)
  (foldl/k (λ (elt acc k) (if (or acc (eq? elt thing))
                              (k0 true)
                              (k false)))
           false lst k0))

This code does what we want, and it also illustrates one feature that continuations enable: non-local jumps in the program’s control flow, of which exceptions are one example. However, continuations are more general than exceptions, as we’ll soon see. First, though, let’s look at transforming some of the cases in our interpreter into CPS, to help build some intuition for this process.

The first rule is, anywhere that we’d simply have returned a value, we instead pass that value to the continuation.

(define (interp/k expr env k)
  (type-case Expr expr
    [numE (val) (k (numV val))]
    ...))

It’s tempting to try to do this everywhere, e.g.:

This code is incorrect.

(define (interp/k expr env k)
  (type-case Expr expr
    ...
    [addE (lhs rhs) (k (numV+ (interp/k lhs env ?) (interp/k rhs env ?)))]
    ...))

However, we’d then need to find something to pass to the recursive calls to interp/k. Clearly, k is not the right continuation—that would be like returning the result of each branch to our caller. Instead we could pass the identity function, and we’d get the right answer, but doing so would defeat the purpose of CPS. We’d be back to using the Racket stack again, so the "continuations" for the sub-problems would not reflect the full "rest of the computation". (It would be a little bit like spawning a new thread for each call to interp/k, which, while an interesting option to consider, is not what we’re trying to do.)

The second rule is, whenever we’re doing any non-trivial computation (e.g., a recursive call to interp/k), we need to pass a continuation. If the result of this subcomputation is the same thing that we’d return (i.e., it’s a tail call), then we can pass k directly. Otherwise, we pass a new continuation that will take the result of the sub-computation and finish the work. For the case of addE, the code looks like this:

(define (interp/k expr env k)
  (type-case Expr expr
    ...
    [addE (lhs rhs)
          (interp/k lhs env
                    (λ (lv)
                      (interp/k rhs env
                                (λ (rv) (k (numV+ lv rv))))))]
    ...))

Note the difference here: k is applied to something that’s (essentially) a value, not to an expression that involves any (non-trivial) computation.

One more interesting case is conditional expressions, where transforming into CPS helps to illustrate how the branch selected by the condition is in tail position with respect to the overall conditional.

(define (interp/k expr env k)
  (type-case Expr expr
    ...
    [if0E (test-e then-e else-e)
          (interp/k test-e env
                    (λ (tv)
                      (if (zero? tv)
                          (interp/k then-e env k)
                          (interp/k else-e env k))))]
    ...))

3 CPS as a Monad

You may have noticed that the threading around of continuations feels a bit similar to the way that we passed around the store in store-passing style. Recall how, with store-passing, we factored out the pattern into a type and functions that captured notions of identity and composition for operations of that type. With store-passing style, we defined a store-transformer type:

(define-type-alias (ST 'a) (Store -> (Result 'a)))

For continuation-passing, we have a different type:

(define-type-alias (CPS 'a) (('a -> 'b) -> 'b))

The type (CPS 'a) represents an expression that, prior to transformation, would have had type 'a. After transformation, it instead takes a continuation whose argument is of type 'a. Since the continuation represents the entire rest of the computation, which could return any sort of value depending on the context, we don’t know its type. Hence, we just say that its return type is some generic type 'b. However, the expression needs to call the continuation (or something else that will eventually call it), so its value (and so its type) is the same as that returned by the continuation.

In CPS, we can also easily define the standard unit and bind functions:

(define (cps-unit [v : 'a]) : (CPS 'a)
  (λ (k) (k v)))
 
(define (cps-bind [v/k : (CPS 'a)]
                  [fn/k : ('a -> (CPS 'b))]) : (CPS 'b)
  (λ (k)
    (v/k (λ (v) ((fn/k v) k)))))

Like with store-passing, if we rewrite with these combinators, we can remove the explicit mention of continuations from the code. For example:

(define (interp/k expr env)
  (type-case Expr expr
    [numE (val) (cps-unit (numV val))]
    [addE (lhs rhs)
          (cps-bind
           (interp/k lhs env)
           (λ (lv) (cps-bind
                    (interp/k rhs env)
                    (λ (rv) (cps-unit (numV+ lv rv))))))]
    [if0E (test-e then-e else-e)
          (cps-bind (interp/k test-e env)
                    (λ (tv)
                      (if (zero? tv)
                          (interp/k then-e env)
                          (interp/k else-e env))))]
    ...))

Note how the code looks the same as it did in store-passing style, except for the names of the unit and bind functions. This is not an accident: monads are a general framework for implementing features that require global program rewriting to install custom hooks. Once the program has been written in monadic style, different monad implementation can be substituted easily.

We can also define special operators that only make sense within the CPS monad. The only one we actually need is call-cc, which seems perhaps like it should just be the identity function, since that would result in the current continuation being passed to fn:

This code is incorrect.

(define (call-cc fn)
  (λ (k) (fn k)))

However, the result of (fn k) doesn’t conform to the CPS protocol. In particular, we need the function passed to fn to be something that consumes and ignores its current continuation, instead invoking the captured one. The correct code looks like this:

(define (call-cc fn)
  (λ (k0)
    ((fn (λ (v)
           (λ (k) (k0 v)))) k0)))

4 Continuations as a Language Feature

Transforming a program into CPS involves a great deal of work. Fortunately, Racket provides a way to get continuations implicitly. It provides this via a function, call-with-current-continuation, which (as the name implies) takes a function and calls it with the current continuation (i.e., the k that would have been there if we’d written the program in CPS). The function is generally abbreviated call/cc, since the original name is a lot to type. With call/cc, we can rewrite contains? much more directly:

(define (contains? lst thing)
  (call/cc
   (λ (k)
     (foldl (λ (elt acc) (if (or acc (eq? elt thing))
                             (k true)
                             false))
            false lst))))

Note the difference here: for ordinary control flow, we no longer need to pass continuations around at all. It’s only in cases where we want to perform non-local control flow operations that we need to capture or invoke a continuation.

Here’s a more interesting example:

(define (proc-a)
  (begin
    (display 'a1)
    (switch)
    (display 'a2)
    (switch)
    (display 'a3)
    (switch)))
 
(define (proc-b _)
  (begin
    (display 'b1)
    (switch)
    (display 'b2)
    (switch)
    (display 'b3)
    (switch)))
 
(define other : (boxof (void -> 'a)) (box proc-b))
 
(define (switch)
  (call/cc
   (λ (k)
     (let ([other-k (unbox other)])
       (set-box! other k)
       (other-k (void))))))

What happens if we run (proc-a)? For sure, it begins by displaying 'a1, and then it calls switch, which captures the current continuation, puts it in the other box, and calls the procedure that used to be there (in this case, proc-b). This proceeds to display 'b1 and then call switch again. This stashes away the current continuation and calls the one that proc-a put there before. This "jumps out" of proc-b and back into where we were when we called switch from proc-a, so it next displays 'a2. The two routine continue "ping-ponging" like this, until each has printed the three symbols.

For another puzzler, try to figure out what happens if you remove the last call to switch from proc-b. (Hint: think about what the stack would have looked like the first time switch was called, and remember that the initial content of other-k was proc-b, which is an ordinary procedure, not a continuation.)

We also covered in class an example involving generators, but this is described in the text book, so we won’t repeat it here.