/* This class describes an MDP, with all state transitions */

#include <cmath>
#include <iostream>
#include <set>
#include "MDP.hpp"
#include "QTable.hpp"
#include "Logger.hpp"

using namespace std;
using namespace spades;

MDP::MDP()
  : num_states(-1)
{
}

MDP::~MDP()
{
}

/***************************************************************************************/

bool
MDP::verify(bool print_error)
{
  unsigned size = num_states * num_actions;
  if (v_state_action_info.size() != size)
    {
      if (print_error)
	errorlog << "Bad state action size: "
		 << num_states << ' ' << num_actions << ' ' << v_state_action_info.size()
		 << ende;
      return false;
    }

  //now check that the total probability for any transitions is about 1.0
  for (StateActionInfoList::const_iterator iter = v_state_action_info.begin();
       iter != v_state_action_info.end();
       iter++)
    {
      float total_prob = 0.0;

      //we allow a state to have no transitions in order reflect end of episode positions
      if (iter->empty())
	continue;
      
      for (MDP::StateActionInfo::const_iterator iternextstate = iter->begin();
	   iternextstate != iter->end();
	   iternextstate++)
	{
	  total_prob += iternextstate->prob;
	}

      if (fabs(total_prob - 1.0) > .00001)
	{
	  if (print_error)
	    errorlog << "Probability not equal to 1.0: "
		     << (iter - v_state_action_info.begin()) << ' '
		     << total_prob
		     << ende;
	  return false;
	}
      
    }
  
  return true;
}

/***************************************************************************************/
//any transitions to the same state are combined
//the probabilities are added and the rewards are averaged
void
MDP::collapseTransitions()
{
  typedef multiset<TranInfo, lt_TranInfo_state> combine_set_t;
  for (int state = 0; state < num_states; state++)
    {
      actionlog(100) << "MDP::collapseTransitions: working on " << state << ende;
      combine_set_t combine_set;
      for (StateActionInfo::iterator iter = v_state_action_info[state].begin();
	   iter != v_state_action_info[state].end();
	   iter++)
	combine_set.insert(*iter);

      v_state_action_info[state].clear();
      
      TranInfo current_tran;
      int current_count = 0; //number of transitions in current_tran
      for (combine_set_t::iterator iter = combine_set.begin();
	   iter != combine_set.end();
	   iter++)
	{
	  if (current_count == 0)
	    {
	      current_tran = *iter;
	      current_count = 1;
	    }
	  else if (current_tran.nextstate == iter->nextstate)
	    {
	      //combine this transition, we add the rewards, and will average it out later
	      current_tran.prob += iter->prob;
	      current_tran.reward += iter->reward;
	      current_count++;
	    }
	  else
	    {
	      current_tran.reward /= current_count;
	      v_state_action_info[state].push_back(current_tran);
	      current_tran = *iter;
	      current_count = 1;
	    }
	}
      if (current_count)
	{
	  current_tran.reward /= current_count;
	  v_state_action_info[state].push_back(current_tran);
	}
    }

  actionlog(50) << "MDP::collapseTransitions: done!" << ende;
}

/***************************************************************************************/
//we pass in the qtable to allow stuff like disabled actions
//returns the number of DP iterations needed
//We use a simple DP method where we update all states in order and in place
int
MDP::solveByQTable(QTable& qt, bool print_status)
{
  if (qt.getNumStates() != num_states)
    {
      errorlog << "Can't solve by QTable without size match! "
	       << qt.getNumStates() << " == " << num_states << ende;
      return 0;
    }
  
  int num_updates;
  int num_iterations = 0;

  if (print_status)
    cout << "MDP solving by QTable: " << flush;
  
  do
    {
      num_updates = qt.mdpDPUpdate(*this);

      actionlog(100) << "MDP::solveByQTable: iteration " << num_iterations << " has "
		     << num_updates << " updates" << ende;

      actionlog(240) << "MDPsolve: iteration " << num_iterations
	//<< '\n' << qt
		     << ende;
      
      num_iterations++;
      if (print_status)
	cout << "." << flush;
    }
  while (num_updates > 0);

  if (print_status)
    cout << endl;

  return num_iterations;
}


/***************************************************************************************/

unsigned int
MDP::getIdx(int state, int action) const
{
  unsigned int idx = state * num_actions + action;
  if (idx > v_state_action_info.size())
    errorlog << "MDP: idx out of range: "
	     << idx << ' ' << v_state_action_info.size() << ' ' 
	     << num_states << ' ' << num_actions << ' ' 
	     << ende;
    
  return idx;
}

/***************************************************************************************/

std::ostream&
operator<<(std::ostream& os, const MDP& m)
{
  if (m.num_states < 0 || m.num_actions < 0)
    {
      errorlog << "Asked to print an invalid MDP" << ende;
      return os;
    }

  os << m.num_states << ' ' << m.num_actions << endl;

  for (MDP::StateActionInfoList::const_iterator iter = m.v_state_action_info.begin();
       iter != m.v_state_action_info.end();
       iter++)
    {
      os << iter->size() << ' ';
      
      for (MDP::StateActionInfo::const_iterator iternextstate = iter->begin();
	   iternextstate != iter->end();
	   iternextstate++)
	{
	  os << iternextstate->prob << ' '
	     << iternextstate->nextstate << ' '
	     << iternextstate->reward << ' ';
	}
      
      os << endl;
    }
  
  return os;
}

std::istream&
operator>>(std::istream& is, MDP& m)
{
  is >> m.num_states >> m.num_actions;

  if (!is)
    return is;

  int size = m.num_states * m.num_actions;
  m.v_state_action_info.resize(size);

  for (int i=0; i < size; i++)
    {
      MDP::StateActionInfo info;
      int tran_size;

      is >> tran_size;
      
      if (!is)
	return is;

      for (int j=0; j < tran_size; j++)
	{
	  MDP::TranInfo tran_info;

	  is >> tran_info.prob >> tran_info.nextstate >> tran_info.reward;
	  
	  if (!is)
	    return is;

	  info.push_back(tran_info);
	}

      m.v_state_action_info[i] = info;
    }

  return is;
}

