/*********************************************************************
 * Software License Agreement (Modified BSD License)
 *
 * Copyright (c) 2009-2010, Willow Garage, Daniel Munoz
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the copyright holders' organizations nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *********************************************************************/

#include <m3n/m3n_model.h>

using namespace std;

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
M3NModel::M3NModel(const std::vector<double>& robust_potts_params)
{
  // --------
  // Verify the truncation parameters are valid:
  //   less than or equal to 0 means use Potts
  //   Robust Potts truncation parameters must be between [0, 0.5]
  for (unsigned int i = 0 ; i < robust_potts_params.size() ; i++)
  {
    if (robust_potts_params[i] > 0.5)
    {
      cerr << "M3NModel Invalid Robust Potts truncation parameters cant be bigger than 0.5" << endl;
      throw;
    }
  }
  m_robust_potts_params.assign(robust_potts_params.begin(), robust_potts_params.end());
  clear();
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
void M3NModel::clear()
{
  m_trained = false;

  m_node_feature_dim = 0;
  m_clique_set_feature_dims.clear();

  m_total_stack_feature_dim = 0;
  m_node_stacked_feature_start_idx.clear();
  m_clique_set_stacked_feature_start_idx.clear();

  m_loss_augmented_inference = false;
  m_training_labels.clear();
  m_nbr_threads_cache = omp_get_num_procs();
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
void M3NModel::initStackedFeatureIndices()
{
  size_t curr_start_idx = 0;

  // Populate start index for nodes
  vector<unsigned int>::iterator iter_labels;
  for (iter_labels = m_training_labels.begin(); iter_labels != m_training_labels.end() ; iter_labels++)
  {
    m_node_stacked_feature_start_idx[*iter_labels] = curr_start_idx;
    curr_start_idx += m_node_feature_dim;
  }

  // Populate start index for cliques
  unsigned int nbr_clique_sets = m_clique_set_feature_dims.size();
  m_clique_set_stacked_feature_start_idx.clear();
  m_clique_set_stacked_feature_start_idx.resize(nbr_clique_sets);
  for (unsigned int i = 0 ; i < nbr_clique_sets ; i++)
  {
    for (iter_labels = m_training_labels.begin(); iter_labels != m_training_labels.end() ; iter_labels++)
    {
      m_clique_set_stacked_feature_start_idx[i][*iter_labels] = curr_start_idx;
      curr_start_idx += m_clique_set_feature_dims[i];
    }
  }

  m_total_stack_feature_dim = curr_start_idx;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::loadFromFile(std::string directory,
                           std::string basename)
{
  clear();

  // -------------------------------------------
  // Create filename: <basename>.m3n_model
  string in_filename = directory;
  in_filename.append("/");
  in_filename.append(basename);
  in_filename.append(".m3n_model");
  ifstream infile(in_filename.c_str());
  if (infile.is_open() == false)
  {
    cerr << "FunctionalM3N::loadFromFile could not open: " << in_filename << endl;
    return -1;
  }

  // -------------------------------------------
  // File format:
  // L = <training_labels_.size()>
  // training_labels_[0] ... training_labels_[L-1]
  //
  // S = <robust_potts_params_.size()>
  // robust_potts_params_[0] ... robust_potts_params_[S-1]
  //
  // node_feature_dim_
  // clique_set_feature_dims_[0] ... clique_set_feature_dims_[S-1]

  // ------------------------
  // read labels
  unsigned int nbr_training_labels = 0;
  infile >> nbr_training_labels;
  m_training_labels.assign(nbr_training_labels, 0);
  for (unsigned int i = 0 ; i < nbr_training_labels ; i++)
  {
    infile >> m_training_labels[i];
  }

  // ------------------------
  // read robust potts params
  unsigned int nbr_clique_sets = 0;
  infile >> nbr_clique_sets;
  m_robust_potts_params.assign(nbr_clique_sets, -1.0);
  for (unsigned int i = 0 ; i < nbr_clique_sets ; i++)
  {
    infile >> m_robust_potts_params[i];
  }

  // ------------------------
  // read feature dimensions
  infile >> m_node_feature_dim;
  m_clique_set_feature_dims.assign(nbr_clique_sets, 0);
  for (unsigned int i = 0 ; i < nbr_clique_sets ; i++)
  {
    infile >> m_clique_set_feature_dims[i];
  }

  // ------------------------
  // This MUST be called ONLY when the following are defined:
  //   training_labels_
  //   node_feature_dim_
  //   clique_set_feature_dims_
  initStackedFeatureIndices();

  int ret_val = doLoadFromFile(infile, directory, basename);
  infile.close();

  if (ret_val == 0)
  {
    m_trained = true;
  }
  return ret_val;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::saveToFile(std::string directory,
                         std::string basename)
{
  if (!m_trained)
  {
    cerr << "M3NModel::saveToFile cannot save untrained model" << endl;
    return -1;
  }

  string empty_suffix("");
  return saveToFile(directory, basename, empty_suffix);
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int M3NModel::saveToFile(std::string directory,
                         std::string basename,
                         std::string basename_suffix)
{
  // -------------------------------------------
  // Create filename: <basename>.m3n_model
  string out_filename = directory;
  out_filename.append("/");
  out_filename.append(basename);
  out_filename.append(basename_suffix);
  out_filename.append(".m3n_model");
  ofstream outfile(out_filename.c_str());
  if (outfile.is_open() == false)
  {
    cerr << "M3NModel::saveToFile could not save to: " << out_filename << endl;
    return -1;
  }

  // -------------------------------------------
  // File format:
  // L = <training_labels_.size()>
  // training_labels_[0] ... training_labels_[L-1]
  //
  // S = <robust_potts_params_.size()>
  // robust_potts_params_[0] ... robust_potts_params_[S-1]
  //
  // node_feature_dim_
  // clique_set_feature_dims_[0] ... clique_set_feature_dims_[S-1]

  // ------------------------
  // labels
  outfile << m_training_labels.size() << endl;
  for (unsigned int i = 0 ; i < m_training_labels.size() ; i++)
  {
    outfile << m_training_labels[i] << " ";
  }
  outfile << endl;

  outfile << endl;

  // ------------------------
  // robust potts params
  outfile << m_robust_potts_params.size() << endl;
  for (unsigned int i = 0 ; i < m_robust_potts_params.size() ; i++)
  {
    outfile << m_robust_potts_params[i] << " ";
  }
  outfile << endl;

  outfile << endl;

  // ------------------------
  // node & clique set feature dimensions
  outfile << m_node_feature_dim << endl;
  for (unsigned int i = 0 ; i < m_clique_set_feature_dims.size() ; i++)
  {
    outfile << m_clique_set_feature_dims[i] << " ";
  }
  outfile << endl;

  outfile << endl;

  int ret_val = doSaveToFile(outfile, directory, basename);
  outfile.close();
  return ret_val;
}
