/* this class represents a set of actions which are disabled
   This is closesly tied to the QTable notion of disabled actions */

#include <set>
#include <iterator>
#include <algorithm>
#include "ActionSubsetMap.hpp"
#include "QTable.hpp"
#include "Policy.hpp"
#include "random.hpp"
#include "Logger.hpp"


using namespace spades;
using namespace std;

/***********************************************************************************/

ActionSubsetMap::ActionSubsetMap(int num_states)
  : num_states(num_states)
{
  resize();
}

ActionSubsetMap::ActionSubsetMap(const char* fn)
  : num_states(-1)
{
  if (fn == NULL || fn[0] == 0)
    return;
  
  ifstream in(fn);
  if (!in)
    {
      errorlog << "ActionSubsetMap: could not open file '" << fn << "'" << ende;
      return;
    }

  in >> *this;

  if (in.fail())
    {
      errorlog << "ActionSubsetMap: error reading '" << fn << "'" << ende;
      return;
    }
}


/***********************************************************************************/

void
ActionSubsetMap::extractFromQTable(const QTable& qt, bool enabled)
{
  num_states = qt.getNumStates();
  
  //this also clears the map
  resize();

  for (int s=0; s<num_states; s++)
    {
      int num_actions = qt.getNumActions(s);
      for (int a=0; a<num_actions; a++)
	{
	  if (qt.isEnabled(s, a) == enabled)
	    {
	      v_actions[s].insert(a);
	    }
	}
    }
}

/***********************************************************************************/

void
ActionSubsetMap::resize()
{
  if (num_states < 0)
    {
      errorlog << "Can't resize ActionSubsetMap with these sizes: "
	       << num_states << ende;
      return;
    }
  
  v_actions.clear();
  v_actions.resize(num_states);
}

/***********************************************************************************/
/***********************************************************************************/

void
ActionSubsetMap::invert(int num_actions)
{
  ActionSet full_action_set;
  
  for (int a=0; a<num_actions; a++)
    full_action_set.insert(a);

  for (int s=0; s<num_states; s++)
    {
      ActionSet new_action_set;

      set_difference(full_action_set.begin(), full_action_set.end(),
		     v_actions[s].begin(), v_actions[s].end(),
		     insert_iterator<ActionSet>(new_action_set, new_action_set.end()));

      v_actions[s].swap(new_action_set);
    }
}

/***********************************************************************************/

//reads the number of actions from the QTable
void
ActionSubsetMap::invert(const QTable& qt)
{
  errorlog << "unimple" << ende;
}


/***********************************************************************************/

void
ActionSubsetMap::unionWith(const ActionSubsetMap& m)
{
  for (int s=0; s<num_states; s++)
    {
      ActionSet new_action_set;

      set_union(v_actions[s].begin(), v_actions[s].end(),
		m.v_actions[s].begin(), m.v_actions[s].end(),
		insert_iterator<ActionSet>(new_action_set, new_action_set.end()));

      v_actions[s].swap(new_action_set);
    }
}

/***********************************************************************************/

bool
ActionSubsetMap::isIn(int state, int action)
{
  if (state < 0 || state >= num_states)
    {
      errorlog << "isIn: State out of range: " << state << ' ' << num_states << ende;
      return false;
    }
  
  return (v_actions[state].count(action) != 0);
}

/***********************************************************************************/
int
ActionSubsetMap::count(int state) const
{
  if (state < 0 || state >= num_states)
    {
      errorlog << "count: State out of range: " << state << ' ' << num_states << ende;
      return false;
    }
  
  return (v_actions[state].size());
}

/***********************************************************************************/
int
ActionSubsetMap::getElement(int state, int idx)
{
  if (state < 0 || state >= num_states)
    {
      errorlog << "getElement: State out of range: " << state << ' ' << num_states << ende;
      return false;
    }

  ActionSet::iterator iter = v_actions[state].begin();
  
  for (int i=0; i<idx; i++)
    iter++;

  return *iter;
}


/***********************************************************************************/
/***********************************************************************************/

int
ActionSubsetMap::removeRandom(float prob)
{
  int count = 0;
  for (int s=0; s<num_states; s++)
    {
      ActionSet::iterator iter;
      ActionSet::iterator nextiter;
      for (iter = v_actions[s].begin();
	   iter != v_actions[s].end();
	   iter = nextiter)
	{
	  nextiter = iter;
	  nextiter++;
	  if (prob_random() < prob)
	    {
	      v_actions[s].erase(iter);
	      count++;
	    }
	}
    }
  return count;
}

/***********************************************************************************/

int
ActionSubsetMap::removeOnePerState()
{
  int count = 0;
  for (int s=0; s<num_states; s++)
    {
      if (v_actions[s].empty())
	continue;

      int idx = int_random(v_actions[s].size());
      ActionSet::iterator iter = v_actions[s].begin();
      for (int i=0; i<idx; i++)
	iter++;
      v_actions[s].erase(iter);
      count++;
    }
  return count;
}

/***********************************************************************************/

void
ActionSubsetMap::compareToPolicy(Policy& p, StateFilter* pfilter,
				 int* pcnt_in, int* pcnt_out, int* pcnt_skip)
{
  *pcnt_in = 0;
  *pcnt_out = 0;
  *pcnt_skip = 0;
  
  for (int sidx=num_states - 1; sidx >= 0; sidx--)
    {
      //check if we should skip it
      if (!pfilter->useState(sidx))
	{
	  (*pcnt_skip)++;
	  continue;
	}

      if (isIn(sidx, p.getAction(sidx)))
	{
	  (*pcnt_in)++;
	  //cout << "in map state: " << sidx << endl;
	}
      else
	{
	  (*pcnt_out)++;
	}
    }
}


/***********************************************************************************/
/***********************************************************************************/

//Format: <num_states>
//Format: then 1 line per state of <num disabled> <list of ints> 
std::ostream&
operator<<(std::ostream& os, const ActionSubsetMap& m)
{
  os << m.num_states << endl;

  for (int s=0; s<m.num_states; s++)
    {
      os << m.v_actions[s].size() << ' ';

      for (ActionSubsetMap::ActionSet::const_iterator iter = m.v_actions[s].begin();
	   iter != m.v_actions[s].end();
	   iter++)
	{
	  os << *iter << ' ';
	}
      
      os << endl;
    }
  
  return os;
}

std::istream&
operator>>(std::istream& is, ActionSubsetMap& m)
{
  is >> m.num_states;

  if (is.fail())
    return is;

  //also clears anything existing
  m.resize();

  for (int s=0; s<m.num_states; s++)
    {
      int size;
      is >> size;
     
      if (is.fail())
	return is;

      for (int i=0; i<size; i++)
	{
	  int a;
	  is >> a;
	  
	  if (is.fail())
	    return is;
	  
	  m.v_actions[s].insert(a);
	}
    }

  return is;
}

