/* a simple class to do table based Q-learning.
   Assumes unsigned ints for state and action space */

#include <cmath>
#include <iomanip>
#include "QTableFlat.hpp"
#include "Logger.hpp"

using namespace spades;
using namespace std;

/****************************************************************************************/

QTableFlat::QTableFlat(int num_states, int num_actions, float gamma)
  : QTable(num_states, gamma), num_actions(num_actions), vQTable()
{
  resizeTable();
}

QTableFlat::QTableFlat(const char* fn)
  : QTable(0, 0), num_actions(-1), vQTable()
{
  ifstream in(fn);
  if (!in)
    {
      errorlog << "QTableFlat: could not open file '" << fn << "'" << ende;
      return;
    }

  in >> *this;
  if (in.fail() && !in.eof())
    errorlog << "QTableFlat: error reading '" << fn << "'" << ende;

  if ((signed)vQTable.size() != num_states * num_actions)
    errorlog << "QTableFlat: size looks wrong: "
	     << vQTable.size() << ' '
	     << num_states * num_actions << ' '
	     << ende;
}


QTableFlat::~QTableFlat()
{
}

void
QTableFlat::resizeTable()
{
  vQTable.resize(num_states * num_actions);
}

void
QTableFlat::zero()
{
  for (int i = num_states * num_actions - 1; i>=0; i--)
    {
      if (!vQTable[i].isEnabled())
	continue;
      vQTable[i].Q = 0.0;
      vQTable[i].visits = 0;
    }
}

QTable::StateActionEntry*
QTableFlat::getSA(int state, int action)
{
  if (!checkValid(state, action))
    return NULL;

  return &(vQTable[getIdx(state, action)]);
}

const QTable::StateActionEntry*
QTableFlat::getSA(int state, int action) const
{
  if (!checkValid(state, action))
    return NULL;
  return &(vQTable[getIdx(state, action)]);
}

//a -1 for action just verifies the state
bool
QTableFlat::checkValid(int state, int action) const
{
  if (state < 0 || state >= num_states)
    return false;
  if (action != -1 &&
      (action < 0 || action >= num_actions))
    return false;
  return true;
}

unsigned int
QTableFlat::getIdx(int state, int action) const
{
  unsigned int idx = state * num_actions + action;
  if (idx > vQTable.size())
    errorlog << "QTableFlat: idx out of range: "
	     << idx << ' ' << vQTable.size() << ' ' 
	     << num_states << ' ' << num_actions << ' ' 
	     << ende;
    
  return idx;
}


/****************************************************************************************/

std::ostream&
operator<<(std::ostream& os, const QTableFlat& qt)
{
  os << qt.num_states << ' '
     << qt.num_actions << ' '
     << qt.gamma << ' ';

  for (int i= 0; i < qt.num_states * qt.num_actions; i++)
    {
      if (i % qt.num_actions == 0)
	os << endl;
      os << setw(11) << setprecision(5) << qt.vQTable[i].Q << ' ' << setw(6) << qt.vQTable[i].visits << ' ';
    }

  os << endl;
  
  return os;
}

std::istream&
operator>>(std::istream& is, QTableFlat& qt)
{
  is >> qt.num_states
     >> qt.num_actions
     >> qt.gamma;

  if (!is)
    return is;

  qt.resizeTable();
  
  for (int i= 0; i < qt.num_states * qt.num_actions; i++)
    {
      is >> qt.vQTable[i].Q >> qt.vQTable[i].visits;
      if (!is)
	return is;
    }
  
  return is;
}


