/*********************************************************************
 * Software License Agreement (Modified BSD License)
 *
 * Copyright (c) 2009-2010, Willow Garage, Daniel Munoz
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the copyright holders' organizations nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *********************************************************************/

#include <m3n/m3n_model.h>

using namespace std;
using namespace submodular_graphcut;

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::infer(const RandomField& random_field,
                    map<unsigned int, unsigned int>& inferred_labels,
                    unsigned int max_iterations,
                    unsigned int max_nbr_regressors) const
{
  if (!m_trained)
  {
    cerr << "M3NModel::infer() Cant infer on untrained model" << endl;
    return -1;
  }
  if (random_field.getNumberOfCliqueSets() != m_robust_potts_params.size())
  {
    cerr << "M3NModel::infer() Inconsistent clique sets" << endl;
    return -1;
  }
  if (random_field.getNodeDim() != m_node_feature_dim)
  {
    cerr << "M3NModel::infer() Inconsistent node feature dimensions" << endl;
    return -1;
  }
  if (!equal(random_field.getCliqueSetDims().begin(), random_field.getCliqueSetDims().end(),
             m_clique_set_feature_dims.begin()))
  {
    cerr << "M3NModel::infer() Inconsistent clique feature dimensions" << endl;
    return -1;
  }
  return doInfer(random_field, inferred_labels, max_iterations, max_nbr_regressors);
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::cachePotentials(const RandomField& random_field,
                              const unsigned int max_nbr_regressors,
                              map<unsigned int, map<unsigned int, double> >& cache_node_potentials,
                              vector<map<unsigned int, map<unsigned int, double> > >& cache_clique_set_potentials) const
{
  int ret_val = 0;

  // -------------------------------------------
  // Clear out old values
  cache_node_potentials.clear();
  cache_clique_set_potentials.clear();
  cache_clique_set_potentials.resize(m_clique_set_feature_dims.size());

  // Fill map: [label -> value]
  unsigned int nbr_training_labels = m_training_labels.size();
  map<unsigned int, double> label_place_holders;
  for (unsigned int i = 0 ; i < nbr_training_labels ; i++)
  {
    label_place_holders[m_training_labels[i]] = 0.0;
  }

  // Retrieve random field info
  cout << "Starting to cache all potential configurations..." << endl;
  time_t start_timer, end_timer;
  time(&start_timer);
  const map<unsigned int, RandomField::Node*>& nodes = random_field.getNodes();
  const vector<map<unsigned int, RandomField::Clique*> >& clique_sets =
      random_field.getCliqueSets();

  // node->[label->value]
  for (map<unsigned int, RandomField::Node*>::const_iterator iter_nodes = nodes.begin() ; iter_nodes
      != nodes.end() ; iter_nodes++)
  {
    cache_node_potentials[iter_nodes->first] = label_place_holders;
  }

  // clique_set_idx->clique->[label->value]
  const unsigned int nbr_clique_sets = clique_sets.size();
  for (unsigned int cs_idx = 0 ; cs_idx < nbr_clique_sets ; cs_idx++)
  {
    const map<unsigned int, RandomField::Clique*>& curr_clique_set = clique_sets[cs_idx];
    for (map<unsigned int, RandomField::Clique*>::const_iterator iter_cliques =
        curr_clique_set.begin() ; iter_cliques != curr_clique_set.end() ; iter_cliques++)
    {
      cache_clique_set_potentials[cs_idx][iter_cliques->first] = label_place_holders;
    }
  }

  // Compute scores for each label
#pragma omp parallel for num_threads(m_nbr_threads_cache)
  for (unsigned int i = 0 ; i < nbr_training_labels ; i++)
  {
    const unsigned int curr_label = m_training_labels[i];

    // -------------------------------------------
    // Populate node scores
    for (map<unsigned int, RandomField::Node*>::const_iterator iter_nodes = nodes.begin() ; iter_nodes
        != nodes.end() ; iter_nodes++)
    {
      // Retrieve current node info
      const unsigned int curr_node_id = iter_nodes->first;
      const RandomField::Node* curr_node = iter_nodes->second;

      double potential_value = 0.0;
      if (computePotential(*curr_node, curr_label, max_nbr_regressors, potential_value) < 0)
      {
        ret_val = -1;
      }

      cache_node_potentials.find(curr_node_id)->second.find(curr_label)->second = potential_value;
      //cache_node_potentials_[curr_node_id][curr_label] = potential_value;
    }

    if (ret_val == 0)
    {
      // -------------------------------------------
      // Populate clique scores
      for (unsigned int cs_idx = 0 ; cs_idx < nbr_clique_sets ; cs_idx++)
      {
        // Iterate over the cliques in each clique set
        const map<unsigned int, RandomField::Clique*>& curr_clique_set = clique_sets[cs_idx];
        for (map<unsigned int, RandomField::Clique*>::const_iterator iter_cliques =
            curr_clique_set.begin() ; iter_cliques != curr_clique_set.end() ; iter_cliques++)
        {
          // Retrieve current clique info
          const unsigned int curr_clique_id = iter_cliques->first;
          const RandomField::Clique* curr_clique = iter_cliques->second;

          double potential_value = 0.0;
          if (computePotential(*curr_clique, cs_idx, curr_label, max_nbr_regressors,
                               potential_value) < 0)
          {
            ret_val = -1;
          }

          cache_clique_set_potentials[cs_idx].find(curr_clique_id)->second.find(curr_label)->second
              = potential_value;
          //cache_clique_set_potentials_[cs_idx][curr_clique_id][curr_label] = potential_value;
        }
      }
    }
  }
  time(&end_timer);

  if (ret_val == 0)
  {
    cout << "Successfully cached potentials in " << difftime(end_timer, start_timer) << " seconds"
        << endl;
  }
  else
  {
    cout << "Failed to cache all potentials" << endl;
  }
  return ret_val;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
void M3NModel::generateInitialLabeling(const RandomField& random_field,
                                       map<unsigned int, unsigned int>& inferred_labels,
                                       const map<unsigned int, map<unsigned int, double> >& cache_node_potentials) const
{
  inferred_labels.clear();

  const unsigned int nbr_labels = m_training_labels.size();

  const map<unsigned int, RandomField::Node*>& nodes = random_field.getNodes();
  for (map<unsigned int, RandomField::Node*>::const_iterator iter_nodes = nodes.begin() ; iter_nodes
      != nodes.end() ; iter_nodes++)
  {
    unsigned int curr_node_id = iter_nodes->first;

    // ---------------
    // Use random labeling
    inferred_labels[curr_node_id] = m_training_labels.at(rand() % nbr_labels);
    // ---------------

    // ---------------
    // Initialize to best node label
    /*
     unsigned int curr_best_label = m_training_labels[0];
     double curr_best_score =
     cache_node_potentials.find(curr_node_id)->second.find(curr_best_label)->second;
     for (unsigned int i = 1 ; i < nbr_labels ; i++)
     {
     unsigned int next_label = m_training_labels[i];
     double next_score = cache_node_potentials.find(curr_node_id)->second.find(next_label)->second;
     if (next_score > curr_best_score)
     {
     curr_best_label = next_label;
     curr_best_score = next_score;
     }
     }
     inferred_labels[curr_node_id] = curr_best_label;
     */
    // ---------------
  }
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::doInfer(const RandomField& random_field,
                      map<unsigned int, unsigned int>& inferred_labels,
                      unsigned int max_iterations,
                      unsigned int max_nbr_regressors) const
{
  time_t start_timer, end_timer;
  time(&start_timer);

  // WARNING: if you change this value, you must change the calls in addNodeEnergy, addCliqueEnergy
  const unsigned int ALPHA_VALUE = 0;

  // -------------------------------------------
  // Retrieve random field information
  const map<unsigned int, RandomField::Node*>& nodes = random_field.getNodes();
  const vector<map<unsigned int, RandomField::Clique*> >& clique_sets =
      random_field.getCliqueSets();
  const unsigned int nbr_clique_sets = clique_sets.size();
  // Estimate the max number of nodes and edges in the graph for alpha-expansion
  unsigned int est_nbr_energy_nodes = nodes.size() + (clique_sets.size() * 3);
  unsigned int est_nbr_energy_edges = est_nbr_energy_nodes * 5;

  // -------------------------------------------
  // Cache potentials
  map<unsigned int, map<unsigned int, double> > cache_node_potentials;
  vector<map<unsigned int, map<unsigned int, double> > > cache_clique_set_potentials;
  if (cachePotentials(random_field, max_nbr_regressors, cache_node_potentials,
                      cache_clique_set_potentials) < 0)
  {
    return -1;
  }

  // -------------------------------------------
  // Setup label information.
  // Generate random initializing labeling if passed empty labeling
  inferred_labels.clear();
  generateInitialLabeling(random_field, inferred_labels, cache_node_potentials);
  const unsigned int nbr_labels = m_training_labels.size();

  // -------------------------------------------
  // Loop until hit max number of iterations or converged on minimum energy
  if (max_iterations == 0)
  {
    max_iterations = numeric_limits<unsigned int>::max();
  }
  double prev_energy = 0.0;
  int ret_val = 0;
  bool converged = false;
  unsigned int t = 0;
  for (t = 0; ret_val == 0 && !converged && t < max_iterations ; t++)
  {
    converged = true; // will be set false if energy changes
    // -------------------------------------------
    // Alpha-expand over each label
    for (unsigned int label_idx = 0 ; ret_val == 0 && label_idx < nbr_labels ; label_idx++)
    {
      unsigned int alpha_label = m_training_labels[label_idx];

      // -------------------------------------------
      // Create new energy function
      //GraphcutVK3 energy_func(est_nbr_energy_nodes, est_nbr_energy_edges);
      GraphcutBGL energy_func; // slower
      //GraphcutVK2 energy_func; // very very slow with high order cliques
      map<unsigned int, SubmodularGraphcut::EnergyVar> energy_vars;

      // -------------------------------------------
      // Create energy variables and Compute node scores
      for (map<unsigned int, RandomField::Node*>::const_iterator iter_nodes = nodes.begin() ; ret_val
          == 0 && iter_nodes != nodes.end() ; iter_nodes++)
      {
        unsigned int curr_node_id = iter_nodes->first;
        const RandomField::Node* curr_node = iter_nodes->second;
        if (curr_node->getID() != curr_node_id)
        {
          abort();
        }
        if (inferred_labels.count(curr_node_id) == 0)
        {
          abort();
        }

        // Add new energy variable
        energy_vars[curr_node_id] = energy_func.addVariable();

        ret_val = addNodeEnergy(*curr_node, energy_func, energy_vars[curr_node_id],
                                inferred_labels[curr_node_id], alpha_label, cache_node_potentials);
      }

      // -------------------------------------------
      // Iterate over clique sets to compute cliques' scores
      for (unsigned int cs_idx = 0 ; ret_val == 0 && cs_idx < nbr_clique_sets ; cs_idx++)
      {
        // -------------------------------------------
        // Iterate over clique scores
        const map<unsigned int, RandomField::Clique*>& curr_clique_set = clique_sets[cs_idx];
        for (map<unsigned int, RandomField::Clique*>::const_iterator iter_cliques =
            curr_clique_set.begin() ; ret_val == 0 && iter_cliques != curr_clique_set.end() ; iter_cliques++)
        {
          const RandomField::Clique* curr_clique = iter_cliques->second;
          unsigned int curr_clique_order = curr_clique->getOrder();
          if (curr_clique_order < 2)
          {
            // Ignore cliques with order 0 or 1
            continue;
          }
          // add edge potential (assuming all edges are associative)
          else if (curr_clique_order == 2)
          {
            ret_val = addEdgeEnergy(*curr_clique, energy_func, energy_vars, inferred_labels,
                                    alpha_label, cache_clique_set_potentials[cs_idx]);
          }
          // add high-order clique potential
          else
          {
            // add Robust Pn Potts
            if (m_robust_potts_params[cs_idx] > 1e-9)
            {
              ret_val = addCliqueEnergyRobustPotts(*curr_clique, cs_idx, energy_func, energy_vars,
                                                   inferred_labels, alpha_label,
                                                   cache_clique_set_potentials[cs_idx]);
            }
            // add Pn Potts
            else
            {
              ret_val = addCliqueEnergyPotts(*curr_clique, energy_func, energy_vars,
                                             inferred_labels, alpha_label,
                                             cache_clique_set_potentials[cs_idx]);
            }
          }
        }
      }

      // -------------------------------------------
      // Minimize function if there are no errors
      if (ret_val == 0)
      {
        double curr_energy = energy_func.minimize();

        // Update labeling if: first iteration OR energy decreases with expansion move
        if ((t == 0 && label_idx == 0) || ((curr_energy + 1e-6) < prev_energy))
        {
          // Change respective labels to the alpha label
          for (map<unsigned int, SubmodularGraphcut::EnergyVar>::iterator iter_energy_vars =
              energy_vars.begin() ; iter_energy_vars != energy_vars.end() ; iter_energy_vars++)
          {
            if (energy_func.getValue(iter_energy_vars->second) == ALPHA_VALUE)
            {
              inferred_labels[iter_energy_vars->first] = alpha_label;
            }
          }

          // Made an alpha expansion, so did not reach convergence yet
          prev_energy = curr_energy;
          converged = false;
        }
      }
    }
  }
  time(&end_timer);

  if (ret_val == 0)
  {
    cout << "Inference converged to energy " << prev_energy << " after " << t << " iterations in "
        << difftime(end_timer, start_timer) << " seconds" << endl;
  }
  else
  {
    cout << "Inference failed after " << t << " iterations" << endl;
  }
  return ret_val;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::addNodeEnergy(const RandomField::Node& node,
                            SubmodularGraphcut& energy_func,
                            const SubmodularGraphcut::EnergyVar& energy_var,
                            const unsigned int curr_label,
                            const unsigned int alpha_label,
                            const map<unsigned int, map<unsigned int, double> >& cache_node_potentials) const
{
  double alpha_score = cache_node_potentials.find(node.getID())->second.find(alpha_label)->second;
  double curr_score = 0.0;
  if (curr_label == alpha_label)
  {
    curr_score = alpha_score;
  }
  else
  {
    curr_score = cache_node_potentials.find(node.getID())->second.find(curr_label)->second;
  }

  // Implement hamming loss margin (during training only)
  if (m_loss_augmented_inference)
  {
    unsigned int gt_label = node.getLabel();
    curr_score += static_cast<double> ((gt_label != curr_label));
    alpha_score += static_cast<double> ((gt_label != alpha_label));
  }

  // WARNING, this follows that ALPHA_VALUE == 0
  // max +score = min -score
  energy_func.addUnary(energy_var, -alpha_score, -curr_score);
  return 0;
}

// --------------------------------------------------------------
/* See function definition.  Assumes edge contains only 2 nodes */
// --------------------------------------------------------------
int M3NModel::addEdgeEnergy(const RandomField::Clique& edge,
                            SubmodularGraphcut& energy_func,
                            const map<unsigned int, SubmodularGraphcut::EnergyVar>& energy_vars,
                            const map<unsigned int, unsigned int>& curr_labeling,
                            const unsigned int alpha_label,
                            const map<unsigned int, map<unsigned int, double> >& cache_cs_potentials) const
{
  // Retrieve the ids of the node in the edge
  const list<unsigned int>& node_ids = edge.getNodeIDs();
  const unsigned int node1_id = node_ids.front();
  const unsigned int node2_id = node_ids.back();

  // Retrieve the nodes current labels
  const unsigned int node1_label = curr_labeling.find(node1_id)->second;
  const unsigned int node2_label = curr_labeling.find(node2_id)->second;

  double E00 = 0.0;
  double E01 = 0.0;
  double E10 = 0.0;
  double E11 = 0.0;

  E00 = cache_cs_potentials.find(edge.getID())->second.find(alpha_label)->second;

  // Compute score if node1 switches to alpha (0) & node2 stays the same (1)
  if (node2_label == alpha_label)
  {
    // reuse computation
    E01 = E00;
  }

  // Compute score if node1 stays the same (1) & node2 switches to alpha (0)
  if (node1_label == alpha_label)
  {
    // reuse computation
    E10 = E00;
  }

  // Compute score if both nodes stay the same (1)
  if (node1_label == node2_label)
  {
    if (node1_label == alpha_label)
    {
      // reuse computation
      E11 = E00;
    }
    else
    {
      E11 = cache_cs_potentials.find(edge.getID())->second.find(node1_label)->second;
    }
  }

  // WARNING, this follows that ALPHA_VALUE == 0
  // max +score = min -score
  if (energy_func.addPairwise(energy_vars.find(node1_id)->second,
                              energy_vars.find(node2_id)->second, -E00, -E01, -E10, -E11) < 0)
  {
    return -1;
  }
  return 0;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::addCliqueEnergyPotts(const RandomField::Clique& clique,
                                   SubmodularGraphcut& energy_func,
                                   const map<unsigned int, SubmodularGraphcut::EnergyVar>& energy_vars,
                                   const map<unsigned int, unsigned int>& curr_labeling,
                                   const unsigned int alpha_label,
                                   const map<unsigned int, map<unsigned int, double> >& cache_cs_potentials) const
{
  // -----------------------------------
  // Compute potential if all node switch to alpha
  double Ec0 = cache_cs_potentials.find(clique.getID())->second.find(alpha_label)->second;

  // -----------------------------------
  // Compute potential if all nodes keep their current labeling
  unsigned int mode1_label = 0;
  unsigned int mode1_count = 0;
  unsigned int mode2_label = 0;
  unsigned int mode2_count = 0;
  if (clique.getModeLabels(mode1_label, mode1_count, mode2_label, mode2_count, NULL, &curr_labeling)
      < 0)
  {
    return -1;
  }

  // Ec1 will be non-zero only when all nodes are labeled the same.
  double Ec1 = 0.0;
  if (mode2_label == RandomField::UNKNOWN_LABEL)
  {
    if (mode1_label == alpha_label)
    {
      // use precomputed value
      Ec1 = Ec0;
    }
    else
    {
      Ec1 = cache_cs_potentials.find(clique.getID())->second.find(mode1_label)->second;
    }
  }

  // -----------------------------------
  // Create list of energy variables that represent the nodes in this clique
  list<SubmodularGraphcut::EnergyVar> node_vars;
  const list<unsigned int>& node_ids = clique.getNodeIDs();
  for (list<unsigned int>::const_iterator iter_node_ids = node_ids.begin() ; iter_node_ids
      != node_ids.end() ; iter_node_ids++)
  {
    node_vars.push_back(energy_vars.find(*iter_node_ids)->second);
  }

  // -----------------------------------
  // WARNING, this follows that ALPHA_VALUE == 0
  // max +score = min -score
  if (energy_func.addPnPotts(node_vars, -Ec0, -Ec1, 0.0) < 0)
  {
    return -1;
  }
  return 0;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::addCliqueEnergyRobustPotts(const RandomField::Clique& clique,
                                         const unsigned int clique_set_idx,
                                         SubmodularGraphcut& energy_func,
                                         const map<unsigned int, SubmodularGraphcut::EnergyVar>& energy_vars,
                                         const map<unsigned int, unsigned int>& curr_labeling,
                                         const unsigned int alpha_label,
                                         const map<unsigned int, map<unsigned int, double> >& cache_cs_potentials) const
{
  // -----------------------------------
  // Compute the mode labels in the clique
  unsigned int mode1_label = 0;
  unsigned int mode1_count = 0;
  unsigned int mode2_label = 0;
  unsigned int mode2_count = 0;
  list<unsigned int> dominant_node_ids;
  if (clique.getModeLabels(mode1_label, mode1_count, mode2_label, mode2_count, &dominant_node_ids,
                           &curr_labeling) < 0)
  {
    return -1;
  }

  // -----------------------------------
  // Determine if a "dominant" label exists: D > P-Q where D is the number of
  // nodes labeled d != alpha, P = number of nodes in the clique, Q = truncation parameter
  // (Note: it must be the mode1_label if dominant label exists).
  // If it exists, compute the clique potential as if all nodes
  // in the clique were assigned that label
  bool found_dominant_label = false;
  double gamma_dominant = -1.0;
  double P = static_cast<double> (clique.getOrder());
  double Q = m_robust_potts_params[clique_set_idx] * P;
  if (mode1_label != alpha_label)
  {
    double D = static_cast<double> (mode1_count);
    if ((D - 1e-9) > (P - Q)) // condition if dominant label exists
    {
      gamma_dominant = cache_cs_potentials.find(clique.getID())->second.find(mode1_label)->second;
      found_dominant_label = true;
    }
  }

  // -----------------------------------
  // Compute potential if all nodes switch to alpha
  double gamma_alpha = cache_cs_potentials.find(clique.getID())->second.find(alpha_label)->second;

  // -----------------------------------
  // Create list of energy variables of the nodes in the clique.
  // Also save another list containing the variables of just the
  // dominant nodes, if indicated to.
  list<unsigned int>::iterator iter_dominant_node_ids;
  if (found_dominant_label)
  {
    iter_dominant_node_ids = dominant_node_ids.begin();
  }
  else
  {
    iter_dominant_node_ids = dominant_node_ids.end();
  }
  list<SubmodularGraphcut::EnergyVar> node_vars;
  list<SubmodularGraphcut::EnergyVar> dominant_vars;
  const list<unsigned int>& node_ids = clique.getNodeIDs();
  for (list<unsigned int>::const_iterator iter_node_ids = node_ids.begin() ; iter_node_ids
      != node_ids.end() ; iter_node_ids++)
  {
    // save energy variable of all nodes in the clique
    node_vars.push_back(energy_vars.find(*iter_node_ids)->second);

    // save the energy variables of only the nodes that take on the dominant label
    if (iter_dominant_node_ids != dominant_node_ids.end())
    {
      dominant_vars.push_back(energy_vars.find(*iter_dominant_node_ids)->second);
      iter_dominant_node_ids++;
    }
  }

  // -----------------------------------
  // WARNING, this follows that ALPHA_VALUE == 0
  // max +score = min -score
  int ret_val = 0;
  if (found_dominant_label)
  {
    ret_val = energy_func.addRobustPottsDominantExpand0(node_vars, dominant_vars, -gamma_alpha,
                                                        -gamma_dominant, 0.0, Q);
  }
  else
  {
    ret_val = energy_func.addRobustPottsNoDominantExpand0(node_vars, -gamma_alpha, 0.0, Q);
  }
  return ret_val;
}
