#include <math.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <list>

#include "AI_Tree.h"

using namespace std;

q_info AI_Tree::get_qinfo(const AI_State& state)
{
	int level = 0;

	if (root == 0)	root = new AI_Tree_Node(0, m_size++, state[level].branch_var, 
											AI_Interval(0,0,true,true), ai_player->get_param("START_LR"));

	AI_Tree_Node* cn = root;

	//cerr << "state: ";

	while (!cn->is_leaf())
		cn = next(cn, state, level++);

	++cn->num_visits;
	if (cn->lr > ai_player->get_param("END_LR"))
		cn->lr *= ai_player->get_param("DL");

	//cerr << "  visits: " << cn->num_visits << " lr: " << cn->lr << endl;

	return q_info(cn->mqv_map, cn->lr);
}

STATE_VALUE AI_Tree::update_qvalue(const AI_State& state, MOVE_ID move, STATE_VALUE nq)
{
	assert(root != 0);
	AI_Tree_Node* cn = root;

	int index = 0;
	while (!cn->is_leaf())
	{
		std::list<AI_Interval> I;
		cn->children.find_intervals(state[index++].value, front_inserter(I));
		cn = I.front().element();
	}

	float lr = cn->lr;
	STATE_VALUE& qvalue = cn->mqv_map[move].qvalue;
	qvalue = qvalue*(1-lr) + nq*lr;
	return qvalue;
}

STATE_VALUE AI_Tree::update_reward(const AI_State& state, MOVE_ID move, STATE_VALUE rdelta)
{
	assert(root != 0);
	AI_Tree_Node* cn = root;

	int index = 0;
	while (!cn->is_leaf())
	{
		std::list<AI_Interval> I;
		cn->children.find_intervals(state[index].value, front_inserter(I));
		cn = I.front().element();
		index++;
	}
	
	STATE_VALUE& reward = cn->mqv_map[move].reward;
	reward += rdelta;
	return reward;
}

void AI_Tree::load(PLAYER_ID player_id)
{
	map<int, AI_Tree_Node*> nodes;
	ostringstream filename;
	filename << "conf\\p" << player_id << ".qvalues";
	ifstream qvalues(filename.str().c_str(), ios::binary);

	if (!qvalues)
	{
		qvalues.close();
		qvalues.clear();
		qvalues.open("conf\\default.qvalues");
	}

	ai_player = player->ai_player;

	unsigned int id, pid;
	STATE_VAR_ID branch_var;
	STATE_VALUE inf, sup;
	AI_PARAM lr;
	MOVE_ID m_id;
	STATE_VALUE qvalue, reward;
	AI_Tree_Node *node=0, *parent=0;

	while (1)
	{
		qvalues.read((char*)&id, sizeof(unsigned int))
			   .read((char*)&pid, sizeof(unsigned int))
			   .read((char*)&branch_var, sizeof(STATE_VAR_ID))
			   .read((char*)&inf, sizeof(STATE_VALUE))
			   .read((char*)&sup, sizeof(STATE_VALUE))
			   .read((char*)&lr, sizeof(AI_PARAM));

		if (!qvalues)
			break;
			
		if (id != 0)
			parent = nodes[pid];
		else
			parent = 0;

		nodes[id] = node = new AI_Tree_Node(parent, id, branch_var, 
											AI_Interval(inf, sup, true, inf == sup), lr);

		if (parent != 0)
			parent->children.insert(node->interval);
		else
			root = node;

		while (1)
		{
			qvalues.read((char*)&m_id, sizeof(MOVE_ID));
			
			if (m_id == MOVE_ID(-5)) 
				break;

			qvalues.read((char*)&qvalue, sizeof(STATE_VALUE))
				   .read((char*)&reward, sizeof(STATE_VALUE));

			node->mqv_map[m_id] = qr_info(qvalue, reward);
		}

		m_size++;
	}
	
	qvalues.close();
}

void AI_Tree::save(PLAYER_ID player_id)
{
	ostringstream filename;
	filename << "conf\\p" << player_id << ".qvalues";
	ofstream qvalues(filename.str().c_str(), ios::binary);
	
	save(root, qvalues);

	qvalues.close();
}

void AI_Tree::save(AI_Tree_Node* cn, ostream& qvalues)
{
	unsigned int pid = (cn->parent != 0 ? cn->parent->id : 0);
	STATE_VALUE inf = cn->interval.inf(), sup = cn->interval.sup();
	MOVE_ID end = MOVE_ID(-5);

	qvalues.write((char*)&cn->id, sizeof(unsigned int))
		   .write((char*)&pid, sizeof(unsigned int))
		   .write((char*)&cn->branch_var, sizeof(STATE_VAR_ID))
		   .write((char*)&inf, sizeof(STATE_VALUE))
		   .write((char*)&sup, sizeof(STATE_VALUE))
		   .write((char*)&cn->lr, sizeof(AI_PARAM));
	
	MQV_iter m_it = cn->mqv_map.begin();
	MQV_iter m_end = cn->mqv_map.end();

	while (m_it != m_end)
	{
		qvalues.write((char*)&(*m_it).first, sizeof(MOVE_ID))
			   .write((char*)&(*m_it).second.qvalue, sizeof(STATE_VALUE))
			   .write((char*)&(*m_it).second.reward, sizeof(STATE_VALUE));
		m_it++;
	}
	
	qvalues.write((char*)&end, sizeof(MOVE_ID));

	il_iter c_it = cn->children.begin();
	il_iter c_end = cn->children.end();

	while (c_it != c_end)
		save((*c_it++).element(), qvalues);
}

void AI_Tree::clear()
{
	if (root!=0)
		rec_delete(root);

	m_size = 0;
	root = 0;
}

void AI_Tree::rec_delete(AI_Tree_Node* cn)
{
	il_iter it = cn->children.begin();
	il_iter end = cn->children.end();

	while (it != end)
	{
		rec_delete((*it).element());
		++it;
	}

	delete cn;
}

AI_Tree::AI_Tree_Node* AI_Tree::next(AI_Tree_Node* cn, const AI_State& state, int level)
{
	assert(cn->branch_var != 0);
	++cn->num_visits;

	STATE_VAR_ID sid = cn->branch_var;
	STATE_VALUE sv = state[level].value;
	
	std::list<AI_Interval> I;
	cn->children.find_intervals(sv, front_inserter(I));
	
	if (I.empty())
	{
		AI_Tree_Node* sn = new AI_Tree_Node(cn, m_size++, state[level+1].branch_var, 
											interval_rep(sid, sv), ai_player->get_param("START_LR"));
		cn->children.insert(sn->interval);
		cn = sn;
		
		if (cn->is_leaf())
			player->search_heuristic->initial_heuristic(state, q_info(cn->mqv_map, cn->lr));
	} 
	else 
		cn = I.front().element();
	
	return cn;
}

AI_Tree::AI_Interval AI_Tree::interval_rep(STATE_VAR_ID sid, STATE_VALUE sv)
{
	state_var_info& info = player->state_inspector->get_state_info(sid);

	if (info.resolution > 0)
		sv -= fmod((sv - info.minimum_value), info.resolution);

	return AI_Interval(sv, sv + info.resolution, true, info.resolution == 0);  
}