ML infers polymorphic types, just as it infers non-polymorphic types, with very little user input. In fact, ML nearly always infers the "most general" type of a value --- its principal type. How does polymorphic type inference work? Informally, it goes something like this:
fn (x, y) => x + y
is
assumed to have type (int * int) -> int
.When doing type inference by hand, we can approximate this algorithm by assigning a fresh type variable to each bound name and each function return value, and solving type constraints until there is nothing more to do.
Here's a concrete example, in which the constraints are "tight" enough that all the type variables are uniquely determined:
fun myFun (w, x, y, z) = if y then w::tl(x) else y::z;
We begin by assigning fresh type variables to the first name
we encounter, myFun
:
myFun : 'a
Read this as "myFun
has type 'a
."
Then, we make the following observations:
f
is used as a function; therefore, add the
type constraint:
myFun : 'a = 'b -> 'cRead this "
myFun
has type 'a
, which
equals 'b -> 'c
". Note that we make fresh type
variables for the argument and result types, because we do not
know what they are.f
's argument pattern is a tuple of four
elements, whose patterns are bound names. Assign these names
fresh type variables, and add a type equality for the argument
type:
'b = ('d * 'e * 'f * 'g) w : 'd x : 'e y : 'f z : 'g
y
is used as the condition in a if
statement; therefore, add the type constraint:
y : 'f = bool
tl
to x
;
therefore, x
must be a list. Add the following
type constraints:
x : 'e = 'h list tl(x) : 'h list w::tl(x) : 'h list
w
onto the tail of
x
; cons takes a head element whose type matches its
tail list's elements, so we conclude that:
w : 'd = 'h
y
is consed onto z
, so we conclude
that z
is a list:
z : 'g = 'i list
y
is consed onto z
,
y
must have the same type as the elements of
z
's list. This also allows us to conclude a
non-polymorphic type for z
, as well as the entire
cons expression:
y : 'f = bool = 'i z : 'i list = bool list y::z : bool list
if
statement
must have the same type. This allows us to propagate and solve
further type equalities:
w::tl(x) : 'h list = bool list tl(x) : 'h list = bool list x : 'e = 'h list = bool list w : 'd = bool
if
expression, yielding:
(if y then w::tl(x) else y::z) : bool list
myFun : bool * bool list * bool * bool list -> bool list
In this example, which these notes present in slightly less detail, the types are unconstrained and result in a polymorphic type. Consider the (curried) function:
fun foo w1 nil _ = (w1, nil) | foo w2 (v::x) y = let val bar = (w2, y v); val bif = v :: tl(x); in (#1(bar), bif) end;
Type inference proceeds as follows:
foo
, and observe that foo
is a curried
function; therefore, assign fresh type variables to its three
curried arguments and return value:
foo : 'a -> 'b -> 'c -> 'd
w1 : 'a 'b = 'e list
w
and whose second element is some kind
of list; therefore:
'd = 'a * 'f list
w2 : 'a (v::x) : 'b = 'g list v : 'g x : 'g list y : 'h
'b =
'e list
, so we further derive
'g list = 'b = 'e list
bar
binding, we observe that
y
is used as a function, and applied to
v
, so
y : 'h = 'g -> 'i bar : 'a * ('g -> 'i)
bif
binding, we obtain:
v : 'g tl(x) : 'g list bif : 'j = 'k list = 'g list
'a * 'g list
. Since this expression
determines the return type of foo
, we get
'd = 'a * 'g list = 'a * 'f list
foo : 'a -> 'g list -> ('g -> 'i) -> 'a * 'g list
foo : 'a -> 'b list -> ('b -> 'c) -> 'a -> 'b list
In previous lectures, I have occasionally referred informally to "unifying types". My hope was that you would use your intuitive sense of the English definition of this word until the time came to define it more precisely, which is now. Consider the expression:
fun ... = let ... val (x, y) = (3, z); in ...
where z
is defined earlier in the function.
Suppose that we had assigned the fresh type variable
'a
to z
. Our first step in inferring
the types of x
and y
would be to set up
a type equation like the following:
('b * 'c) = (int * 'a)
In order to proceed further, we must attempt to unify the left- and right-hand sides of this equation. Although the result seems obvious at a glance, it is useful to consider more precisely the exact procedure for unifying two types. To wit, here is a function in ML-like pseudocode that determines whether two values can be unified:
fun unify(t1, t2) = if t1 and t2 are base types then if t1 = t2 then true else false else if t1 is a type variable or t2 is a type variable then return true (adding constraint t1 = t2) else if t1 and t2 are compound types then if t1 and t2 have different constructors then false else if each corresponding element of t1 and t2 can be unified then true else false
This is a simplification of real unification, which returns not only true or false, but a set of type equalities resulting from successful unifications. Nevertheless, it gives us the intuition for how to unify the tuple example above:
unify('b, int)
: returns true, with type
constraint 'b = int
.unify('c, 'a)
: returns true, with type
constraint 'c = 'a
.int *
'a
.Here is a slightly more interesting type equation:
'a * ('b * 'c) * 'c list = ('v * 'w) * ('x * 'y) * 'z
We begin with the top-level type constructors, which are 3-tuples, and produce the following unifications:
unify('a, ('w * 'x)) unify(('b * 'c), ('x * 'y)) unify('c list, 'z)
For the first and third of these, at least one of the conclusions is a variable, so we bottom out; but for the second we can proceed and recursively unify:
unify('b, 'x) unify('c, 'y)
We bottom out and return. All told, we have the following type equations:
'a = 'w * 'x 'b = 'x 'c = 'y 'c list = 'z
Solving all these equations, the unified type is:
('v * 'w) * ('b * 'c) * 'c list
which, if this were the type inferred from some expression in SML/NJ, the read-eval-print loop would rename to
('a * 'c) * ('d * 'e) * 'e list
Aside: You should see some similarity between type unification and pattern-matching, which also recursively matches two trees. Pattern-matching unifies patterns with values, whereas type unification unifies types with types, but the recursive descent through the trees based on constructors is quite similar. In fact, pattern-matching is a form of unification algorithm.
This similarity may be clearer if you do the following exercise: first, evaluate the following pattern-match by hand, using pen and paper to show the correspondence between pattern trees and value trees...
val (a, b, {c=d, e=f}) = ("u", ("v", "w"), {c="x", e=("y", "z")});
...and then draw the process for performing the following type unification, showing correspondences between types, in the same fashion:
'a * 'b * {c:'d, e:'f} = 'u * ('v * 'w) * {c:'x, e:('y, 'z)}
The regular 'a
syntax means "any type" --- as
noted in previous lectures, we say that these variables are
universally quantified. However, sometimes you
want to be able to write a function over only those types that
satisfies some property --- perhaps all those types that
define an equality operator =
, or a less-than
operator <
. This is called bounded
polymorphism, as opposed to universal polymorphism, because
there's a "bound" (limit) on those types at which the type
variable can be instantiated.
Unfortunately, ML does not provide any form of general bounded
polymorphism --- you cannot write a type that says, "This function
applies to all types that have a prettyPrint
function
defined." However, ML does hard-code one form of bounded
polymorphism, using equality types. An equality
type is either:
int
or
string
) that supports the equality operator
=
.int
* string
, bool list
, or (int * bool)
list.An equality type variable is written with two quotes,
e.g. ''a
. For example, consider the type of the
following function, which searches for a value in a list:
- fun search (aValue, nil) = false | search (aValue, x::xs) = if aValue = x then true else search (aValue, xs); val search = fn : ''a * ''a list -> bool
This function takes an argument of ''a * ''a list
--- that is, a value of equality type, and a list of the same
equality type. Attempting to apply this to a non-equality type
will fail with a type error:
- search (1.0, [1.0, 2.0, 3.0]); stdIn:1.1-20.10 Error: operator and operand don't agree [equality type required] operator domain: ''Z * ''Z list operand: real * real list in expression: search (1.0,1.0 :: 2.0 :::: - search ((1.0, 2.0), [(1.0, 2.0), (3.0, 4.0)]); stdIn:1.1-20.26 Error: operator and operand don't agree [equality type required] operator domain: ''Z * ''Z list operand: (real * real) * (real * real) list in expression: search ((1.0,2.0),(1.0,2.0) :: () , ) :: nil)
Manually infer the types in the following declarations. If the expression implies type constraints that are overconstrained (cannot be satisfied by any type assignment), then explain why.
fun f a b = (a ^ ".", b);
fun g (c, d) = (f c c)::d;
fun m x nil = nil | m x (y::ys) = (x y)::(m x ys);
fun squid(a, b, c) = let val d = [2.0,3.0]; val e:(int * (int * string)) = (3, c); in if a > hd(d) then [hd(b), #2(c)] else tl(b) end;