(* Abstract Syntax *)
(* Frank Pfenning <fp@cs.cmu.edu> *)

signature AST =
sig

type varname = string           (* x, for variable names *)
type tpvarname = string
type tag = string
type ext = Mark.ext option      (* optional extent (source region info) *)
type bound = int option

(* Kinds *)
datatype kind =
         KType
       | KArrow of kind * kind

(* Polymorphic Types *)
datatype tp =
         Arrow of tp * tp            (* tau1 -> tau2 *)
       | Forall of tpvarname * tp    (* !a. tau *)
       | TpVar of tpvarname          (* a *)
       | Times of tp * tp            (* tau1 * tau2 *)
       | One                         (* 1 *)
       | Plus of (tag * tp) list     (* ('i1 : tau1) + ... + ('in : taun) *)
       | With of (tag * tp) list     (* ('i1 : tau1) & ... & ('in : taun) *)
       | Exists of tpvarname * tp    (* ?a. tau *)
       | Rho of tpvarname * tp       (* $a. tau *)
       | TLam of tpvarname * tp      (* \a. tau *)
       | TApp of tp * tp             (* tau1 tau2 *)
       | TpDef of tpvarname          (* a *)

(* Typed Lambda Expressions *)
datatype exp =
         Var of varname                     (* x *)
       | Lam of varname * tp option * exp   (* \x.e or \x:tau.e *)
       | App of exp * exp                   (* e1 e2 *)
       | TpLam of tpvarname * exp           (* /\a .e *)
       | TpApp of exp * tp                  (* e [tau] *)
       | Pair of exp * exp                  (* (e1,e2) *)
       | Unit                               (* ( ) *)
       | Inject of tag * exp                (* 'i e *)
       | Record of (tag * exp) list         (* <| 'i1 => e1 | ... |> *)
       | Project of exp * tag               (* e.'i *)
       | Case of exp * (pat * exp) list     (* case e (bs) *)
       | Fold of exp                        (* fold e *)
       | Unfold of exp                      (* unfold e == case e of (fold x => x) *)
       | Pack of tp * exp                   (* ([tau],e) *)
       | Fix of varname * tp option * exp   (* $x.e or $x:tau.e *)
       | Def of varname                     (* x *)
       | Marked of exp Mark.marked
     and pat =
         VarPat of varname                  (* x *)
       | PairPat of pat * pat               (* (p1,p2) *)
       | UnitPat                            (* 1 *)
       | InjectPat of tag * pat             (* 'i p *)
       | FoldPat of pat                     (* fold p *)
       | PackPat of tpvarname * pat         (* ([a],p) *)

(* Declarations *)
datatype dec =
         Type of tpvarname * tp * ext          (* type a = tau *)
       | Decl of varname * tp * ext            (* decl x : tau *)
       | Defn of varname * exp * ext           (* defn x = e *)
       | Norm of varname * exp * bound * ext   (* norm [bound] x = e *)
       | Conv of exp * exp * ext               (* conv e1 = e2 *)
       | Eval of varname * exp * bound * ext   (* eval [bound] x = e *)
       | Not of dec * ext                      (* fail <dec> *)
       | Pragma of string * string * ext       (* #options, #test *)

type env = dec list

type subst = (exp * varname) list

val fresh : tpvarname list -> tpvarname -> tpvarname

val subst_exp : subst -> exp -> exp
val subst_exp_exp : exp -> varname -> exp -> exp
val subst_tp_exp : tp -> tpvarname -> exp -> exp
val subst_tp_tp : tp -> tpvarname -> tp -> tp

val lookup_var : env -> varname -> exp option
val lookup_tpvar : env -> tpvarname -> tp option
val lookup_var_tp : env -> varname -> tp option

val defined_tpvar : env -> tpvarname -> bool
val defined_var : env -> varname -> bool
val declared_var : env -> varname -> bool

val expand_tp : env -> tpvarname -> tp  (* if known to be defined *)
val expand_tapp : env -> tp -> tp       (* expand d tau1 ... taun *)
val expose_tp : env -> tp -> tp         (* expose type constructor; must be known to be defined *)
val expand_exp : env -> varname -> exp  (* if known to be defined *)

val free_tpvars : env -> tp -> tpvarname list
val free_vars : env -> exp -> (tpvarname list * varname list)
val free_vars_pat : env -> pat -> varname list

val stdize_tp : tp -> tp
val stdize_exp : exp -> exp

(* Printing *)

structure Print :
sig
    val pp_tag : tag -> string
    val pp_tags : tag list -> string
    val pp_tp  : tp -> string
    val pp_pat : pat -> string
    val pp_exp : exp -> string
    val pp_dec : dec -> string
    val pp_env : env -> string
    val pp_valdecs : dec list -> string
end

end (* signature AST *)

structure Ast :> AST =
struct

type varname = string
type tpvarname = string
type tag = string
type ext = Mark.ext option
type bound = int option

(* Kinds *)
datatype kind =
         KType
       | KArrow of kind * kind

(* Polymorphic Types *)
datatype tp =
         Arrow of tp * tp            (* tau1 -> tau2 *)
       | Forall of tpvarname * tp    (* !a. tau *)
       | TpVar of tpvarname          (* a *)
       | Times of tp * tp            (* tau1 * tau2 *)
       | One                         (* 1 *)
       | Plus of (tag * tp) list     (* ('i1 : tau1) + ... + ('in : taun) *)
       | With of (tag * tp) list     (* ('i1 : tau1) & ... & ('in : taun) *)
       | Exists of tpvarname * tp    (* ?a. tau *)
       | Rho of tpvarname * tp       (* $a. tau *)
       | TLam of tpvarname * tp      (* \a. tau *)
       | TApp of tp * tp             (* tau1 tau2 *)
       | TpDef of tpvarname          (* a *)

(* Typed Lambda Expressions *)
datatype exp =
         Var of varname                     (* x *)
       | Lam of varname * tp option * exp   (* \x.e or \x:tau.e *)
       | App of exp * exp                   (* e1 e2 *)
       | TpLam of tpvarname * exp           (* /\a .e *)
       | TpApp of exp * tp                  (* e [tau] *)
       | Pair of exp * exp                  (* (e1,e2) *)
       | Unit                               (* ( ) *)
       | Inject of tag * exp                (* 'i e *)
       | Record of (tag * exp) list         (* (| 'i1 => e1 | ... |) *)
       | Project of exp * tag               (* e.'i *)
       | Case of exp * (pat * exp) list     (* case e (bs) *)
       | Fold of exp                        (* fold e *)
       | Unfold of exp                      (* unfold e == case e of (fold x => x) *)
       | Pack of tp * exp                   (* ([tau],e) *)
       | Fix of varname * tp option * exp   (* $x.e or $x:tau.e *)
       | Def of varname                     (* x *)
       | Marked of exp Mark.marked
     and pat =
         VarPat of varname                  (* x *)
       | PairPat of pat * pat               (* (p1,p2) *)
       | UnitPat                            (* 1 *)
       | InjectPat of tag * pat             (* 'i p *)
       | FoldPat of pat                     (* fold p *)
       | PackPat of tpvarname * pat         (* ([a],p) *)

(* Declarations *)
datatype dec =
         Type of tpvarname * tp * ext          (* type a = tau *)
       | Decl of varname * tp * ext            (* decl x : tau *)
       | Defn of varname * exp * ext           (* defn x = e *)
       | Norm of varname * exp * bound * ext   (* norm [bound] x = e *)
       | Conv of exp * exp * ext               (* conv e1 = e2 *)
       | Eval of varname * exp * bound * ext   (* eval [bound] x = e *)
       | Not of dec * ext                      (* fail <dec> *)
       | Pragma of string * string * ext       (* #options, #test *)

type env = dec list
type subst = (exp * varname) list

(****************)
(* Substitution *)
(****************)

fun free_in_tp "_" tau = false
  | free_in_tp a (Arrow(tau1,tau2)) = free_in_tp a tau1 orelse free_in_tp a tau2
  | free_in_tp a (Forall(b,tau)) = free_in_tp_bind a (b,tau)
  | free_in_tp a (TpVar(b)) = (a = b)
  | free_in_tp a (Times(tau1,tau2)) = free_in_tp a tau1 orelse free_in_tp a tau2
  | free_in_tp a (One) = false
  | free_in_tp a (Plus(sum)) = List.exists (fn (i,tau) => free_in_tp a tau) sum
  | free_in_tp a (With(prod)) = List.exists (fn (i,tau) => free_in_tp a tau) prod
  | free_in_tp a (Exists(b,tau)) = free_in_tp_bind a (b,tau)
  | free_in_tp a (Rho(b,tau)) = free_in_tp_bind a (b,tau)
  | free_in_tp a (TLam(b,tau)) = free_in_tp_bind a (b,tau)
  | free_in_tp a (TApp(tau1,tau2)) = free_in_tp a tau1 orelse free_in_tp a tau2
  | free_in_tp a (TpDef(b)) = false

and free_in_tp_bind a (b,tau) = a <> b andalso free_in_tp a tau

fun def_in_tp "_" tau = false
  | def_in_tp a (Arrow(tau1,tau2)) = def_in_tp a tau1 orelse def_in_tp a tau2
  | def_in_tp a (Forall(b,tau)) = def_in_tp a tau
  | def_in_tp a (Times(tau1,tau2)) = def_in_tp a tau1 orelse def_in_tp a tau2
  | def_in_tp a (One) = false
  | def_in_tp a (Plus(sum)) = List.exists (fn (i,tau) => def_in_tp a tau) sum
  | def_in_tp a (With(prod)) = List.exists (fn (i,tau) => def_in_tp a tau) prod
  | def_in_tp a (Exists(b,tau)) = def_in_tp a tau
  | def_in_tp a (Rho(b,tau)) = def_in_tp a tau
  | def_in_tp a (TLam(b,tau)) = def_in_tp a tau
  | def_in_tp a (TApp(tau1,tau2)) = def_in_tp a tau1 orelse def_in_tp a tau2
  | def_in_tp a (TpVar(b)) = false
  | def_in_tp a (TpDef(b)) = (a = b)

fun fresh tpctx a = if List.all (fn b => a <> b) tpctx then a else fresh tpctx (a ^ "'")

fun subst_tp_tp tau "_" sigma = sigma
  | subst_tp_tp tau a (Arrow(sigma1,sigma2)) = Arrow(subst_tp_tp tau a sigma1, subst_tp_tp tau a sigma2)
  | subst_tp_tp tau a (Forall(b,sigma)) = Forall (subst_tp_tp_bind tau a (b,sigma))
  | subst_tp_tp tau a (sigma as TpVar(b)) = if a = b then tau else sigma
  | subst_tp_tp tau a (Times(sigma1,sigma2)) = Times(subst_tp_tp tau a sigma1, subst_tp_tp tau a sigma2)
  | subst_tp_tp tau a (sigma as One) = sigma
  | subst_tp_tp tau a (Plus(sums)) =
      Plus(List.map (fn (i,sigma) => (i,subst_tp_tp tau a sigma)) sums)
  | subst_tp_tp tau a (With(prod)) =
      With(List.map (fn (i,sigma) => (i,subst_tp_tp tau a sigma)) prod)
  | subst_tp_tp tau a (Exists(b,sigma)) = Exists (subst_tp_tp_bind tau a (b,sigma))
  | subst_tp_tp tau a (Rho(b,sigma)) = Rho (subst_tp_tp_bind tau a (b,sigma))
  | subst_tp_tp tau a (TLam(b,sigma)) = TLam (subst_tp_tp_bind tau a (b,sigma))
  | subst_tp_tp tau a (TApp(sigma1,sigma2))= TApp(subst_tp_tp tau a sigma1, subst_tp_tp tau a sigma2)
  | subst_tp_tp tau a (TpDef(b)) = TpDef(b)
and subst_tp_tp_bind tau a (bSigma as (b,sigma)) =
    if not (free_in_tp_bind a bSigma) then bSigma
    else if free_in_tp b tau
    then let val (b',sigma') = rename_apart_tp_tp (fn c => not (free_in_tp c tau)
                                                           andalso not (free_in_tp_bind c bSigma))
                                                  b bSigma
         in (b', subst_tp_tp tau a sigma') end
    else (b, subst_tp_tp tau a sigma)

and rename_apart_tp_tp pred c (b,sigma) = (* c is failed candidate for new name of b *)
    let val c' = c ^ "'"
    in if pred c'
       then (c', subst_tp_tp (TpVar(c')) b sigma)
       else rename_apart_tp_tp pred c' (b,sigma)
    end

fun free_in "_" _ = false
  | free_in x (Var(y)) = (x = y)
  | free_in x (Lam(y,tau_,e)) = x <> y andalso free_in x e
  | free_in x (App(e1,e2)) = free_in x e1 orelse free_in x e2
  | free_in x (TpLam(a,e)) = free_in x e
  | free_in x (TpApp(e,tau)) = free_in x e
  | free_in x (Pair(e1,e2)) = free_in x e1 orelse free_in x e2
  | free_in x (Unit) = false
  | free_in x (Inject(i,e)) = free_in x e
  | free_in x (Record(prod)) =
       List.exists (fn (i,e) => free_in x e) prod
  | free_in x (Project(e,i)) = free_in x e
  | free_in x (Case(e,branches)) =
       free_in x e orelse List.exists (free_in_branch x) branches
  | free_in x (Fold(e)) = free_in x e
  | free_in x (Unfold(e)) = free_in x e
  | free_in x (Pack(tau,e)) = free_in x e
  | free_in x (Fix(y,tau_,e)) = x <> y andalso free_in x e
  | free_in x (Def(y)) = false
  | free_in x (Marked(marked_e)) = free_in x (Mark.data marked_e)
and free_in_branch x (p,e) = not (free_in_pat x p) andalso free_in x e
and free_in_pat "_" _ = false
  | free_in_pat x (VarPat(y)) = (x = y)
  | free_in_pat x (PairPat(p1,p2)) = free_in_pat x p1 orelse free_in_pat x p2
  | free_in_pat x (UnitPat) = false
  | free_in_pat x (InjectPat(i,p)) = free_in_pat x p
  | free_in_pat x (FoldPat(p)) = free_in_pat x p
  | free_in_pat x (PackPat(a,p)) = free_in_pat x p

and free_in_bind x (y,e) = x <> y andalso free_in x e

and free_in_subst x ((e,y)::theta) = free_in x e orelse free_in_subst x theta
  | free_in_subst x (nil) = false

fun free_in_tp_exp a (Var _) = false
  | free_in_tp_exp a (Lam(y,NONE,e)) = free_in_tp_exp a e
  | free_in_tp_exp a (Lam(y,SOME(tau),e)) = free_in_tp a tau orelse free_in_tp_exp a e
  | free_in_tp_exp a (App(e1,e2)) = free_in_tp_exp a e1 orelse free_in_tp_exp a e2
  | free_in_tp_exp a (TpLam(b,e)) = a <> b andalso free_in_tp_exp a e
  | free_in_tp_exp a (TpApp(e,tau)) = free_in_tp_exp a e orelse free_in_tp a tau
  | free_in_tp_exp a (Pair(e1,e2)) = free_in_tp_exp a e1 orelse free_in_tp_exp a e2
  | free_in_tp_exp a (Unit) = false
  | free_in_tp_exp a (Inject(i,e)) = free_in_tp_exp a e
  | free_in_tp_exp a (Record(prod)) =
       List.exists (fn (i,e) => free_in_tp_exp a e) prod
  | free_in_tp_exp a (Project(e,i)) = free_in_tp_exp a e
  | free_in_tp_exp a (Case(e,branches)) =
       free_in_tp_exp a e orelse List.exists (free_in_tp_branch a) branches
  | free_in_tp_exp a (Fold(e)) = free_in_tp_exp a e
  | free_in_tp_exp a (Unfold(e)) = free_in_tp_exp a e
  | free_in_tp_exp a (Pack(tau,e)) = free_in_tp a tau orelse free_in_tp_exp a e
  | free_in_tp_exp a (Fix(y,NONE,e)) = free_in_tp_exp a e
  | free_in_tp_exp a (Fix(y,SOME(tau),e)) = free_in_tp a tau orelse free_in_tp_exp a e
  | free_in_tp_exp a (Def _) = false
  | free_in_tp_exp a (Marked(marked_e)) = free_in_tp_exp a (Mark.data marked_e)
and free_in_tp_branch a (p,e) = free_in_tp_exp a e (* patterns do not bind type vars *)

fun def_in_tp_exp a (Var _) = false
  | def_in_tp_exp a (Lam(y,NONE,e)) = def_in_tp_exp a e
  | def_in_tp_exp a (Lam(y,SOME(tau),e)) = def_in_tp a tau orelse def_in_tp_exp a e
  | def_in_tp_exp a (App(e1,e2)) = def_in_tp_exp a e1 orelse def_in_tp_exp a e2
  | def_in_tp_exp a (TpLam(b,e)) = def_in_tp_exp a e
  | def_in_tp_exp a (TpApp(e,tau)) = def_in_tp_exp a e orelse def_in_tp a tau
  | def_in_tp_exp a (Pair(e1,e2)) = def_in_tp_exp a e1 orelse def_in_tp_exp a e2
  | def_in_tp_exp a (Unit) = false
  | def_in_tp_exp a (Inject(i,e)) = def_in_tp_exp a e
  | def_in_tp_exp a (Record(prod)) =
    List.exists (fn (i,e) => def_in_tp_exp a e) prod
  | def_in_tp_exp a (Project(e,i)) = def_in_tp_exp a e
  | def_in_tp_exp a (Case(e,branches)) =
       def_in_tp_exp a e orelse List.exists (def_in_tp_branch a) branches
  | def_in_tp_exp a (Fold(e)) = def_in_tp_exp a e
  | def_in_tp_exp a (Unfold(e)) = def_in_tp_exp a e
  | def_in_tp_exp a (Pack(tau,e)) = def_in_tp a tau orelse def_in_tp_exp a e
  | def_in_tp_exp a (Fix(y,NONE,e)) = def_in_tp_exp a e
  | def_in_tp_exp a (Fix(y,SOME(tau),e)) = def_in_tp a tau orelse def_in_tp_exp a e
  | def_in_tp_exp a (Def _) = false
  | def_in_tp_exp a (Marked(marked_e)) = def_in_tp_exp a (Mark.data marked_e)
and def_in_tp_branch a (p,e) = def_in_tp_exp a e (* patterns do not bind type vars *)

fun subst_tp_exp tau a (e as Var _) = e
  | subst_tp_exp tau a (Lam(y,sigma_,e)) =
       Lam(y, Option.map (subst_tp_tp tau a) sigma_, subst_tp_exp tau a e)
  | subst_tp_exp tau a (App(e1,e2)) =
       App(subst_tp_exp tau a e1, subst_tp_exp tau a e2)
  | subst_tp_exp tau a (e as TpLam(b,e')) =
    if not (free_in_tp_exp a e) then e
    else if free_in_tp b tau
    then let val (b',e'') = rename_apart_tp_exp (fn c => not (free_in_tp c tau)
                                                         andalso not (free_in_tp_exp c e'))
                                                b (b,e')
         in TpLam(b', subst_tp_exp tau a e') end
    else TpLam(b, subst_tp_exp tau a e')
  | subst_tp_exp tau a (TpApp(e,sigma)) =
    TpApp(subst_tp_exp tau a e, subst_tp_tp tau a sigma)
  | subst_tp_exp tau a (Pair(e1,e2)) =
    Pair(subst_tp_exp tau a e1, subst_tp_exp tau a e2)
  | subst_tp_exp tau a (e as Unit) = e
  | subst_tp_exp tau a (Inject(i,e)) = Inject(i,subst_tp_exp tau a e)
  | subst_tp_exp tau a (Record(prod)) =
       Record(List.map (fn (i,e) => (i, subst_tp_exp tau a e)) prod)
  | subst_tp_exp tau a (Project(e,i)) =
       Project(subst_tp_exp tau a e,i)
  | subst_tp_exp tau a (Case(e,branches)) =
       Case(subst_tp_exp tau a e, List.map (subst_tp_branch tau a) branches)
  | subst_tp_exp tau a (Fold(e)) = Fold(subst_tp_exp tau a e)
  | subst_tp_exp tau a (Unfold(e)) = Unfold(subst_tp_exp tau a e)
  | subst_tp_exp tau a (Pack(sigma,e)) = Pack(subst_tp_tp tau a sigma, subst_tp_exp tau a e)
  | subst_tp_exp tau a (Fix(y,sigma_,e)) =
       Fix(y, Option.map (subst_tp_tp tau a) sigma_, subst_tp_exp tau a e)
  | subst_tp_exp tau a (e as Def _) = e
  | subst_tp_exp tau a (Marked(marked_e)) = subst_tp_exp tau a (Mark.data marked_e)
and subst_tp_branch tau a (p,e) = (p, subst_tp_exp tau a e)

and rename_apart_tp_exp pred c (b,e) = (* c is failed candidate for new name of b *)
    let val c' = c ^ "'"
    in if pred c'
       then (c', subst_tp_exp (TpVar(c')) b e)
       else rename_apart_tp_exp pred c' (b,e)
    end

fun def_in x (Var _) = false
  | def_in x (Lam(y,_,e)) = def_in x e
  | def_in x (App(e1,e2)) = def_in x e1 orelse def_in x e2
  | def_in x (TpLam(a,e)) = def_in x e
  | def_in x (TpApp(e,tau)) = def_in x e
  | def_in x (Pair(e1,e2)) = def_in x e1 orelse def_in x e2
  | def_in x (Unit) = false
  | def_in x (Inject(i,e)) = def_in x e
  | def_in x (Record(prod)) = List.exists (fn (i,e) => def_in x e) prod
  | def_in x (Project(e,i)) = def_in x e
  | def_in x (Case(e,branches)) =
       def_in x e orelse List.exists (def_in_branch x) branches
  | def_in x (Fold(e)) = def_in x e
  | def_in x (Unfold(e)) = def_in x e
  | def_in x (Pack(tau,e)) = def_in x e
  | def_in x (Fix(y,_,e)) = def_in x e
  | def_in x (Def(y)) = (x = y)
  | def_in x (Marked(marked_e)) = def_in x (Mark.data marked_e)
and def_in_branch x (p,e) = def_in x e

fun subst_var_var ((Var(z),x)::theta) y = if x = y then z else subst_var_var theta y
  | subst_var_var (_::theta) y = subst_var_var theta y
  | subst_var_var (nil) y = y (* identity if not explicitly in substitution *)

fun subst_exp_var ((e,x)::theta) y = if x = y then e else subst_exp_var theta y
  | subst_exp_var (nil) y = Var(y) (* identity if not explicitly in substitution *)

fun subst_exp_patvar ((Var(z),x)::theta) y =
    if x = y then VarPat(z) else subst_exp_patvar theta y
  | subst_exp_patvar (nil) y = VarPat(y) (* identity if not explicitly in substitution *)

fun subst_exp_exp theta (e as Var("_")) = e
  | subst_exp_exp theta (Var(y)) = subst_exp_var theta y
  | subst_exp_exp theta (Lam(y,tau_,f)) =
    let val (y',f') = subst_exp_exp_bind theta (y,f)
    in Lam(y',tau_,f') end
  | subst_exp_exp theta (App(e1,e2)) = App(subst_exp_exp theta e1, subst_exp_exp theta e2)
  | subst_exp_exp theta (TpLam(a,e')) = TpLam(a, subst_exp_exp theta e')
  | subst_exp_exp theta (TpApp(e',tau)) = TpApp(subst_exp_exp theta e', tau)
  | subst_exp_exp theta (Pair(e1,e2)) = Pair(subst_exp_exp theta e1, subst_exp_exp theta e2)
  | subst_exp_exp theta (e' as Unit) = e'
  | subst_exp_exp theta (Inject(i,e')) = Inject(i,subst_exp_exp theta e')
  | subst_exp_exp theta (Record(prod)) =
       Record(List.map (fn (i,e) => (i,subst_exp_exp theta e)) prod)
  | subst_exp_exp theta (Project(e,i)) = Project(subst_exp_exp theta e, i)
  | subst_exp_exp theta (Case(e',branches)) =
       Case(subst_exp_exp theta e', List.map (subst_exp_branch theta) branches)
  | subst_exp_exp theta (Fold(e')) = Fold(subst_exp_exp theta e')
  | subst_exp_exp theta (Unfold(e')) = Unfold(subst_exp_exp theta e')
  | subst_exp_exp theta (Pack(tau,e')) = Pack(tau, subst_exp_exp theta e')
  | subst_exp_exp theta (Fix(y,tau_,f)) =
    let val (y',f') = subst_exp_exp_bind theta (y,f)
    in Fix(y',tau_,f') end
  | subst_exp_exp theta (e' as Def _) = e'
  | subst_exp_exp theta (Marked(marked_e')) = subst_exp_exp theta (Mark.data marked_e')

and subst_exp_branch theta (p,f) =
    let val theta' = fresh_subst (fn z => not (free_in_branch z (p,f))) theta p
    in (subst_exp_pat theta' p, subst_exp_exp theta' f) end

and subst_exp_exp_bind theta (y,f) =
    let val theta' = fresh_subst (fn z => not (free_in_bind z (y,f))) theta (VarPat(y))
    in (subst_var_var theta' y, subst_exp_exp theta' f) end

and subst_exp_pat theta (p as VarPat("_")) = p
  | subst_exp_pat theta (VarPat(x)) = subst_exp_patvar theta x
  | subst_exp_pat theta (PairPat(p1,p2)) =
      PairPat(subst_exp_pat theta p1, subst_exp_pat theta p2)
  | subst_exp_pat theta (UnitPat) = UnitPat
  | subst_exp_pat theta (InjectPat(i,p)) = InjectPat(i,subst_exp_pat theta p)
  | subst_exp_pat theta (FoldPat(p)) = FoldPat(subst_exp_pat theta p)
  | subst_exp_pat theta (PackPat(a,p)) = PackPat(a,subst_exp_pat theta p)

and fresh_subst pred theta (VarPat("_")) = theta
  | fresh_subst pred theta (VarPat(y)) = ((fresh_var pred theta y,y)::theta)
  | fresh_subst pred theta (PairPat(p1,p2)) =
       fresh_subst pred (fresh_subst pred theta p1) p2
  | fresh_subst pred theta (UnitPat) = theta
  | fresh_subst pred theta (InjectPat(i,p)) = fresh_subst pred theta p
  | fresh_subst pred theta (FoldPat(p)) = fresh_subst pred theta p
  | fresh_subst pred theta (PackPat(a,p)) = fresh_subst pred theta p

and fresh_var pred theta y =
    if pred y andalso not (free_in_subst y theta)
    then Var(y)
    else fresh_var pred theta (y ^ "'")
                      
and rename_apart_exp_exp pred z (y,e) = (* z is failed candidate for new name of y *)
    let val z' = z ^ "'"
    in if pred z'
       then (z', subst_exp_exp [(Var(z'),y)] e)
       else rename_apart_exp_exp pred z' (y,e)
    end

fun stdize_tp (Arrow(tau1,tau2)) = Arrow(stdize_tp tau1, stdize_tp tau2)
  | stdize_tp (Forall(a,tau)) = Forall(stdize_tp_bind(a,tau))
  | stdize_tp (Times(tau1,tau2)) = Times(stdize_tp tau1, stdize_tp tau2)
  | stdize_tp (tau as One) = tau
  | stdize_tp (Plus(sum)) = Plus(List.map (fn (i,tau) => (i, stdize_tp tau)) sum)
  | stdize_tp (With(prod)) = With(List.map (fn (i,tau) => (i, stdize_tp tau)) prod)
  | stdize_tp (Exists(a,tau)) = Exists(stdize_tp_bind(a,tau))
  | stdize_tp (Rho(a,tau)) = Rho(stdize_tp_bind(a,tau))
  | stdize_tp (TLam(a,tau)) = TLam(stdize_tp_bind(a,tau))
  | stdize_tp (TApp(tau1,tau2)) = TApp(stdize_tp tau1, stdize_tp tau2)
  | stdize_tp (tau as TpVar _) = tau
  | stdize_tp (tau as TpDef _) = tau
and stdize_tp_bind (a,tau) =
    if def_in_tp a tau
    then let val (a',tau') = rename_apart_tp_tp (fn b => not (def_in_tp b tau)) a (a,tau)
         in (a', stdize_tp tau') end
    else (a,stdize_tp tau)

(* stdize_exp removes any remaining marks *)
fun stdize_exp (e as Var _) = e
  | stdize_exp (Lam(x,tau_,e)) =
    let val (x',e') = stdize_exp_bind (x,e)
    in Lam(x', Option.map stdize_tp tau_, e') end
  | stdize_exp (App(e1,e2)) = App (stdize_exp e1, stdize_exp e2)
  | stdize_exp (TpLam(a,e)) =
    if def_in_tp_exp a e
    then let val (a',e') = rename_apart_tp_exp (fn b => not (def_in_tp_exp b e)) a (a,e)
         in TpLam(a', stdize_exp e') end
    else TpLam(a, stdize_exp e)
  | stdize_exp (TpApp(e,tau)) = TpApp(stdize_exp e, stdize_tp tau)
  | stdize_exp (Pair(e1,e2)) = Pair(stdize_exp e1, stdize_exp e2)
  | stdize_exp (e as Unit) = e
  | stdize_exp (Inject(i,e)) = Inject(i,stdize_exp e)
  | stdize_exp (Record(prod)) = Record(List.map (fn (i,e) => (i, stdize_exp e)) prod)
  | stdize_exp (Project(e,i)) = Project(stdize_exp e, i)
  | stdize_exp (Case(e,branches)) =
       Case(stdize_exp e, List.map stdize_branch branches)
  | stdize_exp (Fold(e)) = Fold(stdize_exp e)
  | stdize_exp (Unfold(e)) = Unfold(stdize_exp e)
  | stdize_exp (Pack(tau,e)) = Pack(stdize_tp tau, stdize_exp e)
  | stdize_exp (Fix(x,tau_,e)) =
    let val (x',e') = stdize_exp_bind (x,e)
    in Fix(x', Option.map stdize_tp tau_, e') end
  | stdize_exp (e as Def(x)) = e
  | stdize_exp (Marked(marked_e)) = stdize_exp (Mark.data marked_e)
and stdize_exp_bind ("_",e) = ("_",stdize_exp e)
  | stdize_exp_bind (x,e) =
    let val theta = fresh_subst (fn z => not (def_in z e)) nil (VarPat(x))
    in (subst_var_var theta x, stdize_exp (subst_exp_exp theta e)) end
and stdize_branch (p,e) =
    let val theta = fresh_subst (fn z => not (def_in z e)) nil p
    in (subst_exp_pat theta p, stdize_exp (subst_exp_exp theta e)) end

fun lookup_tpvar nil a = NONE
  | lookup_tpvar (Type(b,tau,ext)::env) a =
    if b = a then SOME(tau) else lookup_tpvar env a
  | lookup_tpvar (dec::env) a = lookup_tpvar env a

fun expand_tp env a =
    case lookup_tpvar env a
     of SOME(tau) => tau
      | NONE => raise Match

fun whnf_tp env (TpDef(a)) = expand_tp env a
  | whnf_tp env (TApp(tau1,tau2)) =
    apply_tp env (whnf_tp env tau1) tau2
and apply_tp env (TLam(a,tau1)) tau2 =
    subst_tp_tp tau2 a tau1

fun expand_tapp env tau = whnf_tp env tau

fun expose_tp env (TpDef(a)) = expose_tp env (expand_tp env a)
  | expose_tp env (tau as TApp _) = expose_tp env (expand_tapp env tau)
  | expose_tp env A = A

fun lookup_var_tp nil x = NONE
  | lookup_var_tp (Decl(y,tau,ext)::env) x =
    if y = x then SOME(tau) else lookup_var_tp env x
  | lookup_var_tp (dec::env) x = lookup_var_tp env x

fun lookup_var nil x = NONE
  | lookup_var (Defn(y,e,ext)::env) x =
    if y = x then SOME(e) else lookup_var env x
  | lookup_var (dec::env) x = lookup_var env x

fun expand_exp env x =
    case lookup_var env x
     of SOME(e) => e
      | NONE => raise Match

fun defined_tpvar env a =
    case lookup_tpvar env a
     of NONE => false
      | SOME _ => true

fun defined_var env x =
    case lookup_var env x
     of NONE => false
      | SOME _ => true

fun declared_var env x =
    case lookup_var_tp env x
     of NONE => false
      | SOME _ => true

fun frees_tp env tpctx (Arrow(tau1,tau2)) tpvars =
       frees_tp env tpctx tau1 (frees_tp env tpctx tau2 tpvars)
  | frees_tp env tpctx (Forall(a,tau)) tpvars =
       frees_tp env (a::tpctx) tau tpvars
  | frees_tp env tpctx (TpVar("_")) tpvars = tpvars
  | frees_tp env tpctx (TpVar(a)) tpvars =
    if List.exists (fn b => a = b) tpctx
       orelse List.exists (fn b => a = b) tpvars
    then tpvars else a::tpvars
  | frees_tp env tpctx (Times(tau1,tau2)) tpvars =
       frees_tp env tpctx tau1 (frees_tp env tpctx tau2 tpvars)
  | frees_tp env tpctx (One) tpvars = tpvars
  | frees_tp env tpctx (Plus(sum)) tpvars =
       List.foldr (fn ((i,tau),tpvars') => frees_tp env tpctx tau tpvars') tpvars sum
  | frees_tp env tpctx (With(prod)) tpvars =
       List.foldr (fn ((i,tau),tpvars') => frees_tp env tpctx tau tpvars') tpvars prod
  | frees_tp env tpctx (Exists(a,tau)) tpvars =
       frees_tp env (a::tpctx) tau tpvars
  | frees_tp env tpctx (Rho(a,tau)) tpvars =
       frees_tp env (a::tpctx) tau tpvars
  | frees_tp env tpctx (TLam(a,tau)) tpvars =
       frees_tp env (a::tpctx) tau tpvars
  | frees_tp env tpctx (TApp(tau1,tau2)) tpvars =
       frees_tp env tpctx tau1 (frees_tp env tpctx tau2 tpvars)
  | frees_tp env tpctx (TpDef(a)) tpvars = tpvars

fun frees_tp_ env tpctx NONE tpvars = tpvars
  | frees_tp_ env tpctx (SOME(tau)) tpvars = frees_tp env tpctx tau tpvars

fun frees env tpctx ctx (Var("_")) (tpvars, vars) = (tpvars, vars)
  | frees env tpctx ctx (Var(x)) (tpvars, vars) =
    if List.exists (fn y => x = y) ctx
       orelse List.exists (fn y => x = y) vars
    then (tpvars,vars) else (tpvars,x::vars)
  | frees env tpctx ctx (Lam(x,tau_,e)) (tpvars, vars) =
      frees env tpctx (x::ctx) e (frees_tp_ env tpctx tau_ tpvars, vars)
  | frees env tpctx ctx (App(e1,e2)) (tpvars, vars) =
      frees env tpctx ctx e1 (frees env tpctx ctx e2 (tpvars, vars))
  | frees env tpctx ctx (TpLam(a,e)) (tpvars, vars) =
      frees env (a::tpctx) ctx e (tpvars, vars)
  | frees env tpctx ctx (TpApp(e,tau)) (tpvars, vars) =
      frees env tpctx ctx e (frees_tp env tpctx tau tpvars, vars)
  | frees env tpctx ctx (Pair(e1,e2)) (tpvars, vars) =
      frees env tpctx ctx e1 (frees env tpctx ctx e2 (tpvars, vars))
  | frees env tpctx ctx (Unit) (tpvars, vars) = (tpvars, vars)
  | frees env tpctx ctx (Inject(i,e)) (tpvars, vars) =
      frees env tpctx ctx e (tpvars, vars)
  | frees env tpctx ctx (Record(prod)) (tpvars, vars) =
      frees_prod env tpctx ctx prod (tpvars, vars)
  | frees env tpctx ctx (Project(e,i)) (tpvars, vars) =
      frees env tpctx ctx e (tpvars, vars)
  | frees env tpctx ctx (Case(e,branches)) (tpvars, vars) =
      frees env tpctx ctx e (frees_branches env tpctx ctx branches (tpvars, vars))
  | frees env tpctx ctx (Fold(e)) (tpvars, vars) =
      frees env tpctx ctx e (tpvars, vars)
  | frees env tpctx ctx (Unfold(e)) (tpvars, vars) =
      frees env tpctx ctx e (tpvars, vars)
  | frees env tpctx ctx (Pack(tau,e)) (tpvars, vars) =
      frees env tpctx ctx e (frees_tp env tpctx tau tpvars, vars)
  | frees env tpctx ctx (Fix(x,tau_,e)) (tpvars, vars) =
      frees env tpctx (x::ctx) e (frees_tp_ env tpctx tau_ tpvars, vars)
  | frees env tpctx ctx (Def(x)) (tpvars, vars) = (tpvars, vars)
  | frees env tpctx ctx (Marked(marked_e)) (tpvars, vars) =
      frees env tpctx ctx (Mark.data marked_e) (tpvars, vars)
and frees_branches env tpctx ctx nil (tpvars, vars) = (tpvars, vars)
  | frees_branches env tpctx ctx (branch::branches) (tpvars, vars) =
    frees_branch env tpctx ctx branch
                 (frees_branches env tpctx ctx branches (tpvars, vars))
and frees_branch env tpctx ctx (p,e) (tpvars, vars) =
    frees env (frees_pat_tp p @ tpctx) (frees_pat p @ ctx) e (tpvars, vars)
and frees_prod env tpctx ctx ((i,e)::prod) (tpvars, vars) =
    frees env tpctx ctx e (frees_prod env tpctx ctx prod (tpvars, vars))
  | frees_prod env tpctx ctx nil (tpvars, vars) = (tpvars, vars)
and frees_pat (VarPat("_")) = []
  | frees_pat (VarPat(x)) = [x]
  | frees_pat (PairPat(p1,p2)) = frees_pat p1 @ frees_pat p2
  | frees_pat (UnitPat) = []
  | frees_pat (InjectPat(i,p)) = frees_pat p
  | frees_pat (FoldPat(p)) = frees_pat p
  | frees_pat (PackPat(a,p)) = frees_pat p
and frees_pat_tp (VarPat _) = []
  | frees_pat_tp (PairPat(p1,p2)) =
    frees_pat_tp p1 @ frees_pat_tp p2
  | frees_pat_tp (UnitPat) = []
  | frees_pat_tp (InjectPat(i,p)) = frees_pat_tp p
  | frees_pat_tp (FoldPat(p)) = frees_pat_tp p
  | frees_pat_tp (PackPat(a,p)) = [a] @ frees_pat_tp p

fun free_tpvars env tau = frees_tp env nil tau nil
fun free_vars env e = frees env nil nil e (nil,nil)
fun free_vars_pat env p = frees_pat p

fun rename_apart_defs env z (y,e') =
    let val z' = z ^ "'"
    in if defined_var env z'
       then rename_apart_defs env z' (y,e')
       else (z', subst_exp_exp [(Var(z'),y)] e')
    end

val subst_exp = fn theta => fn e => subst_exp_exp theta e
val subst_exp_exp = fn e => fn x => fn f => subst_exp_exp [(e,x)] f
                            
(************)
(* Printing *)
(************)

structure Print =
struct

fun pp_tag i = "'" ^ i

fun pp_tags (nil) = "(no alternatives)"
  | pp_tags (i::nil) = pp_tag i
  | pp_tags (i::tags) = pp_tag i ^ " " ^ pp_tags tags

fun parens s = "(" ^ s ^ ")"

fun pparens prec_here prec_above s =
    if prec_here <= prec_above then parens s else s

(* all type operators are right associative *)
(*
   prec 0: !a. $a.         (prefix)
   prec 1: ->              (right assoc)
   prec 2: +               (right assoc)
   prec 3: *               (right assoc)
   prec 4: "juxtaposition" (left assoc)
*)
fun pp_tp prec (Arrow(tau1,tau2)) = pparens 1 prec (pp_tp 1 tau1 ^ " -> " ^ pp_tp 0 tau2)
  | pp_tp prec (Forall(a,tau)) = pparens 0 prec ("!" ^ a ^ ". " ^ pp_tp ~1 tau)
  | pp_tp prec (TpVar(a)) = a
  | pp_tp prec (Times(tau1,tau2)) = pparens 3 prec (pp_tp 3 tau1 ^ " * " ^ pp_tp 2 tau2)
  | pp_tp prec (One) = "1"
  | pp_tp prec (Plus(nil)) = "0"
  | pp_tp prec (Plus([(i,tau)])) = pparens 2 prec (pp_tagged (i,tau) ^ " + " ^ "()")
  | pp_tp prec (Plus(sum)) = pparens 2 prec (pp_sum sum)
  | pp_tp prec (With(nil)) = "() & ()"
  | pp_tp prec (With([(i,tau)])) = pparens 2 prec (pp_tagged (i,tau) ^ " & " ^ "()")
  | pp_tp prec (With(prod)) = pparens 2 prec (pp_prod prod)
  | pp_tp prec (Exists(a,tau)) = pparens 0 prec ("?" ^ a ^ ". " ^ pp_tp ~1 tau)
  | pp_tp prec (Rho(a,tau)) = pparens 0 prec ("$" ^ a ^ ". " ^ pp_tp ~1 tau)
  | pp_tp prec (TLam(a,tau)) = pparens 0 prec ("\\" ^ a ^ ". " ^ pp_tp ~1 tau)
  | pp_tp prec (TApp(tau1,tau2)) = pparens 4 prec (pp_tp 3 tau1 ^ " " ^ pp_tp 4 tau2) (* todo: 3? *)
  | pp_tp prec (TpDef(a)) = a
and pp_sum ((i,tau)::nil) = pp_tagged (i,tau)
  | pp_sum ((i,tau)::sum) = pp_tagged (i,tau) ^ " + " ^ pp_sum sum
and pp_prod ((i,tau)::nil) = pp_tagged (i,tau)
  | pp_prod ((i,tau)::prod) = pp_tagged (i,tau) ^ " & " ^ pp_prod prod
and pp_tagged (i,tau) = "(" ^ pp_tag i ^ ":" ^ pp_tp ~1 tau ^ ")"

val pp_tp = fn tau => pp_tp ~1 tau

(* operator precedence *)
(* 
  prec 2: "juxtaposition"              left associative, strongest precedence
  prec 1: fold unfold 'l 'r \x. $x.    prefix
  prec 0: ,                            right associative, weakest precedence
 *)
fun pp_exp_tp tau = "[" ^ pp_tp tau ^ "]"
fun pp_exp prec (Var(x)) = x
  | pp_exp prec (Lam(x,NONE,e)) = pparens 1 prec ("\\" ^ x ^ ". " ^ pp_exp 0 e)
  | pp_exp prec (Lam(x,SOME(tau),e)) = pparens 1 prec ("\\" ^ x ^ ":" ^ pp_tp tau ^ ". " ^ pp_exp 0 e)
  | pp_exp prec (App(e1,e2)) = pparens 2 prec (pp_exp 1 e1 ^ " " ^ pp_exp 2 e2)
  | pp_exp prec (TpLam(a,e)) = pparens 1 prec ("/\\" ^ a ^ ". " ^ pp_exp 0 e)
  | pp_exp prec (TpApp(e,tau)) = pparens 2 prec (pp_exp 1 e ^ " " ^ pp_exp_tp tau)
  | pp_exp prec (Pair(e1,e2)) = pparens 0 prec (pp_exp 0 e1 ^ ", " ^ pp_exp ~1 e2)
  | pp_exp prec (Unit) = "()"
  | pp_exp prec (Inject(i,e)) = pparens 1 prec (pp_tag i ^ " " ^ pp_exp 0 e)
  | pp_exp prec (Record(prod)) = "(| " ^ pp_prod prod ^ " |)"
  | pp_exp prec (Project(e,i)) = pparens 2 prec (pp_exp 1 e ^ "." ^ pp_tag i) (* TODO: precedences *)
  | pp_exp prec (Case(e,branches)) = "case " ^ pp_exp ~1 e ^ " (" ^ pp_branches branches ^ ")"
  | pp_exp prec (Fold(e)) = pparens 1 prec ("fold " ^ pp_exp 0 e)
  | pp_exp prec (Unfold(e)) = pparens 1 prec ("unfold " ^ pp_exp 0 e)
  | pp_exp prec (Pack(tau,e)) = pparens 0 prec (pp_exp_tp tau ^ ", " ^ pp_exp ~1 e)
  | pp_exp prec (Fix(x,NONE,e)) = pparens 1 prec ("$" ^ x ^ ". " ^ pp_exp 0 e)
  | pp_exp prec (Fix(x,SOME(tau),e)) = pparens 1 prec ("$" ^ x ^ ":" ^ pp_tp tau ^ ". " ^ pp_exp 0 e)
  | pp_exp prec (Def(x)) = x
  | pp_exp prec (Marked(marked_e)) = pp_exp prec (Mark.data marked_e)
and pp_branches nil = ""
  | pp_branches ((p,e)::nil) = pp_branch (p,e)
  | pp_branches ((p,e)::branches) = pp_branch (p,e) ^ " | " ^ pp_branches branches
and pp_branch (p,e) = pp_pat 0 p ^ " => " ^ pp_exp ~1 e  (* 0 to force outermost parens *)
and pp_prod nil = ""
  | pp_prod ((i,e)::nil) = pp_tagged (i,e)
  | pp_prod ((i,e)::prod) = pp_tagged (i,e) ^ " | " ^ pp_prod prod
and pp_tagged (i,e) = pp_tag i ^ " => " ^ pp_exp ~1 e
and pp_pat prec (VarPat(x)) = x
  | pp_pat prec (PairPat(p1,p2)) = pparens 0 prec (pp_pat 0 p1 ^ ", " ^ pp_pat ~1 p2)
  | pp_pat prec (UnitPat) = "()"
  | pp_pat prec (InjectPat(i,p)) = pparens 1 prec (pp_tag i ^ " " ^ pp_pat 0 p)
  | pp_pat prec (FoldPat(p)) = pparens 1 prec ("fold " ^ pp_pat 0 p)
  | pp_pat prec (PackPat(a,p)) = pparens 0 prec ("[" ^ a ^ "]" ^ ", " ^ pp_pat ~1 p)

val pp_pat = fn p => pp_pat ~1 p
val pp_exp = fn e => pp_exp ~1 e

fun pp_val prec (Lam _) = "---"
  | pp_val prec (TpLam _) = "---"
  | pp_val prec (Pair(v1,v2)) = pparens 0 prec (pp_val 0 v1 ^ ", " ^ pp_val ~1 v2)
  | pp_val prec (Unit) = "()"
  | pp_val prec (Inject(i,v)) = pparens 1 prec (pp_tag i ^ " " ^ pp_val 0 v)
  | pp_val prec (Record _) = "---"
  | pp_val prec (Fold(v)) = pparens 1 prec ("fold " ^ pp_val 0 v)
  | pp_val prec (Pack(tau,v)) = pparens 0 prec ("[---]" ^ ", " ^ pp_val ~1 v)
  | pp_val prec e = raise Match  (* other expressions are not values, including marked ones *)

val pp_val = fn v => pp_val ~1 v

fun pp_dec (Type(a,tau,ext)) = "type " ^ a ^ " = " ^ pp_tp tau
  | pp_dec (Decl(x,tau,ext)) = "decl " ^ x ^ " : " ^ pp_tp tau
  | pp_dec (Defn(x,e,ext)) = "defn " ^ x ^ " = " ^ pp_exp e
  | pp_dec (Norm(x,e,NONE,ext)) = "% norm " ^ x ^ " = " ^ pp_exp e
  | pp_dec (Norm(x,e,SOME(n),ext)) = "% norm " ^ Int.toString n ^ " " ^ x ^ " = " ^ pp_exp e
  | pp_dec (Conv(e1,e2,ext)) = "conv " ^ pp_exp e1 ^ " = " ^ pp_exp e2
  | pp_dec (Eval(x,e,NONE,ext)) = "% eval " ^ x ^ " = " ^ pp_exp e
  | pp_dec (Eval(x,e,SOME(n),ext)) = "% eval " ^ Int.toString n ^ " " ^ x ^ " = " ^ pp_exp e
  | pp_dec (Not(dec,ext)) = "fail " ^ pp_dec dec
  | pp_dec (Pragma(s1,s2,ext)) = s1 ^ " " ^ s2

fun pp_env (nil) = ""
  | pp_env (d::nil) = pp_dec d
  | pp_env (d::env) = pp_dec d ^ "\n" ^ pp_env env

fun pp_valdec (d as Type _) = pp_dec d
  | pp_valdec (d as Decl _) = pp_dec d
  | pp_valdec (Defn(x,v,ext)) = "defn " ^ x ^ " = " ^ pp_val v
  | pp_valdec _ = raise Match

fun pp_valdecs (nil) = ""
  | pp_valdecs (d::nil) = pp_valdec d
  | pp_valdecs (d::ds) = pp_dec d ^ "\n" ^ pp_valdecs ds

end (* structure Print *)

end (* structure Ast *)
