#ifndef __M3N_LOGGING_H__
#define __M3N_LOGGING_H__
/*********************************************************************
 * Software License Agreement (Modified BSD License)
 *
 * Copyright (c) 2009-2010, Willow Garage, Daniel Munoz
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the copyright holders' organizations nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *********************************************************************/

#include <time.h>

#include <sstream>
#include <map>
#include <vector>

#include <m3n/random_field.h>
#include <m3n/util/m3n_stats.h>

// --------------------------------------------------------------
/**
 * \file m3n_logging.h
 *
 * \brief Functions for printing and logging statistics from the classifier
 */
// --------------------------------------------------------------
class M3NLogger
{
  public:
    // --------------------------------------------------------------
    /**
     * \brief Save time to train regressor at each iteration
     */
    // --------------------------------------------------------------
    void addTimingRegressors(double iteration_regressor_time)
    {
      m_timings_regressors.push_back(iteration_regressor_time);
      std::cout << "Iteration REGRESSOR training time: " << m_timings_regressors.back()
          << std::endl;
    }

    // --------------------------------------------------------------
    /**
     * \brief Save time to perform inference at each iteration
     */
    // --------------------------------------------------------------
    void addTimingInference(double iteration_inference_time)
    {
      m_timings_inference.push_back(iteration_inference_time);
      std::cout << "Iteration INFERENCE overall time: " << m_timings_inference.back() << std::endl;
    }

    // --------------------------------------------------------------
    /**
     * \brief Compute performance/error
     */
    // --------------------------------------------------------------
    void addErrorRate(const std::map<unsigned int, RandomField::Node*>& nodes,
                      const std::map<unsigned int, unsigned int>& inferred_labels,
                      const std::vector<unsigned int>& labels)
    {
      std::vector<std::vector<unsigned int> > confusion_matrix;
      M3NStats::computeConfusionMatrix(nodes, labels, inferred_labels, confusion_matrix);

      double accuracy;
      M3NStats::computeAccuracy(confusion_matrix, accuracy);
      m_accuracies.push_back(accuracy);
      std::vector<double> recalls;
      M3NStats::computeRecalls(confusion_matrix, recalls);
      std::vector<double> precisions;
      M3NStats::computePrecisions(confusion_matrix, precisions);

      // Print statistics
      unsigned int nbr_correct = accuracy * nodes.size();
      std::cout << "Total correct: " << nbr_correct << " / " << nodes.size() << " = " << accuracy
          << std::endl;

      std::cout << "Recalls: ";
      for (unsigned int i = 0 ; i < labels.size() ; i++)
      {
        unsigned int curr_label = labels[i];
        std::cout << "[" << curr_label << ": " << recalls[i] << "]  ";
      }
      std::cout << std::endl;

      std::cout << "Precisions: ";
      for (unsigned int i = 0 ; i < labels.size() ; i++)
      {
        unsigned int curr_label = labels[i];
        std::cout << "[" << curr_label << ": " << precisions[i] << "]  ";
      }
      std::cout << std::endl;
    }

    std::vector<double> m_timings_regressors;
    std::vector<double> m_timings_inference;
    std::vector<double> m_accuracies;
};

#endif
