(* Elaboration *)
(* Frank Pfenning <fp@cs.cmu.edu> *)

signature ELAB =
sig

    (* elab_decs env decls = SOME(env') where decs are elaborated in
     * environment env (initially expected to be a copy of decs) *)
    val elab_decs : Ast.env -> Ast.dec list -> Ast.env option

end (* signature ELAB *)

structure Elab :> ELAB =
struct

structure A = Ast
structure PP = Ast.Print
structure TC = TypeCheck
structure N = Norm
structure E = Eval
val ERROR = ErrorMsg.ERROR

fun tpfam_to_kind (A.TLam(a,tau)) = A.KArrow(A.KType, tpfam_to_kind tau)
  | tpfam_to_kind tau = A.KType

fun head_tp (A.TpDef(a)) = a
  | head_tp (A.TApp(tau1,tau2)) = head_tp tau1
  | head_tp _ = raise Match

fun count_args_tp (A.KArrow(_,K)) = 1 + count_args_tp K
  | count_args_tp (A.KType) = 0

fun not_defined env x ext =
    case A.lookup_var env x
     of NONE => ()
      | SOME _ => ERROR ext ("variable " ^ x ^ " already defined")

fun not_defined_tp env a ext =
    case A.lookup_tpvar env a
     of NONE => ()
      | SOME _ => ERROR ext ("type variable " ^ a ^ " already defined")

fun check_lang (SOME(Flags.Lam)) (A.Type(_,_,ext)) = 
    ERROR ext ("type definition not allowed in language '" ^ Flags.pp_lang Flags.Lam ^ "'")
  | check_lang (SOME(Flags.Lam)) (A.Decl(_,_,ext)) =
    ERROR ext ("type declaration not allowed for language '" ^ Flags.pp_lang Flags.Lam ^ "'")
  | check_lang (SOME(Flags.Prf)) (A.Norm(_,_,_,ext)) =
    ERROR ext ("normalization not supported for language '" ^ Flags.pp_lang Flags.Prf ^ "'")
  | check_lang (SOME(Flags.Prf)) (A.Conv(_,_,ext)) =
    ERROR ext ("conversion not supported for language '" ^ Flags.pp_lang Flags.Prf ^ "'")
  | check_lang (SOME(Flags.Prf)) (A.Eval(_,_,_,ext)) =
    ERROR ext ("evaluation not supported for language '" ^ Flags.pp_lang Flags.Prf ^ "'")
  | check_lang _ _ = ()

fun type_of env x ext =
    case A.lookup_var_tp env x
     of SOME(tau) => tau
      | NONE => ERROR ext ("type of " ^ x ^ " not declared and not synthesizable")

fun check_type env x e ext =
    if !Flags.lang <> SOME(Flags.Lam)
    then if not (A.declared_var env x) andalso TC.synthable e
         then [A.Decl(x, TC.synth env e ext, ext)] (* ext?; may raise ErrorMsg.Error *)
         else ( TC.check env e (type_of env x ext) ext ; [] )
    else []

fun synth_eq_types env e1 e2 ext =
    if !Flags.lang <> SOME(Flags.Lam)
    then let val tau1 = TC.synth env e1 ext
             val tau2 = TC.synth env e2 ext
             val () = if TC.equal_tp env tau1 tau2 then ()
                      else ERROR ext ("types not equal\n"
                                      ^ "LHS: " ^ PP.pp_tp tau1 ^ "\n"
                                      ^ "RHS: " ^ PP.pp_tp tau2)
         in () end
    else ()

fun plural (_::_::_) = "s"
  | plural _ = ""

fun newline nil = ""
  | newline (_::_) = "\n"

fun varsToString xs = List.foldr (fn (x,s) => x ^ " " ^ s) "" xs

fun tpvarsToString bs = varsToString bs

fun undeclared_tpvars nil = ""
  | undeclared_tpvars bs = ("undeclared type variable" ^ plural bs ^ ": " ^ tpvarsToString bs)
fun undeclared_vars nil = ""
  | undeclared_vars xs = ("undeclared variable" ^ plural xs ^ ": " ^ tpvarsToString xs)

fun closed_tp env tpctx tau ext =
    let val () = case A.free_tpvars env tau
                  of nil => ()
                   | bs => ERROR ext (undeclared_tpvars bs)
    in () end
    
fun closed env e ext =
    let val () = case A.free_vars env e
                  of (nil,nil) => ()
                   | (bs,xs) => ERROR ext (undeclared_tpvars bs ^ newline bs ^ undeclared_vars xs)
    in () end

fun repeated_tags nil = nil
  | repeated_tags ((i:A.tag,_)::sum) =
    if List.exists (fn (j,_) => i = j) sum
    then i::repeated_tags sum
    else repeated_tags sum

fun resolve_tpdefs env tpctx (A.Arrow(tau1,tau2)) ext =
    A.Arrow(resolve_tpdefs env tpctx tau1 ext, resolve_tpdefs env tpctx tau2 ext)
  | resolve_tpdefs env tpctx (A.Forall(a,tau)) ext =
    A.Forall(a, resolve_tpdefs env (a::tpctx) tau ext)
  | resolve_tpdefs env tpctx (tau as A.TpVar(a)) ext =
    if List.exists (fn b => a = b) tpctx then tau
    else if A.defined_tpvar env a
    then resolve_tpdefs env tpctx (A.TpDef(a)) ext
    else tau (* free type variable! *)
  | resolve_tpdefs env tpctx (A.Times(tau1,tau2)) ext =
    A.Times(resolve_tpdefs env tpctx tau1 ext, resolve_tpdefs env tpctx tau2 ext)
  | resolve_tpdefs env tpctx (A.One) ext = A.One
  | resolve_tpdefs env tpctx (A.Plus(sum)) ext = A.Plus(resolve_alts env tpctx sum ext)
  | resolve_tpdefs env tpctx (A.With(prod)) ext = A.With(resolve_alts env tpctx prod ext)
  | resolve_tpdefs env tpctx (A.Exists(a,tau)) ext =
    A.Exists(a, resolve_tpdefs env (a::tpctx) tau ext)
  | resolve_tpdefs env tpctx (A.Rho(a,tau)) ext =
    A.Rho(a, resolve_tpdefs env (a::tpctx) tau ext)
  | resolve_tpdefs env tpctx (tau as A.TLam _) ext =
    ERROR ext ("unexpected type function: " ^ PP.pp_tp tau)
  | resolve_tpdefs env tpctx (tau as A.TApp _) ext =
    (case resolve_tapp env tpctx tau ext
      of (tau',A.KType) => tau'
       | (tau',kappa) => ERROR ext ("too few type arguments to type constructor"
                                    ^ PP.pp_tp tau ^ " requires " ^ Int.toString (count_args_tp kappa)
                                    ^ " more"))
  | resolve_tpdefs env tpctx (tau as A.TpDef(a)) ext =
    (case resolve_tapp env tpctx tau ext
      of (tau',A.KType) => tau'
       | (tau',kappa) => ERROR ext ("type constructor " ^ a ^ " requires "
                                    ^ Int.toString (count_args_tp kappa) ^ " type arguments"))

and resolve_alts env tpctx sum ext =
    (case repeated_tags sum
      of nil => resolve_alts' env tpctx sum ext
       | tags => ERROR ext ("repeated tags in sum: " ^ PP.pp_tags tags))
and resolve_alts' env tpctx nil ext = nil
  | resolve_alts' env tpctx ((i,tau)::sum) ext =
    (i,resolve_tpdefs env tpctx tau ext)::resolve_alts' env tpctx sum ext

and resolve_tapp env tpctx (tau as A.TApp(tau1,tau2)) ext =
    (case resolve_tapp env tpctx tau1 ext
      of (tau1', A.KArrow(A.KType,kappa)) =>
         (A.TApp(tau1', resolve_tpdefs env tpctx tau2 ext), kappa)
       | (tau1', A.KType) =>
         ERROR ext ("too many arguments to type constructor " ^ PP.pp_tp tau))
  | resolve_tapp env tpctx (tau as A.TpDef(a)) ext =
    (case A.lookup_tpvar env a
      of SOME(tau') => (tau, tpfam_to_kind tau')
       | NONE => ERROR ext ("undefined type name " ^ a))
  | resolve_tapp env tpctx (tau as A.TpVar(a)) ext =
    if not (List.exists (fn b => a = b) tpctx)
    then resolve_tapp env tpctx (A.TpDef(a)) ext
    else ERROR ext ("head of type application " ^ a ^ " is bound")
  | resolve_tapp env tpctx tau ext =
    ERROR ext ("head of type application not a type name"
               ^ PP.pp_tp tau)

fun resolve_tpdefs_ env tpctx tau_ ext =
    Option.map (fn tau => resolve_tpdefs env tpctx tau ext) tau_

fun resolve_tpdefs_fam env kctx (A.TLam(a,tau)) ext =
    A.TLam(a, resolve_tpdefs_fam env (a::kctx) tau ext)
  | resolve_tpdefs_fam env kctx tau ext = resolve_tpdefs env kctx tau ext

(* maybe separate name spaces for type variables and expression variables? *)
fun all_distinct ctx (A.VarPat("_")) ext = ctx
  | all_distinct ctx (A.VarPat(x)) ext =
    if List.exists (fn y => x = y) ctx
    then ERROR ext ("repeated variable in pattern: " ^ x)
    else x::ctx
  | all_distinct ctx (A.PairPat(p1,p2)) ext =
      all_distinct (all_distinct ctx p1 ext) p2 ext
  | all_distinct ctx (A.UnitPat) ext = ctx
  | all_distinct ctx (A.InjectPat(i,p)) ext = all_distinct ctx p ext
  | all_distinct ctx (A.FoldPat(p)) ext = all_distinct ctx p ext
  | all_distinct ctx (A.PackPat(a,p)) ext = all_distinct (a::ctx) p ext

fun resolve_defs env tpctx ctx (A.Var("_")) ext =
    ERROR ext ("cannot use '_' as a variable occurrence")
  | resolve_defs env tpctx ctx (e as A.Var(x)) ext =
    if List.exists (fn y => x = y) ctx then e
    else if A.defined_var env x then A.Def(x)
    else e (* free variable! *)
  | resolve_defs env tpctx ctx (A.Lam(x,tau_,e)) ext =
    A.Lam(x, resolve_tpdefs_ env tpctx tau_ ext, resolve_defs env tpctx (x::ctx) e ext)
  | resolve_defs env tpctx ctx (A.App(e1,e2)) ext =
    A.App(resolve_defs env tpctx ctx e1 ext, resolve_defs env tpctx ctx e2 ext)
  | resolve_defs env tpctx ctx (A.TpLam(a,e)) ext =
    A.TpLam(a, resolve_defs env (a::tpctx) ctx e ext)
  | resolve_defs env tpctx ctx (A.TpApp(e,tau)) ext =
    A.TpApp(resolve_defs env tpctx ctx e ext, resolve_tpdefs env tpctx tau ext)
  | resolve_defs env tpctx ctx (A.Pair(e1,e2)) ext =
    A.Pair(resolve_defs env tpctx ctx e1 ext, resolve_defs env tpctx ctx e2 ext)
  | resolve_defs env tpctx ctx (A.Unit) ext = A.Unit
  | resolve_defs env tpctx ctx (A.Inject(i,e)) ext = A.Inject(i,resolve_defs env tpctx ctx e ext)
  | resolve_defs env tpctx ctx (A.Record(texps)) ext =
    A.Record(List.map (fn (i,e) => (i, resolve_defs env tpctx ctx e ext)) texps)
  | resolve_defs env tpctx ctx (A.Project(e,i)) ext =
    A.Project(resolve_defs env tpctx ctx e ext, i)
  | resolve_defs env tpctx ctx (A.Case(e,branches)) ext =
    A.Case(resolve_defs env tpctx ctx e ext, resolve_defs_branches env tpctx ctx branches ext)
  | resolve_defs env tpctx ctx (A.Fold(e)) ext = A.Fold(resolve_defs env tpctx ctx e ext)
  | resolve_defs env tpctx ctx (A.Unfold(e)) ext = A.Unfold(resolve_defs env tpctx ctx e ext)
  | resolve_defs env tpctx ctx (A.Pack(tau,e)) ext =
    A.Pack(resolve_tpdefs env tpctx tau ext, resolve_defs env tpctx ctx e ext)
  | resolve_defs env tpctx ctx (A.Fix(g,tau_,e)) ext =
    A.Fix(g, resolve_tpdefs_ env tpctx tau_ ext, resolve_defs env tpctx (g::ctx) e ext)
  | resolve_defs env tpctx ctx (A.Marked(marked_e)) ext =
    A.Marked(Mark.mark'(resolve_defs env tpctx ctx (Mark.data marked_e) (Mark.ext marked_e),
                        Mark.ext marked_e))
  | resolve_defs env tpctx ctx (A.Def _) ext = raise Match
and resolve_defs_branches env tpctx ctx nil ext = nil
  | resolve_defs_branches env tpctx ctx ((p,e)::branches) ext =
    resolve_defs_branch env tpctx ctx (p,e) ext::resolve_defs_branches env tpctx ctx branches ext
and resolve_defs_branch env tpctx ctx (p,e) ext =
    ( ignore (all_distinct nil p ext) (* check all pattern variables are distinct *)
    ; (p, resolve_defs env tpctx (A.free_vars_pat env p @ ctx) e ext) )
    
(* resolve identifiers: are the global, local, or free
 * free variables are only permitted if the result will
 * not be bound to a global name
 *)
fun resolve_ids env x e ext =
    let val e' = resolve_defs env nil nil e ext
        val () = if x = "_" then () else closed env e' ext
    in e' end

fun resolve_ids_tp env tau ext =
    let val tau' = resolve_tpdefs env nil tau ext
        val () = closed_tp env nil tau' ext
    in tau' end

fun resolve_ids_tpfam env tau ext =
    let val tau' = resolve_tpdefs_fam env nil tau ext
        val () = closed_tp env nil tau' ext
    in tau' end

(* sublanguage of CBV that describes proofs *)
fun is_prop (A.Arrow(tau1,tau2)) ext = ( is_prop tau1 ext ; is_prop tau2 ext )
  | is_prop (A.Forall(a,tau)) ext = is_prop tau ext
  | is_prop (A.TpVar(a)) ext = ()
  | is_prop (A.Times(tau1,tau2)) ext = ( is_prop tau1 ext ; is_prop tau2 ext )
  | is_prop (A.One) ext = ()
  | is_prop (A.Plus(sum)) ext = List.app (fn (i,tau) => is_prop tau ext) sum
  | is_prop (A.With(prod)) ext = List.app (fn (i,tau) => is_prop tau ext) prod
  | is_prop (A.Exists(a,tau)) ext = is_prop tau ext
  | is_prop (A.TLam(a,tau)) ext = is_prop tau ext
  | is_prop (A.TApp(tau1,tau2)) ext = ( is_prop tau1 ext ; is_prop tau2 ext )
  | is_prop (tau as A.Rho _) ext = ERROR ext ("recursive types not permitted in propositions\n"
                                               ^ PP.pp_tp tau)
  | is_prop (tau as A.TpDef _) ext = () (* all type definitions have been checked already *)

fun is_prop_ NONE ext = ()
  | is_prop_ (SOME(tau)) ext = is_prop tau ext

fun resolve_prop env tau ext =
    let val tau' = resolve_tpdefs_fam env nil tau ext
        val () = closed_tp env nil tau' ext
        val () = is_prop tau' ext
    in tau' end

fun is_pat (A.VarPat _) ext = ()
  | is_pat (A.PairPat(p1,p2)) ext = ( is_pat p1 ext ; is_pat p2 ext )
  | is_pat (A.UnitPat) ext = ()
  | is_pat (A.InjectPat(i,p)) ext = is_pat p ext
  | is_pat (A.FoldPat(p)) ext = ERROR ext "fold not permitted as a pattern in proofs"
  | is_pat (A.PackPat(a,p)) ext = is_pat p ext

fun is_proof (A.Var _) ext = ()
  | is_proof (A.Lam(x,tau_,e)) ext = ( is_prop_ tau_ ext ; is_proof e ext )
  | is_proof (A.App(e1,e2)) ext = ( is_proof e1 ext ; is_proof e2 ext )
  | is_proof (A.TpLam(a,e)) ext = is_proof e ext
  | is_proof (A.TpApp(e,tau)) ext = ( is_proof e ext ; is_prop tau ext )
  | is_proof (A.Pair(e1,e2)) ext = ( is_proof e1 ext ; is_proof e2 ext )
  | is_proof (A.Unit) ext = ()
  | is_proof (A.Inject(i,e)) ext = is_proof e ext
  | is_proof (A.Record(texps)) ext = List.app (fn (i,e) => is_proof e ext) texps
  | is_proof (A.Project(e,i)) ext = is_proof e ext
  | is_proof (A.Case(e,branches)) ext = List.app (fn (p,e) => (is_pat p ext ; is_proof e ext)) branches
  | is_proof (e as A.Fold _) ext = ERROR ext ("fold not permitted in proofs")
  | is_proof (A.Unfold(e)) ext = ERROR ext ("unfold not permitted in proofs")
  | is_proof (A.Pack(tau,e)) ext = ( is_prop tau ext ; is_proof e ext )
  | is_proof (e as A.Fix _) ext = ERROR ext ("fixed point expression not permitted in proofs")
  | is_proof (A.Def(c)) ext = () (* prior definitions already checked *)
  | is_proof (A.Marked(marked_e)) ext = is_proof (Mark.data marked_e) (Mark.ext marked_e)

(* build a definition, ignoring anonymous ones *)
fun defn "_" e ext = []
  | defn x e ext = [A.Defn(x, e, ext)]

fun conv_lam env (A.Conv(e1,e2,ext)) =
    let val e1' = resolve_ids env "_" e1 ext (* allow free variables, free type vars *)
        val e2' = resolve_ids env "_" e2 ext (* allow free variables, free type vars *)
        val () = synth_eq_types env e1' e2' ext
        val (e1'', _) = N.normalize env ~1 e1'
                        handle N.BoundExceeded => ERROR ext ("left-hand side does not converge")
                             | N.NotSupported(e) => ERROR ext ("conversion supported only for untyped and polymorphic lambda-calculus\n"
                                                               ^ PP.pp_exp e)
        val (e2'', _) = N.normalize env ~1 e2'
                        handle N.BoundExceeded => ERROR ext ("right-hand side does not converge")
                             | N.NotSupported(e) => ERROR ext ("conversion supported only for untyped and polymorphic lambda-calculus\n"
                                                               ^ PP.pp_exp e)
        val () = if N.alpha env e1'' e2''
                 then ()
                 else ERROR ext ("normal forms not equal\n"
                                 ^ "LHS: " ^ PP.pp_exp e1'' ^ "\n"
                                 ^ "RHS: " ^ PP.pp_exp e2'')
    in () end

fun conv_cbv env (A.Conv(e1,e2,ext)) =
    let val e1' = resolve_ids env "" e1 ext (* do not allow free variables, but do not bind anything *)
        val e2' = resolve_ids env "" e2 ext (* do not allow free variables, but do not bind anything *)
        val () = synth_eq_types env e1' e2' ext
        val (v1, steps1) = E.eval env ~1 e1' (* no bound *)
            handle E.NoMatch => ERROR ext ("value did not match any branch during evaluation")
        val (v2, steps2) = E.eval env ~1 e2' (* no bound *)
            handle E.NoMatch => ERROR ext ("value did not match any branch during evaluation")
        val eq = E.eq_val env v1 v2
            handle E.NotComparable(msg) => ERROR ext ("cannot compare " ^ msg ^ " for equality")
        val () = if eq then ()
                 else ERROR ext ("values not equal\n"
                                 ^ "LHS: " ^ PP.pp_exp v1 ^ "\n"
                                 ^ "RHS: " ^ PP.pp_exp v2 ^ "\n")
    in () end

fun check_conv env dec ext =
    (case !Flags.lang
      of SOME(Flags.Lam) => conv_lam env dec
       | SOME(Flags.Poly) => conv_lam env dec
       | SOME(Flags.CBV) => conv_cbv env dec
       | SOME(Flags.Prf) => ERROR ext "conversion test on proof not supported"
       | NONE => raise Match)

fun check_lang_tp env tau ext =
    (case !Flags.lang
      of SOME(Flags.Prf) => is_prop tau ext
       | _ => () )              (* check here for Lam or Poly? *)

fun check_lang_exp env e ext =
    (case !Flags.lang
      of SOME(Flags.Prf) => is_proof e ext
       | _ => () )              (* check here for Lam or Poly? *)

(* elab env decs = env' *)
fun elab env nil = env
  | elab env (dec::decs) =
    ( if !Flags.verbosity >= 1
      then TextIO.print (PP.pp_dec dec ^ "\n")
      else ()
    ; check_lang (!Flags.lang) dec
    ; elab' env (dec::decs) )
and elab' env (A.Type(a,tau,ext)::decs) =
    let val () = not_defined_tp env a ext
        val tau' = resolve_ids_tpfam env tau ext
        val () = check_lang_tp env tau' ext
    in elab (env @ [A.Type(a,tau',ext)]) decs end
  | elab' env (A.Decl(x,tau,ext)::decs) =
    let val () = not_defined env x ext
        val tau' = resolve_ids_tp env tau ext
        val () = check_lang_tp env tau' ext
    in elab (env @ [A.Decl(x,tau',ext)]) decs end
  | elab' env ((dec as A.Defn(x,e,ext))::decs) =
    let val () = not_defined env x ext
        val e' = resolve_ids env x e ext
        val () = check_lang_exp env e ext
        val decls_x = check_type env x e' ext
        val () = if !Flags.verbosity >= 1
                 then TextIO.print (PP.pp_env decls_x ^ newline decls_x)
                 else ()
        val defns_x = defn x e' ext
    in elab (env @ decls_x @ defns_x) decs end (* inefficient, but okay *)
  | elab' env (A.Norm(x,e,bd_opt,ext)::decs) =
    (* normalize, and turn into variable declaration *)
    let val () = not_defined env x ext
        val n = case bd_opt of NONE => ~1 | SOME(n) => n
        val e' = resolve_ids env x e ext
        val decls_x = check_type env x e' ext
        val (e'', steps) = N.normalize env n e'
            handle N.BoundExceeded => ERROR ext ("bound on number of steps exceeded")
                 | N.NotSupported(e) => ERROR ext ("normalization supported only for untyped and polymorphic lambda-calculus\n"
                                                   ^ PP.pp_exp e)
        val defns_x = defn x e'' ext
        val ds = decls_x @ defns_x
        val () = if !Flags.verbosity >= 1
                 then ( TextIO.print ("% " ^ Int.toString steps ^ " beta-reductions\n")
                      ; TextIO.print (PP.pp_env ds ^ newline ds) )
                 else ()
    in elab (env @ ds) decs end
  | elab' env ((dec as A.Conv(e1,e2,ext))::decs) =
    ( check_conv env dec ext ;
      elab (env @ [dec]) decs )
  | elab' env (A.Eval(x,e,bd_opt,ext)::decs) =
  (* evaluate, and turn into a variable declaration *)
    let val () = not_defined env x ext
        val n = case bd_opt of NONE => ~1 | SOME(n) => n
        val e' = resolve_ids env x e ext
        val decls_x = check_type env x e' ext
        val (e'', steps) = E.eval env n e'
            handle E.BoundExceeded => ERROR ext ("bound on number of steps exceeded")
                 | E.NoMatch => ERROR ext ("value did not match any branch during evaluation")
        val defns_x = defn x e'' ext
        val ds = decls_x @ defns_x
        val () = if !Flags.verbosity >= 1
                 then ( TextIO.print ("% " ^ Int.toString steps ^ " evaluation steps\n")
                      ; TextIO.print (PP.pp_valdecs ds ^ newline ds) )
                 else ()
    in elab (env @ ds) decs end
  | elab' env ((dec as A.Not(dec',ext))::decs) =
    (* negations cannot be nested *)
    let val result = SOME (ErrorMsg.suppress (fn () => elab' env [dec']))
                     handle ErrorMsg.Error => NONE
    in case result
        of NONE => ( if !Flags.verbosity >= 1
                     then TextIO.print ("% passed due to expected error in negated declaration\n") else ()
                   ; elab env decs )
         | SOME _ => ERROR ext ("negated declaration passed")
    end 
  | elab' env ((dec as A.Pragma(p,line,ext))::decs) =
    ERROR ext ("unexpected pragma:\n" ^ PP.pp_dec dec ^ "\n"
               ^ "pragmas must precede all other declarations\n")
  | elab' env nil = raise Match

(* elab_decs env decs = SOME(env')
 * if elaboration of decs succeeds with respect to env, yielding env'
 * Returns NONE if there is a static error
 *)
fun elab_decs env decs =
    let val env' = elab env decs
    in
        SOME(env')
    end
    handle ErrorMsg.Error => NONE

end (* structure Elab *)
