#include "includes.h"
#include "knightcap.h"

#define TD_LAMBDA 0.7
#define TD_ALPHA (1.0/(EVAL_SCALE))
#define MAX_ROUNDS 10
#define MAX_SIZE 50

struct max_struct {
	double val;
	int i,j,k;
};

static int max_compare(struct max_struct *m1, struct max_struct *m2)
{
	if (m1->val > m2->val) 
		return 1;
	else if (m1->val == m2->val) 
		return 0;

	return -1;
}

extern struct state *state;
extern int player;
extern int dont_change[];

char *stage_name[] = {"OPENING", "MIDDLE", "ENDING", "MATING"};

#include "names.h"

static void p_coeff_vector(struct coefficient_name *cn, FILE *large, FILE *small)
{
	int x;
	fprintf(large,"/* %s */\n", cn->name);
	if (small)
		fprintf(small,"/* %s */\n", cn->name);
	for (x=0; x<(cn+1)->index - cn->index; x++) {
		fprintf(large,"%7d,", coefficients[cn->index + x]);
		if (small)
			fprintf(small,"%7d,", coefficients[cn->index + x]/100);
	}
	fprintf(large,"\n");
	if (small)
		fprintf(small,"\n");
}

static void p_coeff_array(struct coefficient_name *cn, FILE *large, FILE *small)
{
	int x;
	fprintf(large,"/* %s */\n", cn->name);
	if (small)
		fprintf(small,"/* %s */\n", cn->name);
	for (x=0; x<(cn+1)->index - cn->index; x++) {
		fprintf(large,"%7d,", coefficients[cn->index + x]);
		if (small)
			fprintf(small,"%7d,", coefficients[cn->index + x]/100);
		if ((x+1)%10 == 0) {
			fprintf(large, "\n");
			if (small)
				fprintf(small, "\n");
		}
	}
	fprintf(large,"\n");
	if (small)
		fprintf(small,"\n");
}


static void p_coeff_board(struct coefficient_name *cn, FILE *large, FILE *small)
{
	int x, y;
	fprintf(large,"/* %s */\n", cn->name);
	if (small)
		fprintf(small,"/* %s */\n", cn->name);
	for (y=0; y<8; y++) {
		for (x=0; x<8; x++) {
			fprintf(large,"%7d,", coefficients[cn->index + x + y*8]);
			if (small) 
				fprintf(small,"%7d,", coefficients[cn->index + x + y*8]/100);
		}
		fprintf(large,"\n");
		if (small)
			fprintf(small,"\n");
	}
}

static void p_coeff_half_board(struct coefficient_name *cn, FILE *large, FILE *small)
{
	int x, y;
	fprintf(large,"/* %s */\n", cn->name);
	if (small)
		fprintf(small,"/* %s */\n", cn->name);
	for (y=0; y<8; y++) {
		for (x=0; x<4; x++) {
			fprintf(large,"%7d,", coefficients[cn->index + x + y*4]);
			if (small)
				fprintf(small,"%7d,", coefficients[cn->index + x + y*4]/100);
		}
		fprintf(large,"\n");
		if (small)
			fprintf(small,"\n");
	}
}

void dump_coeffs(char *fname)
{
        struct coefficient_name *cn;
        FILE *large, *small;
        int fd;
	int i;

#if LARGE_ETYPE
	large = (FILE *)fopen("large_coeffs.h", "w");
	small = (FILE *)fopen("small_coeffs.h", "w");
#else
	large = (FILE *)fopen("small_coeffs.h", "w");
	small = NULL;
#endif	
	if (large == NULL) {
                perror(fname);
                return;
        }

        fprintf(large, "etype orig_coefficients[] = {\n");
	if (small)
		fprintf(small, "etype orig_coefficients[] = {\n");
	for (i=OPENING; i<=MATING; i++) {
		fprintf(large, "\n/* %%%s%% */\n", stage_name[i]);
		if (small)
			fprintf(small, "\n/* %%%s%% */\n", stage_name[i]);
		cn = &coefficient_names[0];
		coefficients = 	new_coefficients + i*__COEFFS_PER_STAGE__;
		while (cn->name) {
			int n = cn[1].index - cn[0].index;
			if (n == 1) {
				fprintf(large, "/* %s */ %d,\n", cn[0].name, 
					coefficients[cn[0].index]);
				if (small) 
					fprintf(small, "/* %s */ %d,\n", cn[0].name, 
						coefficients[cn[0].index]/100);
			} else if (n == 64) {
				p_coeff_board(cn,large,small);
			} else if (n == 32) {
				p_coeff_half_board(cn,large,small);
			} else if (n % 10 == 0) {
				p_coeff_array(cn,large,small);
			} else {
				p_coeff_vector(cn,large,small);
			}
			cn++;
		}
	}

        fprintf(large, "};\n");
	if (small)
		fprintf(small, "};\n");
        fclose(large);
	if (small)
		fclose(small);

        fd = open(fname, O_WRONLY | O_CREAT | O_TRUNC, 0666);
        if (fd == -1) {
                perror(fname);
                return;
        }

        Write(fd, new_coefficients, __TOTAL_COEFFS__*sizeof(new_coefficients[0]));
        close(fd);

        return;
}
  

int td_dump(char *fname)
{
	int i;
	etype sum;

	dump_coeffs(fname);

	sum = 0.0;
	for (i=0; i<__TOTAL_COEFFS__; i++) {
		sum += ABS(new_coefficients[i] - orig_coefficients[i]);
	}
	
	cprintf(0,"%d\n", sum);
	return 1;
}		

/* routines for updating the evaluation function according to the
   method of temporal differences */

#if LEARN_EVAL

void td_init() 
{
	if (state->gradient_fd)
		close(state->gradient_fd);
	state->gradient_fd = open("grad.dat", O_RDWR | O_CREAT | O_TRUNC | O_APPEND, 0666);
	state->buffered_moves = 0;
	memset(state->grad_buffer, 0, GRAD_TEMP_SIZE*sizeof(state->grad_buffer[0]));
	lprintf(0,"gradient initialised\n");
}

/* save in state->gradient_fd the partial derivative of the eval function with
   respect to each of the evaluation coefficients. Computed
   numerically*/
int td_gradient(Position *b)
{
	etype v, v2, v3, v4;
	int i, j, n, m, error;
	etype delta = 1;
	etype *grad;

	/* I couldn't think of a better place for these */
	if (state->computer != 0) 
		state->colour = state->computer;

	/* sanity check */
	if (b->stage < OPENING || b->stage > MATING) {
		lprintf(0, "**Wrong stage in gradient calc: %d\n", b->stage);
		return 0;
	}

	n = __COEFFS_PER_STAGE__;

	m = state->stored_move_num;

	b->flags &= ~FLAG_EVAL_DONE;
	b->flags &= ~FLAG_DONE_TACTICS;
	v = eval_etype(b, INFINITY, MAX_DEPTH);

	state->leaf_eval[m].v = next_to_play(b)*v;
	if (!state->demo_mode) {
		state->leaf_eval[m].v *= state->computer;
	}

	j=0;
	coefficients = new_coefficients + b->stage*__COEFFS_PER_STAGE__;
	grad = state->grad_buffer + __TOTAL_COEFFS__*state->buffered_moves + 
		b->stage*__COEFFS_PER_STAGE__;

	for (i=0;i<n;i++) {
		if (dont_change && i == dont_change[j]) {
			++j;
			continue;
		}

		coefficients[i] += delta;

		/* material only affects the eval indirectly 
		   via the board, so update the board */
		if (i > IPIECE_VALUES && i < IPIECE_VALUES+KING)
			create_pboard(b);

		b->flags &= ~FLAG_EVAL_DONE;
		b->flags &= ~FLAG_DONE_TACTICS;
		v2 = eval_etype(b, INFINITY, MAX_DEPTH);

		grad[i] = next_to_play(b)*(v2 - v) / delta;
		if (!state->demo_mode) {
			grad[i] *= state->computer;
		} 
		
		if (ABS(grad[i]) > 16)
			dprintf(0,"%d %d %d %d %d\n", b->stage, i, grad[i], v, v2);

#if TEST_GRADIENT
		coefficients[i] += 2*delta;

		if (i > IPIECE_VALUES && i < IPIECE_VALUES+KING)
			create_pboard(b);

		b->flags &= ~FLAG_EVAL_DONE;
		b->flags &= ~FLAG_DONE_TACTICS;

		v3 = eval_etype(b, INFINITY, MAX_DEPTH);

		coefficients[i] -= 3*delta;

		if (i > IPIECE_VALUES && i < IPIECE_VALUES+KING)
			create_pboard(b);

		b->flags &= ~FLAG_EVAL_DONE;
		b->flags &= ~FLAG_DONE_TACTICS;
		v4 = eval_etype(b, INFINITY, MAX_DEPTH);
		error = next_to_play(b)*(v3 - v);
		if (!state->demo_mode)
			error *= state->computer;
		error -= 3*delta*grad[i]; 
		if (ABS(error)>0) {
			lprintf(0,"***coeff: %d grad: %e error: %e %e %e %e %e\n", 
				i, 
				grad[i], 
				next_to_play(b)*(v3 - v) - 3*delta*grad[i], 
				v, v2, v3, v4);
		}
#else
		coefficients[i] -= delta;
#endif
	}

	++state->buffered_moves;
	if (state->buffered_moves == BUFFERED_MOVES) {
		Write(state->gradient_fd, state->grad_buffer, 
		      GRAD_TEMP_SIZE*sizeof(state->grad_buffer[0]));
		state->buffered_moves = 0;
		memset(state->grad_buffer, 0, GRAD_TEMP_SIZE*sizeof(state->grad_buffer[0]));
	}

	++state->stored_move_num;

	return n;
}

void td_save_bad(int fd, Position *b1)
{
	int x;

	lseek(fd, 0, SEEK_END);

	if ((x = Write(fd, b1, sizeof(Position))) != sizeof(Position)) {
		lprintf(0,"***Error saving bad eval position %d %d\n",
			sizeof(Position), x);
	} 
} 

/* Updates the 	coefficients according to the TD(lambda) algorithm. */
int td_update()
{
	int fd;
        int i,j,n,t; 
	int argmax;
	int num_moves;
	int rounds = 0;
	etype grad[300*__TOTAL_COEFFS__];
	double c, max;
	double dw[__TOTAL_COEFFS__];
	double olddw[__TOTAL_COEFFS__];
	double tanhv[MAX_GAME_MOVES];
	double d[MAX_GAME_MOVES];
	double oldnorm, newnorm, dotprod, angle;
	FILE *f;

	if (state->analysed) 
		return 0; 

	if ((f = (FILE *)fopen("rounds.dat", "r")) != NULL) {
		fscanf(f, "%d\n", &rounds);
		fclose(f);
	}

	memset(dw, 0, __TOTAL_COEFFS__*sizeof(dw[0]));
	memset(olddw, 0, __TOTAL_COEFFS__*sizeof(dw[0]));

#if DUMPING_TD_UPDATES
	fd = open("update.dat", O_RDONLY);
	if (fd != -1) {
		if (read(fd, olddw, __TOTAL_COEFFS__*sizeof(olddw[0])) != 
		    __TOTAL_COEFFS__*sizeof(olddw[0])) {
			lprintf(0, "update file corrupt\n");
		} else {
			memcpy(dw, olddw,  __TOTAL_COEFFS__*sizeof(olddw[0]));
		}
	}
	close(fd);
#endif

	if (state->stored_move_num == 0 || state->stored_move_num > 300) {
		lprintf(0, "no gradient information: %d\n", state->stored_move_num); 
		return 0;
	}

	if (state->buffered_moves > 0) {
		Write(state->gradient_fd, state->grad_buffer, 
		      GRAD_TEMP_SIZE*sizeof(state->grad_buffer[0]));
		state->buffered_moves = 0;
		memset(state->grad_buffer, 0, GRAD_TEMP_SIZE*sizeof(state->grad_buffer[0]));
	}

	memset(grad, 0, 300*__TOTAL_COEFFS__*sizeof(grad[0]));

	if (state->ics_robot && result() == TIME_FORFEIT)
		num_moves = state->stored_move_num-1;
	else 
		num_moves = state->stored_move_num;
		
	lprintf(0,"***moves: %d\n", num_moves);
	n = __TOTAL_COEFFS__;

	lseek(state->gradient_fd, 0, SEEK_SET);
	if (Read(state->gradient_fd, grad, num_moves*n*sizeof(grad[0])) != 
	    num_moves*n*sizeof(grad[0])) {
		lprintf(0, "error reading gradient information\n");
		return 0;
	}

	lprintf(0,"gradients read\n");
	
	/* Squash the evals and compute the temporal differences */
	tanhv[0] =  tanh(EVAL_SCALE*state->leaf_eval[0].v);
	for (t=0; t<num_moves-1; t++) {
		tanhv[t+1] = tanh(EVAL_SCALE*state->leaf_eval[t+1].v);
		d[t] = tanhv[t+1] - tanhv[t];
		if (!state->predicted_move[t+1] && 
		    !state->demo_mode && 
		    state->rating_change < 0)
			d[t] = RAMP(d[t]);
	}

	/* work out the outcome */
	if (state->demo_mode) {
		switch (state->won) {
		case STALEMATE: {
			if (NO_STALEMATE_LEARN) 
				return;
			d[num_moves-1] = tanh(EVAL_SCALE*DRAW_VALUE)
				- tanhv[num_moves-1];
			break;
		}
		case 1: {
			d[num_moves-1] = 1.0 - tanhv[num_moves-1];
			break;
		}
		case 0: {
			d[num_moves-1] = -1.0 - tanhv[num_moves-1];
			break;
		}
		} 
	} else {
		switch (result()) {
		case STALEMATE: {
			if (NO_STALEMATE_LEARN) 
				return;
			d[num_moves-1] = tanh(EVAL_SCALE*DRAW_VALUE)
				- tanhv[num_moves-1];
			break;
		}
		case 1: {
			d[num_moves-1] = 1.0 - tanhv[num_moves-1];
			break;
		}
		case 0: {
			d[num_moves-1] = -1.0 - tanhv[num_moves-1];
			break;
		}
		/* for time forfeited or resigned games we just assume the 
		   final eval was correct */
		case TIME_FORFEIT: {
			d[num_moves-1] = 0.0;
			break;
		}
		}
	}

	lprintf(0,"outcome: %d %d %d\n", state->won, state->colour, state->position.winner);

	for (i=0; i<num_moves; i++) {
		lprintf(0, "%d %d %lf\n", i, state->leaf_eval[i].v, d[i]);
	}

	/* calculate the coefficient updates */
	max = 0.0;
	j=0;
	for (i=0; i<n; i++) {
		/* "FACTORS" are multiplicative and have disproportionally
		   high derivatives so we don't adjust them */
		if (dont_change && i==dont_change[j]) {
			++j;
			continue;
		}	
		c = (1.0 - tanhv[0]*tanhv[0])*EVAL_SCALE*grad[i];

		for (t=0; t<num_moves; t++) {
			dw[i] += d[t]*c;
			if (t<num_moves-1) {
				c = TD_LAMBDA*c + (1-tanhv[t+1]*tanhv[t+1])*
					EVAL_SCALE*grad[(t+1)*n+i];
			}
		}
		if (ABS(dw[i]) > max) {
			max = ABS(dw[i]);
			argmax = i;
		}
	}

	lprintf(0,"max: %lf %d\n", TD_ALPHA*max, argmax);
			
#if CONJUNCTION_CALC
	/* calculate the conjunction updates */
	{
		struct stat st;
		double conjdw[100*__COEFFS_PER_STAGE__];
		int i,j,k;
		int fd, index, i1, i2, g;
		struct max_struct max1[MAX_SIZE];

		fd = open("/usr/local/chess/conj_update.dat", O_RDWR | O_CREAT);
		if (fstat(fd, &st) == -1) {
			lprintf(0, "can't open conjugate update file\n");
		}
		memset(conjdw, 0, 100*__COEFFS_PER_STAGE__*sizeof(conjdw[0]));
		if (st.st_size > 0) {
			if (read(fd, conjdw, 100*__COEFFS_PER_STAGE__*sizeof(conjdw[0])) != 
			    100*__COEFFS_PER_STAGE__*sizeof(conjdw[0])) {
				lprintf(0, "conjugate update file corrupt\n");
			}
		}
		memset(max1, 0, 100*sizeof(max1[0]));
		index = 0;
		for (i=0; i<3; i++) {
			for (j=0; j<__COEFFS_PER_STAGE__; j++) {
				i1 = i*__COEFFS_PER_STAGE__+j;
				for (k=j+1; k<__COEFFS_PER_STAGE__; k++) {
					i2 = i*__COEFFS_PER_STAGE__+k;
					if (index == 100*__COEFFS_PER_STAGE__) {
						index = 0;
						if (st.st_size > 0) {
							lseek(fd, -100*__COEFFS_PER_STAGE__*sizeof(conjdw[0]), SEEK_CUR);
						}
						if (Write(fd, conjdw, 100*__COEFFS_PER_STAGE__*sizeof(conjdw[0])) != 
						    100*__COEFFS_PER_STAGE__*sizeof(conjdw[0])) {
							lprintf(0,"failed to write conjugate updates\n");
						}
						memset(conjdw, 0, 100*__COEFFS_PER_STAGE__*sizeof(conjdw[0]));
						if (st.st_size > 0) {
							read(fd, conjdw, 100*__COEFFS_PER_STAGE__*sizeof(conjdw[0]));
						}
					}

					c=0;
					if (grad[i1] != 0 && grad[i2] != 0) {
						c = (1.0 - tanhv[0]*tanhv[0])*EVAL_SCALE;
					}
					for (t=0; t<num_moves-1; t++) {
						conjdw[index] += d[t]*c;
						c = TD_LAMBDA*c;
						if (grad[(t+1)*__TOTAL_COEFFS__+i1] != 0 && 
						    grad[(t+1)*__TOTAL_COEFFS__+i2] != 0) {
							c += (1-tanhv[t+1]*tanhv[t+1])*EVAL_SCALE;
						}
					}
					conjdw[index] += d[num_moves-1]*c;
					if (i == 0 && ABS(conjdw[index]) > max1[0].val) {
						max1[0].val = ABS(conjdw[index]);
						max1[0].i = i;
						max1[0].j = j;
						max1[0].k = k;
						qsort(max1, MAX_SIZE, sizeof(max1[0]), max_compare);
					}
					++index;
				}
			}
		}
		
		close(fd);
		for (i=0; i<MAX_SIZE; i++)
			lprintf(0,"%d %lf %d %d %d\n", 
				i, TD_ALPHA*max1[i].val, max1[i].i, max1[i].j, max1[i].k);
		
	}
#endif
	oldnorm = 0.0;
	newnorm = 0.0;
	dotprod = 0.0;
	for (i=0; i<n; i++) {
		oldnorm += ((double)new_coefficients[i]*(double)new_coefficients[i]);
		newnorm += (new_coefficients[i]+TD_ALPHA*dw[i])*(new_coefficients[i]+TD_ALPHA*dw[i]);
		dotprod += (new_coefficients[i] + TD_ALPHA*dw[i])*new_coefficients[i];
	}
	angle = 0.0;
	if (oldnorm != 0)
		angle = 180*acos(dotprod/sqrt(oldnorm*newnorm))/PI;
	lprintf(0, "change in angle: %lg\n", angle);
	f = (FILE *)fopen("angle.dat", "a");
	fprintf(f, "%g\n", angle);
	fclose(f);

	j = 0;
	for (i=0; i<n; i++) {
		if (dont_change && i==dont_change[j]) {
			++j;
			continue;
		}

		if (rounds == MAX_ROUNDS)
			new_coefficients[i] += TD_ALPHA*dw[i];
	}

#if DUMPING_TD_UPDATES
	fd = open("update.dat", O_WRONLY | O_CREAT | O_TRUNC, 0666);
	if (rounds == MAX_ROUNDS) {
		memset(dw, 0,  __TOTAL_COEFFS__*sizeof(dw[0]));
		rounds = 0;
	}
	if (Write(fd, dw, __TOTAL_COEFFS__*sizeof(dw[0])) != __TOTAL_COEFFS__*sizeof(dw[0])) {
		lprintf(0,"failed to write updates\n");
	}
	close(fd);

	++rounds;
	f = (FILE *)fopen("rounds.dat", "w");
	fprintf(f, "%d\n", rounds);
	fclose(f);
#endif
	lprintf(0,"updated coefficients\n");


	state->analysed = 1;
	return 0;
}

#else
void td_dummy(void) 
{}
#endif


