open Printf

type op = Union of int * int
type check = { a:int; b:int; expect:bool }
type case = { name:string; size:int; ops:op list; checks:check list }

let read_file path =
  let ic = open_in path in
  let b = Buffer.create 1024 in
  (try while true do Buffer.add_string b (input_line ic); Buffer.add_char b '\n' done with End_of_file -> ());
  close_in ic; Buffer.contents b

let trim s =
  let is_space = function ' '| '\t' | '\r' | '\n' -> true | _ -> false in
  let a = ref 0 and b = ref (String.length s - 1) in
  while !a <= !b && is_space s.[!a] do incr a done;
  while !b >= !a && is_space s.[!b] do decr b done;
  if !a > !b then "" else String.sub s !a (!b - !a + 1)

let bool_of_string_exn s =
  match String.lowercase_ascii (trim s) with
  | "true" -> true | "false" -> false | x -> failwith ("bad bool: "^x)

let parse_case path : case =
  let lines = read_file path |> String.split_on_char '\n' in
  let rec take_until tag acc = function
    | [] -> (List.rev acc, [])
    | l::ls when trim l = tag -> (List.rev acc, ls)
    | l::ls -> take_until tag (l::acc) ls
  in
  let header, rest = take_until "--ops--" [] lines in
  let ops_lines, rest2 = take_until "--checks--" [] rest in
  let check_lines, _rest3 = take_until "--end--" [] rest2 in
  let tbl = Hashtbl.create 5 in
  List.iter (fun l -> match String.split_on_char ':' l with
    | [k;v] -> Hashtbl.replace tbl (trim k) (trim v)
    | _ -> ()) header;
  let name = try Hashtbl.find tbl "name" with Not_found -> Filename.basename path in
  let size = try int_of_string (Hashtbl.find tbl "size") with _ -> failwith ("missing/invalid size in "^path) in
  let parse_op l =
    let l = trim l in
    if l = "" || (String.length l > 0 && l.[0] = '#') then None else
    match String.split_on_char ' ' l |> List.filter ((<>) "") with
    | ["union"; a; b] -> Some (Union (int_of_string a, int_of_string b))
    | _ -> failwith ("bad op line: "^l)
  in
  let ops = List.filter_map parse_op ops_lines in
  let parse_check l =
    let l = trim l in
    if l = "" || (String.length l > 0 && l.[0] = '#') then None else
    match String.split_on_char ' ' l |> List.filter ((<>) "") with
    | ("?"|"check") :: a :: b :: rest ->
        let expect = (match rest with | ["="; b1] -> bool_of_string_exn b1 | [b1] -> bool_of_string_exn b1 | _ -> failwith ("bad check line: "^l)) in
        Some { a = int_of_string a; b = int_of_string b; expect }
    | _ -> failwith ("bad check line: "^l)
  in
  let checks = List.filter_map parse_check check_lines in
  { name; size; ops; checks }

let load_cases base_dir : (string * case) list =
  let dir = Filename.concat base_dir "tests-uf" in
  if Sys.file_exists dir then
    Sys.readdir dir |> Array.to_list |> List.sort compare
    |> List.filter (fun f -> Filename.check_suffix f ".tcase")
    |> List.map (fun f -> let c = parse_case (Filename.concat dir f) in (c.name, c))
  else []

let as_z i = Z.of_int i

let exec_case (c:case) : bool =
  let uf = Congbare.uf_new (as_z c.size) in
  let z = as_z in
  let max_id =
    let m = ref (-1) in
    let upd x = if x > !m then m := x in
    List.iter (function Union(a,b)->upd a;upd b) c.ops;
    List.iter (fun {a;b;_} -> upd a; upd b) c.checks; !m in
  if max_id >= c.size then (
    eprintf "Error: test '%s' has size %d but references id %d.\n" c.name c.size max_id;
    exit 2
  );
  let apply_op = function Union (a,b) -> Congbare.union uf (z a) (z b) in
  List.iter apply_op c.ops;
  let eq a b = let ah = Congbare.find uf (z a) and bh = Congbare.find uf (z b) in Z.equal ah bh in
  let results = List.map (fun {a;b;expect} -> let got = eq a b in (a,b,expect,got)) c.checks in
  let all_ok = List.for_all (fun (_,_,e,g) -> e = g) results in
  printf "- %s\n  size: %d\n" c.name c.size;
  if c.ops <> [] then (
    printf "  unions:\n";
    List.iter (fun (Union(a,b)) -> printf "    %d ~ %d\n" a b) c.ops
  );
  List.iter (fun (a,b,e,g) -> printf "  check: %d = %d  expect: %-5s  got: %-5s%s\n"
    a b (string_of_bool e) (string_of_bool g) (if e = g then "" else "  <-- mismatch")) results;
  printf "  status: %s\n\n" (if all_ok then "PASS" else "FAIL");
  all_ok

let run_all base_dir =
  let cases = load_cases base_dir in
  let (passed,total) = List.fold_left (fun (p,t) (_,c) -> let ok = exec_case c in ((if ok then p+1 else p), t+1)) (0,0) cases in
  printf "Summary: %d passed, %d failed, %d total\n" passed (total-passed) total

let run_one base_dir name =
  let cases = load_cases base_dir in
  match List.assoc_opt name cases with
  | Some c -> ignore (exec_case c)
  | None -> eprintf "Unknown test '%s'\n" name

let () =
  let exe = Sys.argv.(0) in
  let build_dir = Filename.dirname exe in
  let base_dir = Filename.dirname build_dir in
  match Array.to_list Sys.argv |> List.tl with
  | [] | ["all"] -> run_all base_dir
  | [name] -> run_one base_dir name
  | _ -> (eprintf "Usage: run_uf [all|name]\n"; exit 2)
