/*********************************************************************
 * Software License Agreement (Modified BSD License)
 *
 * Copyright (c) 2009-2010, Willow Garage, Daniel Munoz
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *   * Redistributions of source code must retain the above copyright
 *     notice, this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright
 *     notice, this list of conditions and the following disclaimer in the
 *     documentation and/or other materials provided with the distribution.
 *   * Neither the name of the copyright holders' organizations nor the
 *     names of its contributors may be used to endorse or promote products
 *     derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *********************************************************************/

#include <m3n/random_field.h>

using namespace std;

const unsigned int RandomField::UNKNOWN_LABEL = 0;

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
RandomField::RandomField()
{
  clear();
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
RandomField::RandomField(unsigned int nbr_clique_sets)
{
  m_clique_sets.resize(nbr_clique_sets);
  m_clique_set_dims.assign(nbr_clique_sets, 0);
  clear();
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
RandomField::~RandomField()
{
  clear();
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
void RandomField::clear()
{
  // Free nodes (using random field id)
  for (map<unsigned int, Node*>::iterator iter_rf_nodes = m_nodes.begin() ; iter_rf_nodes
      != m_nodes.end() ; iter_rf_nodes++)
  {
    delete iter_rf_nodes->second;
  }

  // Free cliques in each clique set
  unsigned int nbr_cs = m_clique_sets.size();
  for (unsigned int i = 0 ; i < nbr_cs ; i++)
  {
    for (map<unsigned int, Clique*>::iterator iter_cliques = m_clique_sets[i].begin() ; iter_cliques
        != m_clique_sets[i].end() ; iter_cliques++)
    {
      delete iter_cliques->second;
    }
  }

  // Empty data structures
  m_nodes.clear();
  m_node_dim = 0;
  m_clique_sets.clear();
  m_clique_sets.resize(nbr_cs);
  m_clique_set_dims.assign(nbr_cs, 0);
  m_labels.clear();
  m_fully_labeled = true;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int RandomField::updateLabelings(const map<unsigned int, unsigned int>& new_labeling)
{
  map<unsigned int, unsigned int>::const_iterator iter_new_labeling;

  // Ensure the number of nodes equals the new labeling mapping
  if (m_nodes.size() != new_labeling.size())
  {
    cerr << "Inconsistent number of nodes (" << m_nodes.size() << ") with mapping size ("
        << new_labeling.size() << ")" << endl;
    return -1;
  }

  // Ensure the keys (node id) in new_labeling exist in this RandomField
  for (iter_new_labeling = new_labeling.begin(); iter_new_labeling != new_labeling.end() ; iter_new_labeling++)
  {
    unsigned int node_id = iter_new_labeling->first;
    if (m_nodes.count(node_id) == 0)
    {
      cerr << "RandomField::updateLabelings Unknown node id from the map: " << node_id << endl;
      return -1;
    }
  }

  // -------------------------------------------------------
  // Update node labels
  m_labels.clear();
  for (iter_new_labeling = new_labeling.begin(); iter_new_labeling != new_labeling.end() ; iter_new_labeling++)
  {
    unsigned int new_label = iter_new_labeling->second;
    (m_nodes[iter_new_labeling->first])->setLabel(new_label);
    m_labels.insert(new_label);
  }

  // -------------------------------------------------------
  // Update label information in each clique in the clique sets
  map<unsigned int, Clique*>::iterator iter_cliques;
  Clique* curr_clique = NULL;
  for (unsigned int i = 0 ; i < m_clique_sets.size() ; i++)
  {
    map<unsigned int, Clique*>& curr_cs = m_clique_sets[i];
    for (iter_cliques = curr_cs.begin(); iter_cliques != curr_cs.end() ; iter_cliques++)
    {
      curr_clique = iter_cliques->second;
      curr_clique->updateLabels(new_labeling);
    }
  }

  return 0;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
const RandomField::Node* RandomField::createNode(const std::vector<double>& feature_vals,
                                                 unsigned int label,
                                                 double x,
                                                 double y,
                                                 double z)
{
  unsigned int unique_id = m_nodes.size();
  while (m_nodes.count(unique_id) != 0)
  {
    unique_id++;
  }
  return createNode(unique_id, feature_vals, label, x, y, z);
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
const RandomField::Node* RandomField::createNode(const unsigned int node_id,
                                                 const std::vector<double>& feature_vals,
                                                 unsigned int label,
                                                 double x,
                                                 double y,
                                                 double z)
{
  // verify clique id doesnt already eixt
  if (m_nodes.count(node_id) != 0)
  {
    cerr << "RandomField::createNode Node id already exists: " << node_id << endl;
    return NULL;
  }

  // verify feature dimensions are consistent with other nodes
  size_t feature_dim = feature_vals.size();
  if (m_nodes.size() == 0)
  {
    if (feature_dim == 0)
    {
      cerr << "RandomField::createNode 0 length features" << endl;
      return NULL;
    }
    m_node_dim = feature_dim;
  }
  else if (feature_dim != m_node_dim)
  {
    cerr << "RandomField::createNode Mismatch feature dim" << endl;
    return NULL;
  }

  RandomField::Node* new_node = new RandomField::Node(node_id, label);
  new_node->setFeatures(feature_vals);
  new_node->setXYZ(x, y, z);
  m_nodes[node_id] = new_node;
  m_labels.insert(label); // set contains only unique values
  if (label == RandomField::UNKNOWN_LABEL)
  {
    m_fully_labeled = false;
  }
  return new_node;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
const RandomField::Clique* RandomField::createClique(const unsigned int clique_set_idx,
                                                     const list<const RandomField::Node*>& nodes,
                                                     const std::vector<double>& feature_vals,
                                                     double x,
                                                     double y,
                                                     double z)
{
  if (clique_set_idx >= m_clique_sets.size())
  {
    cerr << "RandomField::createClique clique_set_idx " << clique_set_idx << " exceeds boundary "
        << m_clique_sets.size() << endl;
    return NULL;
  }

  map<unsigned int, RandomField::Clique*>& clique_set = m_clique_sets[clique_set_idx];

  // generate unique clique id
  unsigned int unique_id = clique_set.size();
  while (clique_set.count(unique_id) != 0)
  {
    unique_id++;
  }
  return createClique(unique_id, clique_set_idx, nodes, feature_vals, x, y, z);
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
const RandomField::Clique* RandomField::createClique(const unsigned int clique_id,
                                                     const unsigned int clique_set_idx,
                                                     const list<const RandomField::Node*>& nodes,
                                                     const std::vector<double>& feature_vals,
                                                     double x,
                                                     double y,
                                                     double z)
{
  // verify the clique set index is within bounds
  if (clique_set_idx >= m_clique_sets.size())
  {
    cerr << "RandomField::createClique cs index " << clique_set_idx << " OOB" << endl;
    return NULL;
  }

  // verify clique id doesnt already exist
  if (m_clique_sets[clique_set_idx].count(clique_id) != 0)
  {
    cerr << "RandomField::createClique clique id " << clique_id << "already exists in cs "
        << clique_set_idx << endl;
    return NULL;
  }

  // verify the nodes are contained in this RandomField
  list<const RandomField::Node*>::const_iterator iter_nodes;
  for (iter_nodes = nodes.begin(); iter_nodes != nodes.end() ; iter_nodes++)
  {
    const RandomField::Node* candidate_node = *iter_nodes;
    unsigned int curr_node_id = candidate_node->getID();
    if (m_nodes.count(curr_node_id) == 0)
    {
      cerr << "RandomField::createClique node id " << curr_node_id << " not in RF" << endl;
      return NULL;
    }
    const RandomField::Node* retrieved_node = m_nodes.find(curr_node_id)->second;
    if (candidate_node != retrieved_node)
    {
      cerr << "RandomField::createClique node " << curr_node_id << " object not in RF" << endl;
      return NULL;
    }
  }

  // verify feature dimensions are consistent with other nodes
  size_t feature_dim = feature_vals.size();
  if (m_clique_sets[clique_set_idx].size() == 0)
  {
    if (feature_dim == 0)
    {
      cerr << "RandomField::createClique() 0 length features" << endl;
      return NULL;
    }
    m_clique_set_dims[clique_set_idx] = feature_dim;
  }
  else if (feature_dim != m_clique_set_dims[clique_set_idx])
  {
    cerr << "RandomField::createClique() mismatch feature dim" << endl;
    return NULL;
  }

  // instantiate clique
  RandomField::Clique* new_clique = new RandomField::Clique(clique_id);
  for (iter_nodes = nodes.begin(); iter_nodes != nodes.end() ; iter_nodes++)
  {
    new_clique->addNode(**iter_nodes);
  }
  new_clique->setFeatures(feature_vals);
  new_clique->setXYZ(x, y, z);

  // add clique to map container
  map<unsigned int, RandomField::Clique*>& clique_set = m_clique_sets[clique_set_idx];
  clique_set[clique_id] = new_clique;
  return new_clique;
}

// -----------------------------------------------------------------------------------------------------------
// RandomField::GenericClique, RandomField::Node, RandomField::Clique definitions below
// -----------------------------------------------------------------------------------------------------------
RandomField::GenericClique::~GenericClique()
{
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
void RandomField::Clique::addNode(const Node& new_node)
{
  // Add node id to list
  m_node_ids.push_back(new_node.getID());

  // Add label->node_id to mapping
  unsigned int new_label = new_node.getLabel();
  if (m_labels_to_node_ids.count(new_label) == 0)
  {
    m_labels_to_node_ids[new_label] = list<unsigned int> ();
  }
  m_labels_to_node_ids[new_label].push_back(new_node.getID());
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int RandomField::Clique::updateLabels(const map<unsigned int, unsigned int>& node_labels)
{
  list<unsigned int>::iterator iter_node_ids;
  unsigned int curr_node_id = 0;

  // Verify labeling contains each node id contained in this Clique
  for (iter_node_ids = m_node_ids.begin(); iter_node_ids != m_node_ids.end() ; iter_node_ids++)
  {
    curr_node_id = *iter_node_ids;
    if (node_labels.count(curr_node_id) == 0)
    {
      cerr << "Clique::updateLabels Mismatch node ids" << endl;
      return -1;
    }
  }

  // Update label --> node_id mapping (m_labels_to_node_ids)
  m_labels_to_node_ids.clear();
  unsigned int curr_label = 0;
  for (iter_node_ids = m_node_ids.begin(); iter_node_ids != m_node_ids.end() ; iter_node_ids++)
  {
    curr_node_id = *iter_node_ids;
    curr_label = node_labels.find(curr_node_id)->second; // labeling[curr_node_id]
    if (m_labels_to_node_ids.count(curr_label) == 0)
    {
      m_labels_to_node_ids[curr_label] = list<unsigned int> ();
    }
    m_labels_to_node_ids[curr_label].push_back(curr_node_id);
  }
  return 0;
}

// --------------------------------------------------------------
/* See function definition */
// --------------------------------------------------------------
int RandomField::Clique::getModeLabels(unsigned int& mode1_label,
                                       unsigned int& mode1_count,
                                       unsigned int& mode2_label,
                                       unsigned int& mode2_count,
                                       list<unsigned int>* mode1_node_ids,
                                       const map<unsigned int, unsigned int>* tempo_labeling) const
{
  mode1_label = RandomField::UNKNOWN_LABEL;
  mode1_count = RandomField::UNKNOWN_LABEL;
  mode2_label = RandomField::UNKNOWN_LABEL;
  mode2_count = RandomField::UNKNOWN_LABEL;

  map<unsigned int, list<unsigned int> > tempo_labels_to_node_ids;

  // -------------------------------------------------------
  // Determine a mapping from label->[node_ids].
  // Recalculate if using temporary label information,
  // otherwise use internal information
  const map<unsigned int, list<unsigned int> >* labels_to_node_ids = NULL;
  if (tempo_labeling != NULL)
  {
    // populate temporary mapping: temporary_label --> [node ids]
    unsigned int curr_node_id = 0;
    unsigned int curr_tempo_node_label = 0;
    for (list<unsigned int>::const_iterator iter_node_ids = m_node_ids.begin() ; iter_node_ids
        != m_node_ids.end() ; iter_node_ids++)
    {
      curr_node_id = *iter_node_ids;

      // ---------
      // Get the node's temporary label
      if (tempo_labeling->count(curr_node_id) == 0)
      {
        cerr << "Clique::getModeLabels Couldnt find id in temp labeling: " << curr_node_id << endl;
        return -1;
      }
      curr_tempo_node_label = tempo_labeling->find(curr_node_id)->second;

      // ---------
      // Add node temporary label's list
      if (tempo_labels_to_node_ids.count(curr_tempo_node_label) == 0)
      {
        tempo_labels_to_node_ids[curr_tempo_node_label] = list<unsigned int> ();
      }
      tempo_labels_to_node_ids[curr_tempo_node_label].push_back(curr_node_id);
    }

    labels_to_node_ids = &tempo_labels_to_node_ids;
  }
  else
  {
    labels_to_node_ids = &m_labels_to_node_ids;
  }

  // -------------------------------------------------------
  // Iterate over each label, compare each's number of associated nodes and update modes appropriately
  unsigned int curr_count = 0;
  for (map<unsigned int, list<unsigned int> >::const_iterator iter = labels_to_node_ids->begin() ; iter
      != labels_to_node_ids->end() ; iter++)
  {
    const list<unsigned int>& curr_node_list = iter->second;
    curr_count = curr_node_list.size();

    // Update mode 1 if necessary
    if (curr_count > mode1_count)
    {
      // shift mode 1 to second place
      mode2_label = mode1_label;
      mode2_count = mode1_count;

      // update mode 1
      mode1_label = iter->first; // label
      mode1_count = curr_count;
    }
    // Update mode 2 if necessary
    else if (curr_count > mode2_count)
    {
      mode2_label = iter->first; // label
      mode2_count = curr_count;
    }
  }

  // -------------------------------------------------------
  // Save node ids that are labeled mode1_label if indicated
  if (mode1_node_ids != NULL)
  {
    const list<unsigned int>& mode_node_ids = labels_to_node_ids->find(mode1_label)->second;
    mode1_node_ids->assign(mode_node_ids.begin(), mode_node_ids.end());
  }

  return 0;
}
