Polymorphic type inference in ML

How does type inference work?

ML infers polymorphic types, just as it infers non-polymorphic types, with very little user input. In fact, ML nearly always infers the "most general" type of a value --- its principal type. How does polymorphic type inference work? Informally, it goes something like this:

  1. ML assigns a fresh type variable to each (sub)expression and (sub)pattern.
  2. Each use of an expression or pattern may imply some property of that entity's type: this is a type constraint. ML builds up a set of type constraints from the smallest expressions/patterns to progressively larger ones.
  3. ML solves the constraints, with three possible results:

When doing type inference by hand, we can approximate this algorithm by assigning a fresh type variable to each bound name and each function return value, and solving type constraints until there is nothing more to do.

Examples

Example 1: Inferring a uniquely determined type

Here's a concrete example, in which the constraints are "tight" enough that all the type variables are uniquely determined:

fun myFun (w, x, y, z) =
     if y then w::tl(x) else y::z;

We begin by assigning fresh type variables to the first name we encounter, myFun:

myFun : 'a

Read this as "myFun has type 'a." Then, we make the following observations:

Example 2: Inferring a polymorphic type

In this example, which these notes present in slightly less detail, the types are unconstrained and result in a polymorphic type. Consider the (curried) function:

fun foo w1 nil    _ = (w1, nil)
  | foo w2 (v::x) y =
    let
        val bar = (w2, y v);
        val bif = v :: tl(x);
    in
        (#1(bar), bif)
    end;

Type inference proceeds as follows:

Unification

In previous lectures, I have occasionally referred informally to "unifying types". My hope was that you would use your intuitive sense of the English definition of this word until the time came to define it more precisely, which is now. Consider the expression:

fun ... =
    let
        ...
        val (x, y) = (3, z);
    in ...

where z is defined earlier in the function. Suppose that we had assigned the fresh type variable 'a to z. Our first step in inferring the types of x and y would be to set up a type equation like the following:

('b * 'c) = (int * 'a)

In order to proceed further, we must attempt to unify the left- and right-hand sides of this equation. Although the result seems obvious at a glance, it is useful to consider more precisely the exact procedure for unifying two types. To wit, here is a function in ML-like pseudocode that determines whether two values can be unified:

fun unify(t1, t2) =
   if t1 and t2 are base types then
      if t1 = t2 then true else false
   else if t1 is a type variable
           or t2 is a type variable then
      return true (adding constraint t1 = t2)
   else if t1 and t2 are compound types then
      if t1 and t2 have different constructors then
          false
      else
          if each corresponding element of t1 and t2 can be unified then
             true
          else
             false

This is a simplification of real unification, which returns not only true or false, but a set of type equalities resulting from successful unifications. Nevertheless, it gives us the intuition for how to unify the tuple example above:

  1. First, observe that the types are the same constructor (2-tuple). More precisely, the constructors have the same kind (tuple) and arity (2); for record constructors, we would also check that they have the same fields.
  2. Second, recursively attempt to unify the corresponding positions in each tuple:
  3. Third, becuse all recursive-unifications succeed, return true (and keep all type equalities accumulated while performing recursive unification). The unified type is int * 'a.

Here is a slightly more interesting type equation:

'a * ('b * 'c) * 'c list = ('v * 'w) * ('x * 'y) * 'z

We begin with the top-level type constructors, which are 3-tuples, and produce the following unifications:

unify('a, ('w * 'x))
unify(('b * 'c), ('x * 'y))
unify('c list, 'z)

For the first and third of these, at least one of the conclusions is a variable, so we bottom out; but for the second we can proceed and recursively unify:

unify('b, 'x)
unify('c, 'y)

We bottom out and return. All told, we have the following type equations:

'a = 'w * 'x
'b = 'x
'c = 'y
'c list = 'z

Solving all these equations, the unified type is:

('v * 'w) * ('b * 'c) * 'c list

which, if this were the type inferred from some expression in SML/NJ, the read-eval-print loop would rename to

('a * 'c) * ('d * 'e) * 'e list

Pattern matching vs. unification

Aside: You should see some similarity between type unification and pattern-matching, which also recursively matches two trees. Pattern-matching unifies patterns with values, whereas type unification unifies types with types, but the recursive descent through the trees based on constructors is quite similar. In fact, pattern-matching is a form of unification algorithm.

This similarity may be clearer if you do the following exercise: first, evaluate the following pattern-match by hand, using pen and paper to show the correspondence between pattern trees and value trees...

val (a, b, {c=d, e=f}) = ("u", ("v", "w"), {c="x", e=("y", "z")});

...and then draw the process for performing the following type unification, showing correspondences between types, in the same fashion:

'a * 'b * {c:'d, e:'f} = 'u * ('v * 'w) * {c:'x, e:('y, 'z)}

Equality types

The regular 'a syntax means "any type" --- as noted in previous lectures, we say that these variables are universally quantified. However, sometimes you want to be able to write a function over only those types that satisfies some property --- perhaps all those types that define an equality operator =, or a less-than operator <. This is called bounded polymorphism, as opposed to universal polymorphism, because there's a "bound" (limit) on those types at which the type variable can be instantiated.

Unfortunately, ML does not provide any form of general bounded polymorphism --- you cannot write a type that says, "This function applies to all types that have a prettyPrint function defined." However, ML does hard-code one form of bounded polymorphism, using equality types. An equality type is either:

An equality type variable is written with two quotes, e.g. ''a. For example, consider the type of the following function, which searches for a value in a list:

- fun search (aValue, nil) = false
  | search (aValue, x::xs) =
    if aValue = x then true else search (aValue, xs);
val search = fn : ''a * ''a list -> bool

This function takes an argument of ''a * ''a list --- that is, a value of equality type, and a list of the same equality type. Attempting to apply this to a non-equality type will fail with a type error:

- search (1.0, [1.0, 2.0, 3.0]);
stdIn:1.1-20.10 Error: operator and operand don't agree
    [equality type required]
  operator domain: ''Z * ''Z list
  operand:         real * real list
  in expression:
    search (1.0,1.0 :: 2.0 ::  :: )
- search ((1.0, 2.0), [(1.0, 2.0), (3.0, 4.0)]);
stdIn:1.1-20.26 Error: operator and operand don't agree
    [equality type required]
  operator domain: ''Z * ''Z list
  operand:         (real * real) * (real * real) list
  in expression:
    search ((1.0,2.0),(1.0,2.0) :: (,) :: nil)

Supplemental exercises

Manually infer the types in the following declarations. If the expression implies type constraints that are overconstrained (cannot be satisfied by any type assignment), then explain why.