(* Normalization and Equality *)
(* Frank Pfenning *)

signature NORM =
sig

    exception BoundExceeded
    exception NotSupported of Ast.exp

    val normalize : Ast.env -> int -> Ast.exp -> Ast.exp * int (* returns (nf, #steps) *)
    val alpha : Ast.env -> Ast.exp -> Ast.exp -> bool (* alpha-equal normal forms *)

end (* signature NORM *)

structure Norm :> NORM =
struct

structure A = Ast
structure PP = Ast.Print
structure TC = TypeCheck

exception BoundExceeded
exception NotSupported of A.exp

local
    val gas : int ref = ref 0
in
fun set_gas n = ( gas := n )
fun get_gas () = !gas
fun tick () = if !gas = 0 then raise BoundExceeded
              else gas := !gas - 1
end

fun whnf env (A.App(e1,e2)) = apply env (whnf env e1) e2
  | whnf env (A.TpApp(e,tau)) = apply_tp env (whnf env e) tau
  | whnf env (e as A.Var _) = e
  | whnf env (e as A.Lam _) = e
  | whnf env (e as A.TpLam _) = e
  | whnf env (A.Def(x)) = whnf env (A.expand_exp env x)
  | whnf env (A.Marked(marked_e)) = whnf env (Mark.data marked_e)
  | whnf env e = raise NotSupported(e)

and apply env (A.Lam(x,_,e)) e2 = ( tick() ; whnf env (A.subst_exp_exp e2 x e) )
  | apply env (e1 as A.Var _) e2 = A.App(e1,e2)
  | apply env (e1 as A.App _) e2 = A.App(e1,e2)
  | apply env (e1 as A.TpApp _) e2 = A.App(e1,e2)
  | apply env _ e2 = raise Match

and apply_tp env (A.TpLam(a,e)) tau2 = whnf env (A.subst_tp_exp tau2 a e)
  | apply_tp env (e1 as A.Var _) tau2 = A.TpApp(e1,tau2)
  | apply_tp env (e1 as A.App _) tau2 = A.TpApp(e1,tau2)
  | apply_tp env (e1 as A.TpApp _) tau2 = A.TpApp(e1,tau2)
  | apply_tp env _ tau2 = raise Match

fun norm env e = descend env (whnf env e)
and descend env (A.Lam(x,tau_,e)) = A.Lam(x,tau_,norm env e)
  | descend env (A.TpLam(a,e)) = A.TpLam(a, norm env e)
  | descend env (e as A.Var _) = e
  | descend env (A.App(e1,e2)) = A.App(descend env e1, norm env e2)
  | descend env (A.TpApp(e1,tau2)) = A.TpApp(descend env e1, tau2)
  | descend env _ = raise Match

fun normalize env n e =
    ( set_gas n ; (A.stdize_exp (norm env e), n - get_gas() ) )
    handle Overflow => raise BoundExceeded

fun equated_by_exp nil (x:A.varname,x':A.varname) = (x = x') (* free variables *)
  | equated_by_exp ((y,y')::eqs) (x,x') =
    if y = x then y' = x'
    else if y' = x' then y = x
    else equated_by_exp eqs (x,x')

fun eq_tp__ env NONE NONE = true
  | eq_tp__ env (SOME(tau)) NONE = true
  | eq_tp__ env NONE (SOME(tau)) = true
  | eq_tp__ env (SOME(tau)) (SOME(tau')) = TC.equal_tp env tau tau'

(* eq only for normal forms *)
fun eq env eqs (A.Var(x)) (A.Var(x')) = equated_by_exp eqs (x,x')
  | eq env eqs (A.Lam(x,tau_,e)) (A.Lam(x',tau'_,e')) =
    eq_tp__ env tau_ tau'_ andalso eq env ((x,x')::eqs) e e'
  | eq env eqs (A.App(e1,e2)) (A.App(e1',e2')) =
    eq env eqs e1 e1' andalso eq env eqs e2 e2'
  | eq env eqs (A.TpLam(a,e)) (A.TpLam(a',e')) =
    let val u = TC.unique_tpvar ()
    in eq env eqs (A.subst_tp_exp (A.TpVar(u)) a e)
          (A.subst_tp_exp (A.TpVar(u)) a' e')
    end 
  | eq env eqs (A.TpApp(e,tau))  (A.TpApp(e',tau')) =
    eq env eqs e e' andalso TC.equal_tp env tau tau'
  | eq env eqs _ _ = false
  (* A.Marked and A.Def should be impossible on both sides *)

(* alpha-conversion among normal forms *)
fun alpha env e1 e2 = eq env nil e1 e2

end (* structure Norm *)
