/*********************************************************************
 * Software License Agreement (Modified BSD License)
 *
 * Copyright (c) 2009-2010, Willow Garage, Daniel Munoz
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the copyright holders' organizations nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *********************************************************************/

#include <m3n/regressors/ocv_rtree_wrapper.h>

using namespace std;

// --------------------------------------------------------------
/*! See function definition */
// --------------------------------------------------------------
OCVRTreeWrapper::OCVRTreeWrapper(const OCVRTreeWrapperParams& rtree_params) :
  RegressorWrapper(RegressorWrapper::OPENCV_RTREE)
{
  if (rtree_params.max_tree_depth == 0)
  {
    cerr << "OCVRTreeWrapper Max depth factor must be non-zero" << endl;
    throw;
  }
  if (rtree_params.regression_accuracy < 0.0 || rtree_params.regression_accuracy > 1.0)
  {
    cerr << "OCVRTreeWrapper Regression accuracy must be in [0,1]" << endl;
    throw;
  }
  m_rtree_params = rtree_params;
  m_rtree = NULL;
}

// --------------------------------------------------------------
/*! See function definition */
// --------------------------------------------------------------
RegressorWrapper* OCVRTreeWrapper::paramClone() const
{
  return new OCVRTreeWrapper(m_rtree_params);
}

// --------------------------------------------------------------
/*! See function definition */
// --------------------------------------------------------------
void OCVRTreeWrapper::doClear()
{
  if (m_rtree != NULL)
  {
    delete m_rtree;
  }
  m_rtree = NULL;

  // do NOT clear rtree_params_
}

// --------------------------------------------------------------
/*! See function definition */
// --------------------------------------------------------------
int OCVRTreeWrapper::saveToFile(const std::string directory,
                                const string basename)
{
  // -------------------------------------------
  // Verify regression tree is trainedtrained
  if (!m_trained)
  {
    cerr << "OCVRTreeWrapper::saveToFile Cannot save untrained regression tree" << endl;
    return -1;
  }

  // -------------------------------------------
  // Create filename: <basename>_wrapper.rtree
  string out_filename = directory;
  out_filename.append("/");
  out_filename.append(basename);
  out_filename.append("_wrapper.rtree");
  ofstream outfile(out_filename.c_str());
  if (outfile.is_open() == false)
  {
    cerr << "OCVRTreeWrapper::saveToFile Could not open file to save: " << basename << endl;
    return -1;
  }

  // -------------------------------------------
  // Create filename: <basename>_opencv.rtree
  // Save regressor tree structure to file
  string opencv_full_filename = directory;
  opencv_full_filename.append("/");
  opencv_full_filename.append(basename);
  opencv_full_filename.append("_opencv.rtree");
  m_rtree->save(opencv_full_filename.c_str());

  // -------------------------------------------
  // File format:
  // algorithm_type_
  // stacked_feature_dim_
  //
  // rtree_params_
  //
  // <opencv filename string length> (no directory)
  // <opencv filename>

  // -------------------------------------------
  outfile << m_algorithm_type << endl;
  outfile << m_stacked_feature_dim << endl;
  outfile << endl;
  outfile << m_rtree_params.max_tree_depth << endl;
  outfile << m_rtree_params.min_sample_count << endl;
  outfile << m_rtree_params.regression_accuracy << endl;
  outfile << m_rtree_params.nbr_xvalidation_folds << endl;
  outfile << endl;
  string opencv_short_filename = basename;
  opencv_short_filename.append("_opencv.rtree");
  outfile << opencv_short_filename.length() << endl;
  outfile << opencv_short_filename << endl;

  outfile.close();
  return 0;
}

// --------------------------------------------------------------
/*! See function definition */
// --------------------------------------------------------------
int OCVRTreeWrapper::loadFromFile(const std::string directory,
                                  const string basename)
{
  clear();

  // -------------------------------------------
  // Create filename: <basename>_wrapper.rtree
  string in_filename = directory;
  in_filename.append("/");
  in_filename.append(basename);
  in_filename.append("_wrapper.rtree");
  ifstream infile(in_filename.c_str());
  if (infile.is_open() == false)
  {
    cerr << "OCVRTreeWrapper::loadFromFile Could not load file: " << basename << endl;
    return -1;
  }

  // -------------------------------------------
  // File format:
  // algorithm_type_
  // stacked_feature_dim_
  //
  // rtree_params_
  //
  // <opencv filename string length> (no directory)
  // <opencv filename>

  // -------------------------------------------
  // Verify loading a regression tree with expected stacked feature dimension
  int tempo_algo_type;
  infile >> tempo_algo_type;
  if (tempo_algo_type != OPENCV_RTREE)
  {
    cerr << "OCVRTreeWrapper::loadFromFile Algorithm mismatch: " << tempo_algo_type << endl;
    return -1;
  }
  infile >> m_stacked_feature_dim;

  // Fields
  infile >> m_rtree_params.max_tree_depth;
  infile >> m_rtree_params.min_sample_count;
  infile >> m_rtree_params.regression_accuracy;
  infile >> m_rtree_params.nbr_xvalidation_folds;

  unsigned int opencv_short_filename_strlength = 0;
  infile >> opencv_short_filename_strlength;
  char opencv_short_filename[opencv_short_filename_strlength];
  infile >> opencv_short_filename;
  infile.close();

  // prepend filename with the directory path
  string opencv_full_filename = directory;
  opencv_full_filename.append("/");
  opencv_full_filename.append(opencv_short_filename);

  m_rtree = new CvDTree;
  m_rtree->load(opencv_full_filename.c_str());
  m_trained = true;
  return 0;
}

// --------------------------------------------------------------
/*! See function definition */
// --------------------------------------------------------------
int OCVRTreeWrapper::doTrain(const vector<const vector<double>*>& interm_feature_vals,
                             const vector<unsigned int>& interm_start_idx,
                             const vector<double>& interm_target)
{
  unsigned int nbr_samples = interm_feature_vals.size();

  // Use only a subset if matrix will exceed max allocation size.
  if (m_rtree_params.max_allocation > 0.0)
  {
    double nbr_bytes = m_rtree_params.max_allocation * 1048576.0; // 1 mb = 1048576 bytes
    double nbr_matrix_entries = nbr_bytes / static_cast<double> (sizeof(double));
    unsigned int nbr_rows = static_cast<unsigned int> (nbr_matrix_entries
        / static_cast<double> (m_stacked_feature_dim));
    nbr_samples = min(nbr_samples, nbr_rows);
  }

  // -------------------------------------------
  // Check number of training samples is sensible
  if (nbr_samples < m_rtree_params.min_sample_count)
  {
    cout << "OCVRTreeWrapper::doTrain Have less training samples " << nbr_samples
        << "than needed to perform split " << m_rtree_params.min_sample_count << endl;
    return -1;
  }

  // -------------------------------------------
  // Create OpenCV data structures to train regression tree

  // ------------------------------
  // Vector with target values to regress to
  CvMat* target_vals = cvCreateMat(nbr_samples, 1, CV_32F);

  // ------------------------------
  // Create nbr_samples-by-stacked_feature_dim_ matrix to hold feature values in each row
  CvMat* train_data = cvCreateMat(nbr_samples, m_stacked_feature_dim, CV_32F);
  if (train_data == NULL)
  {
    cerr << "OCVRTreeWrapper::doTrain not enough memory to allocate CvMat" << endl;
    abort();
  }

  // ------------------------------
  // Populate target_vals and sparse_feature_matrix for each sample
  for (unsigned int i = 0 ; i < nbr_samples ; i++)
  {
    // Define the target value for current sample
    cvmSet(target_vals, i, 0, interm_target[i]);

    unsigned int curr_start_idx = interm_start_idx[i];
    const vector<double>* curr_feats = interm_feature_vals[i];
    unsigned int curr_feat_length = curr_feats->size();

    for (unsigned int j = 0 ; j < m_stacked_feature_dim ; j++)
    {
      if (j < curr_start_idx || j >= (curr_start_idx + curr_feat_length))
      {
        cvmSet(train_data, i, j, 0.0);
      }
      else
      {
        cvmSet(train_data, i, j, curr_feats->at(j - curr_start_idx));
      }
    }
  }

  // ------------------------------
  // Indicate using numbers and not categories
  CvMat* var_type = cvCreateMat(m_stacked_feature_dim + 1, 1, CV_8U);
  cvSet(var_type, cvScalarAll(CV_VAR_ORDERED));

  // ------------------------------
  // Instantiate tree
  m_rtree = new CvDTree;
  bool train_success = m_rtree->train(train_data, // training data
                                      CV_ROW_SAMPLE, // how to read train_data
                                      target_vals, // target values
                                      NULL, // var_idx
                                      NULL, // sample_idx
                                      var_type, //var_type
                                      NULL, // missing mask
                                      CvDTreeParams(m_rtree_params.max_tree_depth,
                                                    m_rtree_params.min_sample_count, // min sample count
                                                    m_rtree_params.regression_accuracy, // regression accuracy
                                                    false, // do NOT compute surrogate split (no missing data)
                                                    10, // Not used b/c doing regression (default max_categories val)
                                                    m_rtree_params.nbr_xvalidation_folds, // the number of cross-validation folds
                                                    true, // use 1SE rule => smaller tree
                                                    true, // throw away the pruned tree branches
                                                    NULL // priors
                                       ));

  cvReleaseMat(&var_type);
  cvReleaseMat(&target_vals);
  cvReleaseMat(&train_data);

  if (train_success)
  {
    m_trained = true;
    return 0;
  }
  else
  {
    return -1;
  }
}

// --------------------------------------------------------------
/*! See function definition */
// --------------------------------------------------------------
int OCVRTreeWrapper::doPredict(const vector<double>& feature_vals,
                               const unsigned int start_idx,
                               double& predicted_val) const
{
  unsigned int length = feature_vals.size();

  // Create feature vec
  CvMat* cv_feature_vec = cvCreateMat(m_stacked_feature_dim, 1, CV_32F);
  for (unsigned int i = 0 ; i < m_stacked_feature_dim ; i++)
  {
    if (i < start_idx || i >= (start_idx + length))
    {
      cvmSet(cv_feature_vec, i, 0, 0.0);
    }
    else
    {
      cvmSet(cv_feature_vec, i, 0, feature_vals.at(i - start_idx));
    }
  }

  // Predict with regression tree
  predicted_val = static_cast<double> (m_rtree->predict(cv_feature_vec, NULL)->value);

  cvReleaseMat(&cv_feature_vec);

  return 0;
}
