(* Naive Bidirectional Type Checking *)
(* Frank Pfenning *)

signature TYPE_CHECK =
sig

    val check : Ast.env -> Ast.exp -> Ast.tp -> Ast.ext -> unit
    val synth : Ast.env -> Ast.exp -> Ast.ext -> Ast.tp

    val unique_tpvar : unit -> Ast.tpvarname
    val compare_tp : Ast.env -> Ast.tp -> Ast.tp -> bool (* tau <: tau' or tau == tau' *)
    val equal_tp : Ast.env -> Ast.tp -> Ast.tp -> bool (* tau == tau' *)
    val synthable : Ast.exp -> bool

end  (* signature TYPE_CHECK *)

structure TypeCheck :> TYPE_CHECK =
struct

structure A = Ast
structure PP = Ast.Print

val ERROR = ErrorMsg.ERROR

fun locate (A.Marked(marked_e)) ext = locate (Mark.data marked_e) (Mark.ext marked_e)
  | locate e ext = ext

fun pp_pats (p::nil) = PP.pp_pat p
  | pp_pats (p::ps) = PP.pp_pat p ^ "\n" ^ pp_pats ps
  | pp_pats nil = "[]"

fun type_mismatch e found expected ext =
    ERROR ext ("type mismatch\n"
               ^ "expected: " ^ expected ^ "\n"
               ^ "   found: " ^ found (* ^ "\n" *)
               (* ^ " because: " ^ PP.pp_exp e ^ " : " ^ found *)
               )

fun type_mismatch_pat p found expected ext =
    ERROR ext ("type mismatch\n"
               ^ "expected: " ^ expected ^ "\n"
               ^ "   found: " ^ found ^ "\n"
               ^ " because: " ^ PP.pp_pat p ^ " : " ^ found)

fun tag_mismatch found expected ext =
    ERROR ext ("tag mismatch\n"
               ^ "expected: one of " ^ PP.pp_tags expected ^ "\n"
               ^ "   found: " ^ PP.pp_tag found)

fun cannot_synth e ext =
    ERROR ext "expression does not synthesize a type"

fun extraneous tg_es =
    "extraneous components of record type (not checked):\n"
    ^ List.foldr (fn ((i,_),s) => PP.pp_tag i ^ " " ^ s) "" tg_es

(*
fun extraneous_branch (p,f) tau ext =
    if !Flags.subtyping then () (* warning will be issued for extraneous or redundant pattern *)
    else ERROR ext ("extraneous pattern:\n" ^  PP.pp_pat p)
*)

val highline = "%----------------------------------------"
fun highlight (msg) = highline ^ "\n" ^ msg ^ "\n" ^ highline

(* redundant patterns: warning or error *)
fun redundant_pats nil ext = ()
  | redundant_pats redundants ext =
    (if !Flags.subtyping andalso not (!Flags.lang = SOME(Flags.Prf))
     then ErrorMsg.warn else ErrorMsg.ERROR)
    ext ("redundant patterns:\n" ^ highlight (pp_pats redundants))

(* extraneous patterns: warning or error *)
fun extraneous_pats nil ext = ()
  | extraneous_pats redundants ext =
    (if !Flags.subtyping andalso not (!Flags.lang = SOME(Flags.Prf))
     then ErrorMsg.warn else ErrorMsg.ERROR)
    ext ("extraneous patterns:\n" ^ highlight (pp_pats redundants))

(* missing patterns: a warning unless we are doing proof checking *)
fun missing_pats nil ext = ()
  | missing_pats missing ext =
    (if not (!Flags.lang = SOME(Flags.Prf))
     then ErrorMsg.warn else ErrorMsg.ERROR)
    ext ("missing patterns:\n" ^ highlight (pp_pats missing))

(******************)
(* Kind Inference *)
(******************)

(*
fun equated_by_tp nil (a:A.tpvarname,a':A.tpvarname) = (a = a') (* free variables *)
  | equated_by_tp ((b,b')::tpeqs) (a,a') =
    if b = a then b' = a'
    else if b' = a' then b = a
    else equated_by_tp tpeqs (a,a')
*)

fun member tpeqs (a:A.tpvarname,a':A.tpvarname) =
    List.exists (fn (b,b') => a = b andalso a' = b') tpeqs

(* scons = Option.map ... *)
fun scons x (NONE) = NONE
  | scons x (SOME(xs)) = SOME(x::xs)

local
    val counter : int ref = ref 0
in
fun reset () = ( counter := 0 )
fun unique_tpvar () = "%" ^ Int.toString (!counter) before counter := !counter + 1
end

fun uniquify env (A.Arrow(tau1,tau2)) = A.Arrow(uniquify env tau1, uniquify env tau2)
  | uniquify env (A.Forall(a,tau)) = A.Forall(a,uniquify env tau)
  | uniquify env (tau as A.TpVar _) = tau
  | uniquify env (A.Times(tau1,tau2)) = A.Times(uniquify env tau1, uniquify env tau2)
  | uniquify env (tau as A.One) = tau
  | uniquify env (A.Plus(sum)) = A.Plus(List.map (fn (i,tau) => (i, uniquify env tau)) sum)
  | uniquify env (A.With(prod)) = A.With(List.map (fn (i,tau) => (i, uniquify env tau)) prod)
  | uniquify env (A.Rho(a,tau)) =
    let val u = unique_tpvar()
    in A.Rho(u, uniquify env (A.subst_tp_tp (A.TpVar(u)) a tau)) end
  | uniquify env (A.Exists(a,tau)) = A.Exists(a,uniquify env tau)
  | uniquify env (A.TLam(a,tau)) = raise Match
  | uniquify env (tau as A.TApp _) = uniquify env (A.expand_tapp env tau)
  | uniquify env (A.TpDef(a)) = uniquify env (A.expand_tp env a)

(* requires arg to be uniquified *)
fun empty hyps (A.Arrow(tau1,tau2)) = false
  | empty hyps (A.Forall(a,tau)) = false
  | empty hyps (A.TpVar(a)) = List.exists (fn h => a = h) hyps
  | empty hyps (A.Times(tau1,tau2)) = empty hyps tau1 orelse empty hyps tau2
  | empty hyps (A.One) = false
  | empty hyps (A.Plus(sum)) = List.all (fn (i,tau) => empty hyps tau) sum
  | empty hyps (A.With(prod)) = false (* lazy record non-empty *)
  | empty hyps (A.Rho(u,tau)) = (* u is unique *)
    empty (u::hyps) tau
  | empty hyps (A.Exists(a,tau)) = (* a is nonempty *)
    empty hyps tau
  | empty hyps (A.TLam _) = raise Match
  | empty hyps (A.TApp _) = raise Match
  | empty hyps (A.TpDef _) = raise Match

fun is_empty tau = empty nil tau


(*****************)
(* Type Equality *)
(*****************)

(* THIS IS NOW DEAD CODE *)
(* REPLACED BY SUBTYPING IN BOTH DIRECTIONS *)

(*
(* no defs; recursive types have been uniquified *)
fun eq_tp tpeqs (A.Arrow(tau1,tau2)) (A.Arrow(tau1',tau2')) =
    (eq_tp tpeqs tau1' tau1 andalso eq_tp tpeqs tau2 tau2')
    orelse (is_empty tau1' andalso is_empty tau1) (* rule S-T in Liggati17toplas *)
  | eq_tp tpeqs (A.Forall(a,tau)) (A.Forall(a',tau')) =
    let val u = unique_tpvar()
    in eq_tp tpeqs (A.subst_tp_tp (A.TpVar(u)) a tau)
             (A.subst_tp_tp (A.TpVar(u)) a' tau')
    end
  | eq_tp tpeqs (A.TpVar(a)) (A.TpVar(a')) = ( a = a' )
  | eq_tp tpeqs (tau as A.Times(tau1,tau2)) (tau' as A.Times(tau1',tau2')) =
    (eq_tp tpeqs tau1 tau1' andalso eq_tp tpeqs tau2 tau2')
    orelse (is_empty tau andalso is_empty tau')
  | eq_tp tpeqs (A.One) (A.One) = true
  | eq_tp tpeqs (A.Plus(sum)) (A.Plus(sum')) =
    eq_sum tpeqs sum sum'
  | eq_tp tpeqs (tau as A.Rho(u,tau1)) (tau' as A.Rho(u',tau1')) =
    (u = u' orelse member tpeqs (u,u') (* could also be here *)
     orelse eq_tp ((u,u')::(u',u)::tpeqs) (* build in symmetry here *)
                  (A.subst_tp_tp tau u tau1)
                  (A.subst_tp_tp tau' u' tau1'))
    orelse (is_empty tau andalso is_empty tau')
  | eq_tp tpeqs (A.TpDef _) _ = raise Match (* impossible *)
  | eq_tp tpeqs _ (A.TpDef _) = raise Match (* impossible *)
  | eq_tp tpeqs tau tau' = (* false *)
    (is_empty tau andalso is_empty tau')

and eq_sum tpeqs nil nil = true
  | eq_sum tpeqs ((i,tau)::sum) sum' =
    (case eq_tagged tpeqs (i,tau) sum'
      of NONE => false
       | SOME(sum'') => eq_sum tpeqs sum sum'')
  | eq_sum tpeqs nil ((j,sigma)::sum') =
    is_empty sigma andalso eq_sum tpeqs nil sum'

and eq_tagged tpeqs (i,tau) nil =
    if is_empty tau then SOME(nil) else NONE
  | eq_tagged tpeqs (i,tau) ((j,sigma)::sum') =
    if i = j
    then if eq_tp tpeqs tau sigma then SOME(sum') else NONE
    else scons (j,sigma) (eq_tagged tpeqs (i,tau) sum')

fun equal_tp env tau tau' =
    let val utau = uniquify env tau
        val utau' = uniquify env tau'
    in eq_tp nil utau utau' end

fun equal_tp_ env NONE tau' = true
  | equal_tp_ env (SOME(tau)) tau' = equal_tp env tau tau'
*)

(*************)
(* Subtyping *)
(*************)

(* tau == sigma iff tau <: sigma and sigma <: tau *)

(* subtyping, according to Ligatti et al., TOPLAS 2017 *)
(* no defs; recursive types have been uniquified *)
fun sub_tp tpeqs (A.Arrow(tau1,tau2)) (A.Arrow(tau1',tau2')) =
    sub_tp tpeqs tau1' tau1 andalso sub_tp tpeqs tau2 tau2'
  | sub_tp tpeqs tau (A.Arrow(tau1',tau2')) =
    (* special case, rule S-T in Liggati17toplas *)
    is_empty tau1'
  | sub_tp tpeqs (A.Forall(a,tau)) (A.Forall(a',tau')) =
    let val u = unique_tpvar()
    in sub_tp tpeqs (A.subst_tp_tp (A.TpVar(u)) a tau)
              (A.subst_tp_tp (A.TpVar(u)) a' tau')
    end
  | sub_tp tpeqs (A.TpVar(a)) (A.TpVar(a')) = ( a = a' )
  | sub_tp tpeqs (tau as A.Times(tau1,tau2)) (tau' as A.Times(tau1',tau2')) =
    (sub_tp tpeqs tau1 tau1' andalso sub_tp tpeqs tau2 tau2')
    orelse is_empty tau
  | sub_tp tpeqs (A.One) (A.One) = true
  | sub_tp tpeqs (A.Plus(sum)) (A.Plus(sum')) =
    sub_sum tpeqs sum sum'
  | sub_tp tpeqs (A.With(prod)) (A.With(prod')) =
    sub_prod tpeqs prod prod'
  | sub_tp tpeqs (tau as A.Rho(u,tau1)) (tau' as A.Rho(u',tau1')) =
    (u = u' orelse member tpeqs (u,u')
     orelse sub_tp ((u,u')::tpeqs)
                   (A.subst_tp_tp tau u tau1)
                   (A.subst_tp_tp tau' u' tau1'))
    orelse is_empty tau
  | sub_tp tpeqs (A.Exists(a,tau)) (A.Exists(a',tau')) =
    let val u = unique_tpvar()
    in sub_tp tpeqs (A.subst_tp_tp (A.TpVar(u)) a tau)
              (A.subst_tp_tp (A.TpVar(u)) a' tau')
    end
  | sub_tp tpeqs (A.TLam _) _ = raise Match
  | sub_tp tpeqs _ (A.TLam _) = raise Match
  | sub_tp tpeqs (A.TApp _) _ = raise Match
  | sub_tp tpeqs _ (A.TApp _) = raise Match
  | sub_tp tpeqs (A.TpDef _) _ = raise Match (* impossible *)
  | sub_tp tpeqs _ (A.TpDef _) = raise Match (* impossible *)
  | sub_tp tpeqs tau tau' = is_empty tau

and sub_sum tpeqs nil sum' = true (* width subtyping *)
  | sub_sum tpeqs ((i,tau)::sum) sum' =
    (case sub_tagged tpeqs (i,tau) sum'
      of NONE => false
       | SOME(sum'') => sub_sum tpeqs sum sum'')

and sub_tagged tpeqs (i,tau) nil =
    if is_empty tau then SOME(nil) else NONE
  | sub_tagged tpeqs (i,tau) ((j,sigma)::sum') =
    if i = j
    then if sub_tp tpeqs tau sigma
         then SOME(sum') (* depth subtyping *)
         else NONE
    else scons (j,sigma) (sub_tagged tpeqs (i,tau) sum')

and sub_prod tpeqs prod nil = true (* width subtyping *)
  | sub_prod tpeqs prod ((i,tau)::prod')  =
    (case sub_tagged_prod tpeqs prod (i,tau)
      of NONE => false
       | SOME(prod'') => sub_prod tpeqs prod'' prod')

and sub_tagged_prod tpeqs nil (i,tau) =
    (* lazy record, so component types are considered non-empty *)
    NONE
  | sub_tagged_prod tpeqs ((j,sigma)::prod) (i,tau) =
    if i = j
    then if sub_tp tpeqs sigma tau
         then SOME(prod)        (* depth subtyping *)
         else NONE
    else scons (j,sigma) (sub_tagged_prod tpeqs prod (i,tau))

(* type comparison; uses global flag *)
fun compare_tp env tau tau' =
    let (* val () = reset ()  *)
        (* do not reset internal counter; used in Norm.alpha *)
        val utau = uniquify env tau
        val utau' = uniquify env tau'
    in if !Flags.subtyping
       then sub_tp nil utau utau'
       else sub_tp nil utau utau' andalso sub_tp nil utau' utau
    end

(* alway uses equality; used for conversion test *)
fun equal_tp env tau tau' =
    let val utau = uniquify env tau
        val utau' = uniquify env tau'
    in sub_tp nil utau utau' andalso sub_tp nil utau' utau end

fun synthable (A.Var _) = true
  | synthable (A.Lam(x,NONE,e)) = false
  | synthable (A.Lam(x,SOME(tau),e)) = synthable e
  | synthable (A.App(e1,e2)) = synthable e1
  | synthable (A.TpLam(a,e)) = synthable e
  | synthable (A.TpApp(e1,tau2)) = synthable e1
  | synthable (A.Pair(e1,e2)) = synthable e1 andalso synthable e2
  | synthable (A.Unit) = true
  | synthable (A.Inject _) = false
  | synthable (A.Record _) = false (* for simplicity; sometimes may be synthable *)
  | synthable (A.Project(e,i)) = synthable e
  | synthable (A.Case _) = false (* for simplicity; sometimes may be synthable *)
  | synthable (A.Fold _) = false
  | synthable (A.Unfold(e)) = synthable e
  | synthable (A.Pack(tau,e)) = false
  | synthable (A.Fix(x,NONE,e)) = false
  | synthable (A.Fix(x,SOME(tau),e)) = true
  | synthable (A.Def _) = true
  | synthable (A.Marked(marked_e)) = synthable (Mark.data marked_e)

(***************************)
(* Exhaustiveness Checking *)
(***************************)

(*
fun anonymize (A.VarPat _) = A.VarPat("_")
  | anonymize (A.PairPat(p1,p2)) = A.PairPat(anonymize p1, anonymize p2)
  | anonymize (A.UnitPat) = A.UnitPat
  | anonymize (A.InjectPat(i,p)) = A.InjectPat(i,anonymize p)
  | anonymize (A.FoldPat(p)) = A.FoldPat(anonymize p)
  | anonymize (A.PackPat(a,p)) = A.PackPat(a,anonymize p)
*)

(* 
 * next section represents a different algorithm based on
 * splitting candidates.  it may be more efficient in some
 * cases, but doesn't identify redundant patterns
 *)
(*
local
    val ctr : int ref = ref 0
in
fun reset () = ( ctr := 0 )
fun fresh () = A.VarPat("_" ^ Int.toString (!ctr)) before ctr := !ctr + 1
end

datatype splits = NotSplittable | Splits of A.pat list
                                            
fun fresh_pats (A.Times(sigma1,sigma2)) = Splits [A.PairPat(fresh(),fresh())]
  | fresh_pats (A.One) = Splits [A.UnitPat]
  | fresh_pats (A.Plus(sum)) = Splits (List.map (fn (i,sigma) => A.InjectPat(i,fresh())) sum)
  | fresh_pats (A.Rho _) = Splits [A.FoldPat(fresh())]
  | fresh_pats (A.Exists(a,_)) = Splits [A.PackPat(a,fresh())]
  | fresh_pats _ = NotSplittable (* negative types not splittable *)

fun pairs1 q1 (NotSplittable) = NotSplittable
  | pairs1 q1 (Splits(q2s)) = Splits (List.map (fn q2 => A.PairPat(q1,q2)) q2s)

fun pairs2 q1s q2 = Splits (List.map (fn q1 => A.PairPat(q1,q2)) q1s)

fun map1 f (NotSplittable) = NotSplittable
  | map1 f (Splits(qs)) = Splits(List.map f qs)

fun split xs (A.VarPat(x)) sigma =
    if List.exists (fn x' => x = x') xs (* splittable *)
    then fresh_pats sigma (* could be unsplittable if sigma is negative *)
    else if is_empty sigma (* sigma is an empty type; split into 0 cases *)
    then Splits []
    else NotSplittable (* case not covered and x not splittable *)
  | split xs (A.PairPat(q1,q2)) (A.Times(sigma1,sigma2)) =
    (* do only the leftmost allowable split *)
    (case split xs q1 sigma1
      of NotSplittable => pairs1 q1 (split xs q2 sigma2)
       | Splits(q1s) => pairs2 q1s q2)
  | split xs (A.UnitPat) (A.One) = NotSplittable
  | split xs (A.InjectPat(i,q)) (A.Plus(sum)) =
    map1 (fn q' => A.InjectPat(i,q')) (split xs q (lookup i sum))
  | split xs (A.FoldPat(q)) (sigma as A.Rho(a,sigma')) =
    map1 (fn q' => A.FoldPat(q')) (split xs q (A.subst_tp_tp sigma a sigma'))
  | split xs (A.PackPat(b,q)) (A.Exists(a,sigma)) =
    map1 (fn q' => A.PackPat(b,q')) (split xs q sigma)
  (* other cases should be impossible for well-typed terms *)

datatype candidates = Covered | Candidates of A.varname list

fun union (Covered) (Covered) = Covered
  | union (Candidates(xs)) (Covered) = Candidates(xs)
  | union (Covered) (Candidates(ys)) = Candidates(ys)
  | union (Candidates(xs)) (Candidates(ys)) = Candidates (xs @ ys)

fun instance_of (A.VarPat _) q = Covered
  | instance_of _ (A.VarPat(x)) = Candidates [x]
  | instance_of (A.PairPat(p1,p2)) (A.PairPat(q1,q2)) =
        union (instance_of p1 q1) (instance_of p2 q2) (* lazier? *)
  | instance_of (A.UnitPat) (A.UnitPat) = Covered
  | instance_of (A.InjectPat(i,p)) (A.InjectPat(j,q)) =
    if i = j then instance_of p q
    else Candidates []
  | instance_of (A.FoldPat(p)) (A.FoldPat(q)) = instance_of p q
  | instance_of (A.PackPat(_,p)) (A.PackPat(_,q)) = instance_of p q
  (* by canonical forms, these should be the only possible cases *)

(* sigma uniquified *)
fun covered q sigma ((p,e)::branches) =
    (case instance_of p q
      of Covered => Covered
       | Candidates(xs) => (case covered q sigma branches
                             of Covered => Covered
                              | Candidates(ys) => Candidates(xs @ ys)))
  | covered q sigma nil = Candidates []

(* sigma uniquified *)
fun exhaustive (q::qs) missing sigma branches =
    (case covered q sigma branches
      of Covered => exhaustive qs missing sigma branches
       | Candidates(xs) => case split xs q sigma
                            of NotSplittable => exhaustive qs (missing @ [q]) sigma branches
                             | Splits(qs') => exhaustive (qs' @ qs) missing sigma branches)
  | exhaustive nil missing sigma branches = missing

fun missing_patterns env sigma branches =
    let val sigma' = uniquify env sigma
        val () = reset ()       (* fresh variable counter *)
    in exhaustive [fresh()] nil sigma' branches end
*)

fun lookup i ((j:A.tag,sigma)::sum) =
    if i = j then sigma else lookup i sum
    (* nil case should be impossible for well-typed patterns *)

val anon = A.VarPat("_")

(* intersection = unification, where both q and p are linear *)
fun intersect (A.VarPat _) p = SOME(p)
  | intersect q (A.VarPat _) = SOME(q)
  | intersect (A.PairPat(q1,q2)) (A.PairPat(p1,p2)) =
    (case (intersect q1 p1, intersect q2 p2)
      of (NONE, _) => NONE
       | (_, NONE) => NONE
       | (SOME(r1), SOME(r2)) => SOME(A.PairPat(r1,r2)))
  | intersect (A.UnitPat) (A.UnitPat) = SOME(A.UnitPat)
  | intersect (A.InjectPat(j,q)) (A.InjectPat(i,p)) =
    if i = j then (case intersect q p
                    of NONE => NONE
                     | SOME(r) => SOME(A.InjectPat(i,r)))
    else NONE
  | intersect (A.FoldPat(q)) (A.FoldPat(p)) =
    (case intersect q p
      of NONE => NONE
       | SOME(r) => SOME(A.FoldPat(r)))
  | intersect (A.PackPat(b,q)) (A.PackPat(a,p)) =
    (case intersect q p
      of NONE => NONE
       | SOME(r) => SOME(A.PackPat(a,r)))

fun intersections qs p = List.map (fn q => (q, intersect q p)) qs

fun feasible sigma (A.VarPat _) = true
  | feasible (A.Times(sigma1,sigma2)) (A.PairPat(p1,p2)) =
    feasible sigma1 p1 andalso feasible sigma2 p2
  | feasible (A.One) (A.UnitPat) = true
  | feasible (A.Plus(sum)) (A.InjectPat(j,pj)) =
    List.exists (fn (i,sigmai) => j = i andalso feasible sigmai pj) sum
  | feasible (sigma as A.Rho(a,sigma')) (A.FoldPat(p)) =
    feasible (A.subst_tp_tp sigma a sigma') p
  | feasible (A.Exists(b,sigma)) (A.PackPat(c,p)) =
    feasible sigma p (* correct? *)
  | feasible _ _ = false

(* 
   (q1 \/ q2) - p = (q1 - p) \/ (q2 - p)

   with some redundancies in the output:
   q1 * q2 - p1 * p2 = (q1 - p1) * q2 \/ q1 * (q2 - p2)

   for disjoint union:
   p1' = q1 /\ p1, p2' = q2 /\ p2
   q1 * q2 - p1' * p2' = (q1 - p1') * p2' \/ p1 * (q2 - p2') \/ (q1 - p1') * (q2 - p2')

   diff q sigma p
   assume q : sigma, p : sigma and q superset p
*)
fun diff _ sigma (A.VarPat _) = []
  | diff (A.VarPat _) sigma (p as A.PairPat _) = diff (A.PairPat(anon,anon)) sigma p
  | diff (A.VarPat _) sigma (p as A.UnitPat) = diff (A.UnitPat) sigma p
  | diff (A.VarPat _) (sigma as A.Plus(sum)) (p as A.InjectPat _) =
       diffs (List.map (fn (j,_) => A.InjectPat(j,anon)) sum) sigma p
  | diff (A.VarPat _) sigma (p as A.FoldPat _) = diff (A.FoldPat(anon)) sigma p
  | diff (A.VarPat _) (sigma as A.Exists(a,_)) (p as A.PackPat _) =
       diff (A.PackPat(a,anon)) sigma p
  | diff (A.PairPat(q1,q2)) (A.Times(sigma1,sigma2)) (A.PairPat(p1,p2)) =
    let val q1_p1 = diff q1 sigma1 p1
        val q2_p2 = diff q2 sigma2 p2
        val qs1 = List.map (fn q1' => A.PairPat(q1',p2)) q1_p1
        val qs2 = List.map (fn q2' => A.PairPat(p1,q2')) q2_p2
        val qss3 = List.map (fn q1' => List.map (fn q2' => A.PairPat(q1',q2')) q2_p2) q1_p1
    in qs1 @ qs2 @ List.concat qss3 end
  | diff (A.UnitPat) (A.One) (A.UnitPat) = []
  | diff (A.InjectPat(j,q)) (A.Plus(sum)) (A.InjectPat(i,p)) =
    if j = i then List.map (fn q' => A.InjectPat(j,q')) (diff q (lookup j sum) p)
    else [A.InjectPat(j,q)]
  | diff (A.FoldPat(q)) (sigma as A.Rho(a,sigma')) (A.FoldPat(p)) =
    List.map (fn q' => A.FoldPat(q')) (diff q (A.subst_tp_tp sigma a sigma') p)
  | diff (A.PackPat(a,q)) (A.Exists(b,sigma)) (A.PackPat(c,p)) =
    List.map (fn q' => A.PackPat(a,q')) (diff q sigma p)

and diffs (q::qs) sigma p = diff q sigma p @ diffs qs sigma p
  | diffs nil sigma p = nil

and union_diffs ((q,SOME(p))::qps) sigma = diff q sigma p @ union_diffs qps sigma
  | union_diffs ((q,NONE)::qps) sigma = q::union_diffs qps sigma
  | union_diffs nil sigma = nil

fun subtraction qs sigma ((p,e)::branches) redundants extras =
    let 
        val qps = intersections qs p (* syntactic *)
        val redundant = List.all (fn (q,NONE) => true | (q,SOME _) => false) qps
        val extraneous = List.all (fn (q,NONE) => true | (q,SOME(p)) => not (feasible sigma p)) qps (* semantic *)
    in
        if redundant (* syntactic *)
        then subtraction qs sigma branches (p::redundants) extras
        else if extraneous
        then subtraction (union_diffs qps sigma) sigma branches redundants (p::extras)
        else subtraction (union_diffs qps sigma) sigma branches redundants extras
    end
  | subtraction qs sigma nil redundants extras =
    (qs, redundants, extras)

(* sigma already uniquified *)
fun empty_pat (A.VarPat _) sigma = is_empty sigma
  | empty_pat (A.PairPat(p1,p2)) (A.Times(sigma1,sigma2)) =
      empty_pat p1 sigma1 orelse empty_pat p2 sigma2
  | empty_pat (A.UnitPat) (A.One) = false
  | empty_pat (A.InjectPat(i,p)) (A.Plus(sum)) = empty_pat p (lookup i sum)
  | empty_pat (A.FoldPat(p)) (sigma as A.Rho(a,sigma')) =
      empty_pat p (A.subst_tp_tp sigma a sigma')
  | empty_pat (A.PackPat(a,p)) (A.Exists(b,sigma)) = empty_pat p sigma

fun nonempty_pats (q::qs) sigma =
    if empty_pat q sigma
    then nonempty_pats qs sigma
    else q::nonempty_pats qs sigma
  | nonempty_pats nil sigma = nil

fun missing_by_subtraction env sigma branches =
    let val sigma' = uniquify env sigma
        val (missing, redundants, extras) = subtraction [anon] sigma' branches nil nil
        val missing' = nonempty_pats missing sigma'
    in (missing', redundants, extras) end

(**************************)
(* Checking and Synthesis *)
(**************************)

fun synth_var ((y,tau)::ctx) x ext = if y = x then tau else synth_var ctx x ext
  | synth_var nil x ext = ERROR ext ("unbound variable " ^ x)

fun synth_def env x ext =
    case A.lookup_var_tp env x
     of SOME(tau) => tau
      | NONE => ERROR ext ("no type declared for variable " ^ x)

fun check env tpctx ctx (A.Marked(marked_e)) tau ext =
    check env tpctx ctx (Mark.data marked_e) tau (Mark.ext marked_e)

  | check env tpctx ctx (e as A.Lam(x1,NONE,e2)) tau ext =
    (case A.expose_tp env tau
      of A.Arrow(tau1,tau2) => check env tpctx ((x1,tau1)::ctx) e2 tau2 ext
       | _ => type_mismatch e "<tp> -> <tp>" (PP.pp_tp tau) ext)
  | check env tpctx ctx (e as A.Lam(x1,SOME(tau1'),e2)) tau ext =
    (case A.expose_tp env tau
      of A.Arrow(tau1,tau2) => if compare_tp env tau1 tau1'  (* tau1 <: tau1' *)
                               then check env tpctx ((x1,tau1')::ctx) e2 tau2 ext
                               else type_mismatch (A.Var(x1)) (PP.pp_tp tau1) (PP.pp_tp tau1') ext
       | _ => type_mismatch e "<tp> -> <tp>" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.TpLam(a,e2)) tau ext =
    (case A.expose_tp env tau
      of A.Forall(b,tau2) =>
         let val a' = A.fresh tpctx a
             val e2' = if a = a' orelse a = "_" then e2 else A.subst_tp_exp (A.TpVar(a')) a e2
             val tau2' = A.subst_tp_tp (A.TpVar(a')) b tau2
         in check env (a'::tpctx) ctx e2' tau2' ext end
       | _ => type_mismatch e "!<id>. <tp>" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.Pair(e1,e2)) tau ext =
    (case A.expose_tp env tau
      of A.Times(tau1,tau2) =>
         ( check env tpctx ctx e1 tau1 ext
         ; check env tpctx ctx e2 tau2 ext )
       | _ => type_mismatch e "<tp> * <tp>" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.Unit) tau ext =
    (case A.expose_tp env tau
      of A.One => ()
       | _ => type_mismatch e "1" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.Inject(i,ei)) tau ext =
    (case A.expose_tp env tau
      of A.Plus(sum) => check_sum env tpctx ctx (i,ei) sum sum ext
       | _ => type_mismatch e "<sum>" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.Record(texps)) tau ext =
    (case A.expose_tp env tau
      of A.With(prod) => check_prod env tpctx ctx texps prod ext
       | _ => type_mismatch e "<prod>" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.Case(e',branches)) tau ext =
    let val tau' = synth env tpctx ctx e' ext
        (* typecheck first, to establish invariant for exhaustiveness checking *)
        val () = check_branches env tpctx ctx tau' branches tau ext
        (* val missing = missing_patterns env tau' branches *)
        (* val anon = List.map anonymize missing *)
        val (missing, redundants, extras) = missing_by_subtraction env tau' branches
        val () = redundant_pats redundants ext (* always error? *)
        val () = extraneous_pats extras ext (* okay, warning, or error *)
        val () = missing_pats missing ext      (* okay, warning, or error *)
    in
        ()
    end

  | check env tpctx ctx (e as A.Fold(e')) tau ext =
    (case A.expose_tp env tau
      of A.Rho(a,tau1) => check env tpctx ctx e' (A.subst_tp_tp tau a tau1) ext
       | _ => type_mismatch e "$<id>. <tp>" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.Pack(sigma,e')) tau ext =
    (case A.expose_tp env tau
      of A.Exists(a,tau1) => check env tpctx ctx e' (A.subst_tp_tp sigma a tau1) ext
       | _ => type_mismatch e "?<id>. <tp>" (PP.pp_tp tau) ext)

  | check env tpctx ctx (e as A.Fix(g,NONE,e1)) tau ext =
    check env tpctx ((g,tau)::ctx) e1 tau ext

  | check env tpctx ctx (e as A.Fix(g,SOME(tau1),e1)) tau ext =
    if compare_tp env tau1 tau (* tau1 <: tau *)
    then check env tpctx ((g,tau1)::ctx) e1 tau ext
    else type_mismatch (A.Var(g)) (PP.pp_tp tau1) (PP.pp_tp tau) ext

  | check env tpctx ctx e tau ext =
    let val tau' = synth env tpctx ctx e ext
    in if compare_tp env tau' tau (* tau' <: tau *)
       then ()
       else type_mismatch e (PP.pp_tp tau') (PP.pp_tp tau) ext
    end

and check_prod env tpctx ctx nil nil ext = ()
  | check_prod env tpctx ctx texps nil ext =
    if !Flags.subtyping
    then ErrorMsg.warn ext (extraneous texps)
    else ERROR ext (extraneous texps)
  | check_prod env tpctx ctx texps ((i,tau)::prod) ext =
    let val texps' = check_tag env tpctx ctx texps (i,tau) ext
    in check_prod env tpctx ctx texps' prod ext end
and check_tag env tpctx ctx nil (i,tau) ext =
    ERROR ext ("type mismatch\n"
               ^ "tag " ^ PP.pp_tag i ^ " : " ^ PP.pp_tp tau ^ "\n"
               ^ "not present in lazy record")
  | check_tag env tpctx ctx ((j,e)::texps) (i,tau) ext =
    if i = j then ( check env tpctx ctx e tau ext
                  ; texps )
    else (j,e)::check_tag env tpctx ctx texps (i,tau) ext

and check_branches env tpctx ctx sigma nil tau ext = ()
  | check_branches env tpctx ctx sigma ((p,f)::branches) tau ext =
    (case synth_ctx env (tpctx, ctx) sigma p ext
      of NONE => ( (* extraneous branch; ignore here to catch during exhaustiveness/redundancy checking *)
                   () (* extraneous_branch (p,f) tau ext *)
                 ; check_branches env tpctx ctx sigma branches tau ext )
       | SOME(tpctx', ctx') =>
         ( check env tpctx' ctx' f tau ext
         ; check_branches env tpctx ctx sigma branches tau ext ))

and synth_ctx env (tpctx, ctx) sigma (A.VarPat("_")) ext = SOME(tpctx, ctx)
  | synth_ctx env (tpctx, ctx) sigma (p as A.VarPat(x)) ext = SOME(tpctx, (x,sigma)::ctx)
  | synth_ctx env (tpctx, ctx) sigma (p as A.UnitPat) ext =
    (case A.expose_tp env sigma
      of A.One => SOME(tpctx, ctx)
       | _ => type_mismatch_pat p "1" (PP.pp_tp sigma) ext)
  | synth_ctx env (tpctx, ctx) sigma (p as A.PairPat(p1,p2)) ext =
    (case A.expose_tp env sigma
      of A.Times(sigma1,sigma2) =>
         (case (synth_ctx env (tpctx, ctx) sigma1 p1 ext)
           of NONE => NONE
            | SOME(tpctx1,ctx1) => synth_ctx env (tpctx1, ctx1) sigma2 p2 ext)
       | _ => type_mismatch_pat p "<tp> * <tp>" (PP.pp_tp sigma) ext)
  | synth_ctx env (tpctx, ctx) sigma (p as A.InjectPat(i,p')) ext =
    (case A.expose_tp env sigma
      of A.Plus(sum) =>
         (case project sum i sum ext
           of NONE => NONE
            | SOME(tau) => synth_ctx env (tpctx, ctx) tau p' ext)
       | _ => type_mismatch_pat p "<sum>" (PP.pp_tp sigma) ext)
  | synth_ctx env (tpctx, ctx) sigma (p as A.FoldPat(p')) ext =
    (case A.expose_tp env sigma
      of A.Rho(a,sigma') => synth_ctx env (tpctx, ctx) (A.subst_tp_tp sigma a sigma') p' ext
       | _ => type_mismatch_pat p "$a. <tp>" (PP.pp_tp sigma) ext)
  | synth_ctx env (tpctx, ctx) sigma (p as A.PackPat(a,p')) ext =
    (case A.expose_tp env sigma
      of A.Exists(a',sigma') => synth_ctx env ((a::tpctx), ctx) (A.subst_tp_tp (A.TpVar(a)) a' sigma')
                                          p' ext
       | _ => type_mismatch_pat p "?a. <tp>" (PP.pp_tp sigma) ext)

and project nil i sum ext =
    if !Flags.subtyping then NONE
    else tag_mismatch i (List.map (fn (i,tau) => i) sum) ext
  | project ((j,tau)::sum') i sum ext = if j = i then SOME(tau) else project sum' i sum ext

and check_sum env tpctx ctx (i,e) nil sum ext =
    tag_mismatch i (List.map (fn (i,tau) => i) sum) ext
  | check_sum env tpctx ctx (i,e) ((j,tau)::sum') sum ext =
    if i = j then check env tpctx ctx e tau ext
    else check_sum env tpctx ctx (i,e) sum' sum ext

and synth env tpctx ctx (A.Var(x)) ext = synth_var ctx x ext
  | synth env tpctx ctx (A.Lam(x,NONE,e)) ext = cannot_synth e ext
  | synth env tpctx ctx (A.Lam(x,SOME(tau),e)) ext = A.Arrow(tau, synth env tpctx ((x,tau)::ctx) e ext)
  | synth env tpctx ctx (A.App(e1,e2)) ext =
    let val tau = synth env tpctx ctx e1 ext
    in case A.expose_tp env tau
        of A.Arrow(tau2,tau1) => ( check env tpctx ctx e2 tau2 ext ; tau1 )
         | _ => type_mismatch e1 (PP.pp_tp tau) "<tp> -> <tp>" (locate e1 ext)
    end

  | synth env tpctx ctx (A.TpLam(a,e)) ext =
    let val a' = A.fresh tpctx a
        val e' = if a = a' orelse a = "_" then e else A.subst_tp_exp (A.TpVar(a')) a e
    in A.Forall(a', synth env (a'::tpctx) ctx e' ext) end
  | synth env tpctx ctx (A.TpApp(e,tau)) ext =
    let val sigma = synth env tpctx ctx e ext
    in case A.expose_tp env sigma
        of A.Forall(a,sigma') => A.subst_tp_tp tau a sigma'
         | _ => type_mismatch e (PP.pp_tp sigma) "!<id>. <tp>" (locate e ext)
    end

  | synth env tpctx ctx (A.Pair(e1,e2)) ext =
    A.Times(synth env tpctx ctx e1 ext, synth env tpctx ctx e2 ext)
  | synth env tpctx ctx (A.Unit) ext = A.One
  | synth env tpctx ctx (e as A.Inject _) ext = cannot_synth e ext
  | synth env tpctx ctx (e as A.Record _) ext = cannot_synth e ext
  | synth env tpctx ctx (A.Project(e',i)) ext =
    let val tau = synth env tpctx ctx e' ext
    in case A.expose_tp env tau
        of A.With(prod) => synth_project prod i tau ext
         | _ => type_mismatch e' (PP.pp_tp tau) "<tp> & <tp>" (locate e' ext)
    end
  | synth env tpctx ctx (e as A.Case _) ext = cannot_synth e ext (* for simplicity *)
  | synth env tpctx ctx (e as A.Fold _) ext = cannot_synth e ext
  | synth env tpctx ctx (A.Unfold(e')) ext =
    let val tau = synth env tpctx ctx e' ext
    in case A.expose_tp env tau
        of A.Rho(a,tau') => A.subst_tp_tp tau a tau'
         | _ => type_mismatch e' (PP.pp_tp tau) "$<id>. <tp>" (locate e' ext)
    end
  | synth env tpctx ctx (e as A.Pack _) ext = cannot_synth e ext
  | synth env tpctx ctx (e as A.Fix(g,NONE,e')) ext = cannot_synth e ext
  | synth env tpctx ctx (A.Fix(g,SOME(tau),e')) ext =
    ( check env tpctx ((g,tau)::ctx) e' tau ext ; tau )

  | synth env tpctx ctx (A.Def(x)) ext = synth_def env x ext
  | synth env tpctx ctx (A.Marked(marked_e)) ext =
    synth env tpctx ctx (Mark.data marked_e) (Mark.ext marked_e)

and synth_project nil i tau ext =
    ERROR ext ("tag " ^ PP.pp_tag i ^ " not present in type " ^ PP.pp_tp tau)
  | synth_project ((j,sigma)::texps) i tau ext =
    if i = j then sigma else synth_project texps i tau ext

val check = fn env => fn e => fn tau => fn ext => check env nil nil e tau ext
val synth = fn env => fn e => fn ext => synth env nil nil e ext

end  (* structure TypeCheck *)
