/*********************************************************************
 * Software License Agreement (Modified BSD License)
 *
 * Copyright (c) 2009-2010, Willow Garage, Daniel Munoz
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the copyright holders' organizations nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *********************************************************************/

#include <m3n/m3n_model.h>

using namespace std;

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::train(const vector<const RandomField*>& training_rfs,
                    const unsigned int nbr_iterations,
                    const double step_size)
{
  unsigned int saving_interval = 0; // indicates don't save until completed
  string empty_string(""); // will be ignored
  return train(training_rfs, nbr_iterations, step_size, saving_interval, empty_string, empty_string);
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::train(const vector<const RandomField*>& training_rfs,
                    const unsigned int nbr_iterations,
                    const double step_size,
                    unsigned int saving_interval,
                    std::string directory,
                    std::string basename)
{
  // Extract and ensure labels and dimensions are consistent
  if (extractVerifyLabelsFeatures(training_rfs) < 0)
  {
    return -1;
  }

  // Now define the dimension information
  // (This must be called after extractVerifyLabelsFeatures())
  initStackedFeatureIndices();

  // Will parallelize inference over multiple RFs instead
  if (training_rfs.size() > 1)
  {
    m_nbr_threads_cache = 1;
  }
  // Learn with loss-augmented inference (current only hamming loss)
  m_loss_augmented_inference = true;

  M3NLogger logger;
  for (unsigned int t = 0 ; t < nbr_iterations ; t++)
  {
    cout << "-------- Starting iteration " << t << " --------" << endl;
    bool infer_random = (t == 0 && !m_trained);
    double curr_step_size = step_size / sqrt(static_cast<double> (t + 1));
    doSubgradientUpdate(training_rfs, infer_random, curr_step_size, logger);

    // Save the model so far if indicated to
    if (saving_interval > 0 && t > 0 && t % saving_interval == 0)
    {
      stringstream basename_suffix;
      basename_suffix << "_iteration" << t;
      if (saveToFile(directory, basename, basename_suffix.str()) < 0)
      {
        cerr << "Could not save intermediate model to: " << directory << endl;
        abort();
      }
    }
  } // end training iterations

  m_loss_augmented_inference = false;
  m_nbr_threads_cache = omp_get_num_procs();
  m_trained = true;
  return 0;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
void M3NModel::structLossResiduals(const vector<const RandomField*>& training_rfs,
                                   const bool infer_random,
                                   RegressorWrapper& regressor,
                                   M3NLogger& logger)
{
  double total_inference_time = 0.0;

  // ---------------------------------------------------
  // Iterate over each RandomField
  unsigned int nbr_training_rfs = training_rfs.size();
#pragma omp parallel for
  for (unsigned int i = 0 ; i < nbr_training_rfs ; i++)
  {
    const RandomField* curr_rf = training_rfs[i];
    const map<unsigned int, RandomField::Node*>& nodes = curr_rf->getNodes();
    const vector<map<unsigned int, RandomField::Clique*> >& clique_sets = curr_rf->getCliqueSets();

    // ---------------------------------------------------
    // Perform inference with the current model
    map<unsigned int, unsigned int> curr_inferred_labeling;
    time_t start_timer, end_timer;
    time(&start_timer);
    if (infer_random) // in first iteration of learning
    {
      unsigned int nbr_labels = m_training_labels.size();
      for (map<unsigned int, RandomField::Node*>::const_iterator iter_nodes = nodes.begin() ; iter_nodes
          != nodes.end() ; iter_nodes++)
      {
        unsigned int random_label = m_training_labels[rand() % nbr_labels];
        curr_inferred_labeling[iter_nodes->first] = random_label;
      }
    }
    else
    {
      doInfer(*curr_rf, curr_inferred_labeling);
    }
    time(&end_timer);

#pragma omp critical
{
    total_inference_time += difftime(end_timer, start_timer);
    logger.addErrorRate(nodes, curr_inferred_labeling, m_training_labels);

    // ---------------------------------------------------
    // Create training set for the new regressor from node and clique features.
    // When classification is wrong, do +1/-1 with features with ground truth/inferred label
    // ------------------------
    // Node features
    for (map<unsigned int, RandomField::Node*>::const_iterator iter_nodes = nodes.begin() ; iter_nodes
        != nodes.end() ; iter_nodes++)
    {
      unsigned int curr_node_id = iter_nodes->first;
      unsigned int curr_node_gt_label = iter_nodes->second->getLabel();
      unsigned int curr_node_infer_label = curr_inferred_labeling[curr_node_id];
      if (curr_node_gt_label != curr_node_infer_label)
      {
        // +1 features with ground truth label
        if (regressor.addTrainingSample(iter_nodes->second->getFeatureVals(),
                                        m_node_stacked_feature_start_idx[curr_node_gt_label], 1.0)
            < 0)
        {
          abort();
        }

        // -1 features with wrong inferred label
        if (regressor.addTrainingSample(iter_nodes->second->getFeatureVals(),
                                        m_node_stacked_feature_start_idx[curr_node_infer_label],
                                        -1.0) < 0)
        {
          abort();
        }
      }
    }

    // ------------------------
    // Iterate over clique sets to get clique features
    for (unsigned int clique_set_idx = 0 ; clique_set_idx < clique_sets.size() ; clique_set_idx++)
    {
      // ------------------------
      // Iterate over cliques
      const map<unsigned int, RandomField::Clique*>& curr_cliques = clique_sets[clique_set_idx];
      for (map<unsigned int, RandomField::Clique*>::const_iterator iter_cliques =
          curr_cliques.begin() ; iter_cliques != curr_cliques.end() ; iter_cliques++)
      {
        if (iter_cliques->second->getOrder() == 1)
          continue;

        unsigned int curr_clique_gt_mode1_label = 0;
        unsigned int curr_clique_gt_mode1_count = 0;
        unsigned int curr_clique_gt_mode2_label = 0; // unused
        unsigned int curr_clique_gt_mode2_count = 0; // unused
        if (iter_cliques->second->getModeLabels(curr_clique_gt_mode1_label,
                                                curr_clique_gt_mode1_count,
                                                curr_clique_gt_mode2_label,
                                                curr_clique_gt_mode2_count) < 0)
        {
          abort();
        }

        unsigned int curr_clique_infer_mode1_label = 0;
        unsigned int curr_clique_infer_mode1_count = 0;
        unsigned int curr_clique_infer_mode2_label = 0; // unused
        unsigned int curr_clique_infer_mode2_count = 0; // unused
        if (iter_cliques->second->getModeLabels(curr_clique_infer_mode1_label,
                                                curr_clique_infer_mode1_count,
                                                curr_clique_infer_mode2_label,
                                                curr_clique_infer_mode2_count, NULL,
                                                &curr_inferred_labeling) < 0)
        {
          abort();
        }

        // sanity check
        if (curr_clique_gt_mode1_label == RandomField::UNKNOWN_LABEL
            || curr_clique_infer_mode1_label == RandomField::UNKNOWN_LABEL)
        {
          cerr << "Error, mode label is UNKNOWN" << endl;
          abort();
        }

        // ------------------------
        // Compute functional gradient for current clique
        // Note that addTrainingSample will update targets from previous calls
        // ------------------------
        // Compute functional gradient residual from ground truth label (gt_residual)
        double gt_residual = calcFuncGradResidual(m_robust_potts_params[clique_set_idx],
                                                  iter_cliques->second->getOrder(),
                                                  curr_clique_gt_mode1_count);
        // +1 features with ground truth label
        if (regressor.addTrainingSample(
                                        iter_cliques->second->getFeatureVals(),
                                        m_clique_set_stacked_feature_start_idx[clique_set_idx][curr_clique_gt_mode1_label],
                                        gt_residual) < 0)
        {
          abort();
        }

        // ------------------------
        // Compute functional gradient residual from inferred label (gt_infer)
        double infer_residual = calcFuncGradResidual(m_robust_potts_params[clique_set_idx],
                                                     iter_cliques->second->getOrder(),
                                                     curr_clique_infer_mode1_count);
        // -1 features with wrong inferred label
        if (regressor.addTrainingSample(
                                        iter_cliques->second->getFeatureVals(),
                                        m_clique_set_stacked_feature_start_idx[clique_set_idx][curr_clique_infer_mode1_label],
                                        -infer_residual) < 0)
        {
          abort();
        }
      } // end iterate over cliques
    } // end iterate over clique sets
}
  } // end iterate over random fields
  logger.addTimingInference(total_inference_time);
}

// --------------------------------------------------------------
/*
 * See function definition
 * Invariant: truncation_params are valid
 */
// --------------------------------------------------------------
double M3NModel::calcFuncGradResidual(const double truncation_param,
                                      const unsigned int clique_order,
                                      const unsigned int nbr_mode_label) const
{
  // If using Robust Potts, determine if allowed number of disagreeing nodes is allowable
  if (truncation_param > 0.0)
  {
    double double_nbr_mode_label = static_cast<double> (nbr_mode_label);
    double double_clique_order = static_cast<double> (clique_order);
    double Q = truncation_param * double_clique_order;
    double residual = ((double_nbr_mode_label - double_clique_order) / Q) + 1.0;
    if (residual > 1.000001)
      abort();
    return max(residual, 0.0);
  }
  // Otherwise, a truncation param of 0.0 indicates to use Potts model
  else
  {
    return static_cast<double> (clique_order == nbr_mode_label);
  }
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::extractVerifyLabelsFeatures(const vector<const RandomField*>& training_rfs)
{
  // ---------------------------------------------------
  // Verify non-NULL pointers
  unsigned int nbr_rfs = training_rfs.size();
  if (nbr_rfs == 0)
  {
    cerr << "M3NModel::extractVerifyLabelsFeatures() nothing to train on" << endl;
    return -1;
  }

  // ---------------------------------------------------
  // Store labels in set so easy to check for duplicates
  set<unsigned int> training_labels_set;
  if (m_trained)
  {
    training_labels_set.insert(m_training_labels.begin(), m_training_labels.end());
  }

  // ---------------------------------------------------
  int ret_val = 0;
  for (unsigned int i = 0 ; ret_val == 0 && i < nbr_rfs ; i++)
  {
    // --------------------------------
    // Verify not null
    if (training_rfs[i] == NULL)
    {
      cerr << "M3NModel::extractVerifyLabelsFeatures() NULL random field " << i << endl;
      return -1;
    }
    const RandomField& curr_rf = *(training_rfs[i]);

    // --------------------------------
    // Ensure fully labeled random field
    if (curr_rf.isFullyLabeled() == false)
    {
      cerr << "M3NModel::extractVerifyLabelsFeatures() partially labeled random field" << endl;
      ret_val = -1;
    }

    // --------------------------------
    // Ensure consistent number of clique sets for training
    if (curr_rf.getNumberOfCliqueSets() != m_robust_potts_params.size())
    {
      cerr << "M3NModel::extractVerifyLabelsFeatures() inconsistent clique sets" << endl;
      ret_val = -1;
    }

    // --------------------------------
    // If NOT trained, use dimensions from first random field
    if (!m_trained && i == 0)
    {
      m_node_feature_dim = curr_rf.getNodeDim();
      m_clique_set_feature_dims = curr_rf.getCliqueSetDims();
    }
    // Otherwise, verify the node & clique dimensions are consistent
    if (m_trained || i > 0)
    {
      if (m_node_feature_dim != curr_rf.getNodeDim())
      {
        cerr << "M3NModel::extractVerifyLabelsFeatures() inconsistent node dim" << endl;
        ret_val = -1;
      }
      if (equal(m_clique_set_feature_dims.begin(), m_clique_set_feature_dims.end(),
                curr_rf.getCliqueSetDims().begin()) == false)
      {
        cerr << "M3NModel::extractVerifyLabelsFeatures() inconsistent clique dim" << endl;
        ret_val = -1;
      }
    }

    // --------------------------------
    // If trained: verify RF's labels are contained in training_labels_
    // otherwise: add to training_labels_ (b/c all labels may not be present in 1 random field)
    const set<unsigned int>& curr_labels = curr_rf.getLabels();
    if (m_trained)
    {
      if (includes(training_labels_set.begin(), training_labels_set.end(), curr_labels.begin(),
                   curr_labels.end()) == false)
      {
        cerr << "M3NModel::extractVerifyLabelsFeatures() introduced new labels" << endl;
        ret_val = -1;
      }
    }
    else
    {
      training_labels_set.insert(curr_labels.begin(), curr_labels.end());
    }
  }

  // ---------------------------------------------------
  // Verify the feature dimensions are non-zero
  if (m_node_feature_dim == 0 || (count(m_clique_set_feature_dims.begin(),
                                        m_clique_set_feature_dims.end(), 0) != 0))
  {
    cerr << "M3NModel::extractVerifyLabelsFeatures() 0 feature dimension" << endl;
    ret_val = -1;
  }

  // ---------------------------------------------------
  // Assign/reset values upon success/failure
  if (!m_trained)
  {
    if (ret_val == 0)
    {
      m_training_labels.assign(training_labels_set.begin(), training_labels_set.end());
    }
    else
    {
      m_node_feature_dim = 0;
      m_clique_set_feature_dims.clear();
      m_training_labels.clear();
    }
  }
  sort(m_training_labels.begin(), m_training_labels.end());
  return ret_val;
}
