(* CSE 341, Programming Languages *)
(* Lecture 11: Type Inference *)
(*** a couple examples "how a human would do it" before we walk through
"how the type-checker does it"
***)
fun f x = (* infer val f : int -> int *)
if x > 3
then 42
else x * 2
(*
fun g x = (* report type error *)
if x > 3
then true
else x * 2
*)
val x = 42 (* val x : int *)
fun f (y, z, w) =
if y (* y must be bool *)
then z + x (* z must be int *)
else 0 (* both branches have same type *)
(* f must return an int
f must take a bool * int * ANYTHING
so val f : bool * int * 'a -> int
*)
(*
f : T1 -> T2 [must be a function; all functions take one argument]
x : T1 [must have type of f's argument]
y : T3
z : T4
T1 = T3 * T4 [else pattern-match in val-binding doesn't type-check]
T3 = int [because (abs y) where abs : int -> int]
T4 = int [because add z to an int]
So T1 = int * int
So (abs y) + z : int, so let-expression : int, so body : int, so T2=int
So f : int * int -> int
*)
fun f x =
let val (y,z) = x in
(abs y) + z
end
(*
sum : T1 -> T2 [must be a function; all functions take one argument]
xs : T1 [must have type of f's argument]
x : T3 [pattern match against T3 list]
xs' : T3 list [pattern match against T3 list]
T1 = T3 list [else pattern-match on xs doesn't type-check]
0 : int, so case-expresssion : int, so body : int, so T2=int
T3 = int [because x : T3 and is argument to addition]
T2 = int [because result of recursive call is argument to addition]
sum xs' type-checks because xs' has type T3 list and T1 = T3 list
case-expression type-checks because both branches have type int
from T1 = T3 list and T3 = int, we know T1 = int list
from that and T2 = int, we know f : int list -> int
*)
fun sum xs =
case xs of
[] => 0
| x::xs' => x + (sum xs')
(*
type inference proceeds exactly like for sum for most of it:
broken_sum : T1 -> T2 [must be a function; all functions take one argument]
xs : T1 [must have type of f's argument]
x : T3 [pattern match against T3 list]
xs' : T3 list [pattern match against T3 list]
T1 = T3 list [else pattern-match on xs doesn't type-check]
0 : int, so case-expresssion : int, so body : int, so T2=int
T3 = int [because x : T3 and is argument to addition]
T2 = int [because result of recursive call is argument to addition]
but now to type-check (broken_sum x) we need T3 = T1 and T1 = T3 list,
so we need T3 = T3 list, which is impossible for any T3.
Note: The actual type-checker might gather facts in a different order and
therefore report a different error, but it will report an error.
*)
(*fun broken_sum xs =
case xs of
[] => 0
| x::xs' => x + (broken_sum x)
*)
(*
First several steps are just like with sum:
length : T1 -> T2 [must be a function; all functions take one argument]
xs : T1 [must have type of f's argument]
x : T3 [pattern match against T3 list]
xs' : T3 list [pattern match against T3 list]
T1 = T3 list [else pattern-match on xs doesn't type-check]
0 : int, so case-expresssion : int, so body : int, so T2=int
recursive call type-checks because xs' has type T3 list, which = T1
and T2=int, so fine argument to addition
so with all our constraints, length : T3 list -> int
so 'a list -> int
*)
fun length xs =
case xs of
[] => 0
| x::xs' => 1 + (length xs')
(*
f : T1 * T2 * T3 -> T4
x : T1
y : T2
z : T3
both conditional branches must have type T4 (the type of the function body),
so T1 * T2 * T3 = T4 and T2 * T1 * T3 = T4, which means T1 = T2
putting it all together, f : T1 * T1 * T3 -> T1 * T1 * T3
now replace unconstrained types /consistently/ with type variables:
'a * 'a * 'b -> 'a * 'a * 'b
*)
fun f (x,y,z) =
if true
then (x,y,z)
else (y,x,z)
(*
compose : T1 * T2 -> T3
f : T1
g : T2
x : T4
from body of compose being a function, T3 = T4->T5 for some T4 and T5
from g being passed x, T2 = T4->T6 for some T6
from f being passed result of g, T1 = T6->T7 for some T7
from f being body of anonymous function, T7=T5
putting it all together:
T1=T6->T5, T2=T4->T6, and T3=T4->T5
so compose: (T6->T5) * (T4->T6) -> (T4->T5)
now replace unconstrained types /consistently/ with type variables:
('a -> 'b) * ('c -> 'a) -> ('c -> 'b)
*)
fun compose (f,g) = fn x => f (g x)
(**** the value restriction (important, but optional material) ****)
(* this first line is not polymorphic so next two lines do not type-check *)
val r = ref NONE
(*
val _ = r := SOME "hi"
val i = 1 + valOf (!r)
*)
type 'a foo = 'a ref
val f : 'a -> 'a foo = ref
val r2 = f NONE (* also need value restriction here *)
(* where the value restriction arises despite no mutation *)
val pairWithOne = List.map (fn x => (x,1))
(* a workaround *)
fun pairWithOne2 xs = List.map (fn x => (x,1)) xs