#ifndef __OCV_RTREE_WRAPPER_H__
#define __OCV_RTREE_WRAPPER_H__
/*********************************************************************
 * Software License Agreement (Modified BSD License)
 *
 * Copyright (c) 2009-2010, Willow Garage, Daniel Munoz
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the copyright holders' organizations nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *********************************************************************/

#include <string>
#include <vector>
#include <iostream>
#include <fstream>

#include <opencv/ml.h>
#include <opencv/cv.h>
#include <opencv/cxcore.h>

#include <m3n/regressors/regressor_wrapper.h>

// --------------------------------------------------------------
/*!
 * \brief Container for the parameters for OpenCV regression tree
 */
// --------------------------------------------------------------
class OCVRTreeWrapperParams
{
  public:
    // --------------------------------------------------------------
    /**
     * \brief See CvDTreeParams in OpenCV: \n
     * http://opencv.willowgarage.com/wiki/MachineLearning#DecisionTrees \n
     *
     * max_tree_depth \n
     * min_sample_count \n
     * regression_accuracy \n
     * nbr_xvalidation_folds \n
     *
     * max_allocation The max allocation size, in MB, of matrix used to
     *                train regressor, negative value indicates no limit
     */
    // --------------------------------------------------------------
    OCVRTreeWrapperParams() :
      max_tree_depth(7), min_sample_count(10), regression_accuracy(0.001),
          nbr_xvalidation_folds(10), max_allocation(-1.0)
    {
    }

    unsigned int max_tree_depth;
    unsigned int min_sample_count;
    double regression_accuracy;
    unsigned int nbr_xvalidation_folds;
    double max_allocation;
};

// --------------------------------------------------------------
/*!
 * \brief Wrapper around the OpenCV regression tree implementation
 */
// --------------------------------------------------------------
class OCVRTreeWrapper: public RegressorWrapper
{
  public:
    // --------------------------------------------------------------
    /**
     * \brief Create regression tree with default parameters
     *
     * See OCVRTreeWrapperParams for default values
     *
     * \warning This OCVRTreeWrapper is forever bound to these parameters
     */
    // --------------------------------------------------------------
    OCVRTreeWrapper() :
      RegressorWrapper(RegressorWrapper::OPENCV_RTREE), m_rtree(NULL)
    {
    }

    // --------------------------------------------------------------
    /**
     * \brief Train an OpenCV regression tree with the specified parameters
     *
     * \param rtree_params Regression tree parameters
     *
     * \warning This regression tree is forever bound to these parameters
     */
    // --------------------------------------------------------------
    OCVRTreeWrapper(const OCVRTreeWrapperParams& rtree_params);

    // --------------------------------------------------------------
    /**
     * \brief See RegressorWrapper::paramClone()
     */
    // --------------------------------------------------------------
    virtual RegressorWrapper* paramClone() const;

    // --------------------------------------------------------------
    /**
     * \brief See RegressorWrapper::saveToFile()
     */
    // --------------------------------------------------------------
    virtual int saveToFile(const std::string directory,
                           const std::string basename);

    // --------------------------------------------------------------
    /**
     * \brief See RegressorWrapper::loadFromFile()
     */
    // --------------------------------------------------------------
    virtual int loadFromFile(const std::string directory,
                             const std::string basename);

  protected:
    // --------------------------------------------------------------
    /**
     * \brief See RegressorWrapper::doClear()
     */
    // --------------------------------------------------------------
    virtual void doClear();

    // --------------------------------------------------------------
    /**
     * \brief See RegressorWrapper::doTrain()
     */
    // --------------------------------------------------------------
    virtual int doTrain(const std::vector<const std::vector<double>*>& interm_feature_vals,
                        const std::vector<unsigned int>& interm_start_idx,
                        const std::vector<double>& interm_target);

    // --------------------------------------------------------------
    /**
     * \brief See RegressorWrapper::doPredict()
     */
    // --------------------------------------------------------------
    virtual int doPredict(const std::vector<double>& feature_vals,
                          const unsigned int start_idx,
                          double& predicted_val) const;

  private:
    OCVRTreeWrapperParams m_rtree_params;
    CvDTree* m_rtree;
};
#endif
