#include "CFDSolver.h"
#include <string.h>
#include <stdio.h>
#include <iostream>

#define IX(i,j) ((i)+(m_N+2)*(j))
#define SWAP(x0,x) {float * tmp=x0;x0=x;x=tmp;}
#define FOR_EACH_CELL for (int i=1 ; i<=m_N ; i++ ) { for (int j=1 ; j<=m_N ; j++ ) {
#define END_FOR }}

CFDSolver::CFDSolver(int N, float dt, float diffusion, float viscosity)
:m_dt(dt), m_diff(diffusion), m_visc(viscosity), m_N(N),
 m_size((N+2)*(N+2))
{
	m_density 	=	newClean(m_size);
	m_dens_prev =	newClean(m_size);
	m_vecu 			=	newClean(m_size);
	m_vecv			=	newClean(m_size);
	m_vecu_prev	=	newClean(m_size);
	m_vecv_prev	=	newClean(m_size);
}

CFDSolver::~CFDSolver()
{
	
}

float* CFDSolver::newClean(int length)
{
	float* ret = new float[length];
	memset(ret,0,sizeof(float)*length);
	return ret;
}

void CFDSolver::advect (int b, float * d, float * d0, float * u, float * v)
{
	int i0, j0, i1, j1;
	float x, y, s0, t0, s1, t1, dt0;

	dt0 = m_dt*m_N;
	FOR_EACH_CELL
		x = i-dt0*u[IX(i,j)];
		y = j-dt0*v[IX(i,j)];
		//perform a clamp
		if (x<0.5f) x=0.5f; 
		if (x>m_N+0.5f) x=m_N+0.5f;
		i0=(int)x; 
		i1=i0+1;
		if (y<0.5f) y=0.5f;
		if (y>m_N+0.5f) y=m_N+0.5f; 
		j0=(int)y; 
		j1=j0+1;
		s1 = x-i0; s0 = 1-s1; t1 = y-j0; t0 = 1-t1;
		d[IX(i,j)] = s0*(t0*d0[IX(i0,j0)]+t1*d0[IX(i0,j1)])+
					 s1*(t0*d0[IX(i1,j0)]+t1*d0[IX(i1,j1)]);
	END_FOR
	set_bnd (b, d);
}

void CFDSolver::dens_step (float * x, float * x0, float * u, float * v)
{
	diffuse (0, x0, x, m_diff);
	advect (0, x, x0, u, v);
}

void CFDSolver::diffuse (int b, float * x, float * x0, float diff)
{
	float a=m_dt*diff*m_N*m_N;
	lin_solve (b, x, x0, a, 1+4*a );
}

void CFDSolver::lin_solve (int b, float * x, float * x0, float a, float c )
{
	for (int k=0 ; k<20 ; k++ ) {
		FOR_EACH_CELL
			x[IX(i,j)] = (x0[IX(i,j)] + a*(x[IX(i-1,j)]+x[IX(i+1,j)]+x[IX(i,j-1)]+x[IX(i,j+1)]))/c;
		END_FOR
		set_bnd (b, x);
	}
}


void CFDSolver::project (float * u, float * v, float * p, float * div )
{
	FOR_EACH_CELL
		div[IX(i,j)] = -0.5f*(u[IX(i+1,j)]-u[IX(i-1,j)]+v[IX(i,j+1)]-v[IX(i,j-1)])/m_N;
		p[IX(i,j)] = 0;
	END_FOR	
	set_bnd (0, div );
	set_bnd (0, p );

	lin_solve (0, p, div, 1, 4 );

	FOR_EACH_CELL
		u[IX(i,j)] -= 0.5f*m_N*(p[IX(i+1,j)]-p[IX(i-1,j)]);
		v[IX(i,j)] -= 0.5f*m_N*(p[IX(i,j+1)]-p[IX(i,j-1)]);
	END_FOR
	set_bnd (1, u);
	set_bnd (2, v);
}


void CFDSolver::set_bnd (int b, float * x )
{
	int i;

	for ( i=1 ; i<=m_N ; i++ ) 
	{
		x[IX(0    ,i    )] = b==1 ? -x[IX(1  ,i  )] : x[IX(1  ,i  )];
		x[IX(m_N+1,i    )] = b==1 ? -x[IX(m_N,i  )] : x[IX(m_N,i  )];
		x[IX(i    ,0    )] = b==2 ? -x[IX(i  ,1  )] : x[IX(i  ,1  )];
		x[IX(i    ,m_N+1)] = b==2 ? -x[IX(i  ,m_N)] : x[IX(i  ,m_N)];
	}
	x[IX(0    ,0    )] = 0.5f*(x[IX(1  ,0    )]+x[IX(0    ,1  )]);
	x[IX(0    ,m_N+1)] = 0.5f*(x[IX(1  ,m_N+1)]+x[IX(0    ,m_N)]);
	x[IX(m_N+1,0    )] = 0.5f*(x[IX(m_N,0    )]+x[IX(m_N+1,1  )]);
	x[IX(m_N+1,m_N+1)] = 0.5f*(x[IX(m_N,m_N+1)]+x[IX(m_N+1,m_N)]);
}


void CFDSolver::vel_step (float * u, float * v, float * u0, float * v0)
{
	diffuse (1, u0, u, m_visc);
	diffuse (2, v0, v, m_visc);
	
	project (u0, v0, u, v );
	
	advect (1, u, u0, u0, v0); 
	advect (2, v, v0, u0, v0);
	
	project (u, v, u0, v0 );
}

float CFDSolver::getDensityAt(int x, int y)
{
	return m_density[IX(x,y)];
}

float CFDSolver::getUAt(int x, int y)
{
	return m_vecu[IX(x,y)];
}

float CFDSolver::getVAt(int x, int y)
{
	return m_vecv[IX(x,y)];
}

void CFDSolver::addDensityAt(float source, int x, int y)
{
	m_density[IX(x,y)] += source;
}

void CFDSolver::runTimeStep(float dt)
{
	if(dt!=0)
	{
		m_dt = dt;
	}
	vel_step(m_vecu,m_vecv,m_vecu_prev,m_vecv_prev);
	dens_step(m_density,m_dens_prev,m_vecu,m_vecv);
	remove_from_all(dt);
}

void CFDSolver::addUForce(float force, int x, int y)
{
	m_vecu[IX(x,y)] += force;
}

void CFDSolver::addVForce(float force, int x, int y)
{
	m_vecv[IX(x,y)] += force;
}

void CFDSolver::clearAll()
{
	memset(m_density,0,sizeof(float)*m_size);
	memset(m_dens_prev,0,sizeof(float)*m_size);
	memset(m_vecu,0,sizeof(float)*m_size);
	memset(m_vecv,0,sizeof(float)*m_size);
	memset(m_vecu_prev,0,sizeof(float)*m_size);
	memset(m_vecv_prev,0,sizeof(float)*m_size);
}

void CFDSolver::remove_from_all(float dt) {
	FOR_EACH_CELL
		m_density[IX(i,j)] = m_density[IX(i,j)] * (1.0-(dt/50.0));
	END_FOR
}

