#include <string>
#include <vector>
#include <list>

#include <m3n/random_field.h>
#include <m3n/functional_m3n.h>
#include <m3n/regressors/regressor_includes.h>

#include <iostream>

using namespace std;

int main()
{
  // ----------------------------------------------------------
  // Example how to create a random field with edges and high-order cliques
  unsigned int nbr_clique_sets = 2; // edges, high-order cliques
  vector<double> empty_features(5); // dummy 5 dimension feature vector
  RandomField tmp_rf(nbr_clique_sets);
  // Create Nodes (ids or labels do not need to be sequential)
  unsigned int node1_id = 5;
  unsigned int node1_label = 9;
  const RandomField::Node* node1 = tmp_rf.createNode(node1_id, empty_features, node1_label);
  unsigned int node2_id = 10;
  unsigned int node2_label = 50;
  const RandomField::Node* node2 = tmp_rf.createNode(node2_id, empty_features, node2_label);
  unsigned int node3_id = 500;
  unsigned int node3_label = 50;
  const RandomField::Node* node3 = tmp_rf.createNode(node3_id, empty_features, node3_label);
  unsigned int node4_id = 901;
  unsigned int node4_label = 9;
  const RandomField::Node* node4 = tmp_rf.createNode(node4_id, empty_features, node4_label);
  // Create edges (in the first clique set (cs))
  unsigned int edges_cs_idx = 0;
  list<const RandomField::Node*> edge1_nodes;
  edge1_nodes.push_back(node1);
  edge1_nodes.push_back(node2);
  tmp_rf.createClique(edges_cs_idx, edge1_nodes, empty_features);
  list<const RandomField::Node*> edge2_nodes;
  edge2_nodes.push_back(node2);
  edge2_nodes.push_back(node4);
  tmp_rf.createClique(edges_cs_idx, edge2_nodes, empty_features);
  // Create high-order cliques (in the second clique set)
  unsigned int cliques_cs_idx = 1;
  list<const RandomField::Node*> clique1_nodes;
  clique1_nodes.push_back(node1);
  clique1_nodes.push_back(node2);
  clique1_nodes.push_back(node3);
  tmp_rf.createClique(cliques_cs_idx, clique1_nodes, empty_features);
  list<const RandomField::Node*> clique2_nodes;
  clique2_nodes.push_back(node1);
  clique2_nodes.push_back(node2);
  clique2_nodes.push_back(node3);
  clique2_nodes.push_back(node4);
  tmp_rf.createClique(cliques_cs_idx, clique2_nodes, empty_features);

  // ----------------------------------------------------------
  // Load a real random field (point cloud example) to train on
  RandomField training_rf;
  if (training_rf.loadRandomFieldWGASCII("random_field/train_rf") < 0)
  {
    cerr << "Failed to load random field" << endl;
    return -1;
  }
  vector<const RandomField*> training_rfs(1, &training_rf);

  // ----------------------------------------------
  // Define learning parameters
  double nbr_iters = 7;
  double step_size = 0.3;
  double truncation_param = -1.0; // indicates use Pott's model
  vector<double> robust_potts_params(training_rf.getNumberOfCliqueSets(), truncation_param);
  // --- Linear Model ---
  LinearRegression lin_reg(0.0001);
  FunctionalM3N m3n_model(robust_potts_params, lin_reg);
  // --- Non-linear Model ---
  //OCVRTreeWrapper rtree; // default training parameters
  //FunctionalM3N m3n_model(robust_potts_params, rtree);

  // ----------------------------------------------------------
  // Train the M3N
  cout << "Starting to train..." << endl;
  if (m3n_model.train(training_rfs, nbr_iters, step_size) < 0)
  {
    cerr << "Failed to train M3N" << endl;
    return -1;
  }
  cout << "Successfully trained M3N" << endl;

  // ----------------------------------------------------------
  // Verify save and load work correctly
  string dir1("learned_m3n_model");
  string dir2("resaved_m3n_model");
  string model_name("my_m3n_model");
  if (m3n_model.saveToFile(dir1, model_name) < 0)
  {
    cerr << "Failed to save M3N" << endl;
    return -1;
  }
  FunctionalM3N loaded_m3n(dir1, model_name);
  if (loaded_m3n.saveToFile(dir2, model_name) < 0)
  {
    cerr << "Failed to load M3N" << endl;
    return -1;
  }

  // ----------------------------------------------------------
  // Perform inference
  map<unsigned int, unsigned int> inferred_labels;
  if (loaded_m3n.infer(training_rf, inferred_labels) < -0)
  {
    cerr << "Failed to perform inference" << endl;
    return -1;
  }
  const map<unsigned int, RandomField::Node*>& nodes = training_rf.getNodes();
  unsigned int first_node_id = nodes.begin()->second->getID();
  cout << "Node with id: " << first_node_id << " has predicted label: "
      << inferred_labels[first_node_id] << endl;

  return 0;
}
