/*
 * Copyright 1995,96 Thierry Bousch
 * Licensed under the Gnu Public License, Version 2
 *
 * $Id: Mono.c,v 2.4 1996/08/18 09:25:26 bousch Exp $
 *
 * Operations on Monomials
 */

#include <assert.h>
#include <limits.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "saml.h"
#include "saml-errno.h"
#include "mnode.h"
#include "builtin.h"

typedef struct {
	s_mnode* lit;
	int exp;
} lit_exp;

typedef struct {
	struct mnode_header hdr;
	s_mnode* coeff;
	int factors;
	lit_exp le[0];
} monomial;

#define MONO(n)  ((monomial*)(n))

static void mono_free (s_mnode*);
static s_mnode* mono_build (const char*);
static gr_string* mono_stringify (s_mnode*);
static s_mnode* mono_make (s_mnode*);
static s_mnode* mono_add (s_mnode*, s_mnode*);
static s_mnode* mono_sub (s_mnode*, s_mnode*);
static s_mnode* mono_mul (s_mnode*, s_mnode*);
static s_mnode* mono_div (s_mnode*, s_mnode*);
static s_mnode* mono_gcd (s_mnode*, s_mnode*);
static int mono_notzero (s_mnode*);
int mono_compare (s_mnode*, s_mnode*);
static s_mnode* mono_zero (s_mnode*);
static s_mnode* mono_negate (s_mnode*);
static s_mnode* mono_one (s_mnode*);
static s_mnode* literal2mono (s_mnode*, s_mnode*);
static s_mnode* mono2mono (s_mnode*, s_mnode*);
s_mnode* mono_add_sim (s_mnode*, s_mnode*);
s_mnode* decompose_powers_umono (std_mnode*, s_mnode*);
s_mnode* mono_unpack (s_mnode*);

static s_mtype MathType_Monomial = {
	"Monomial",
	mono_free, mono_build, mono_stringify,
	mono_make, NULL,
	mono_add, mono_sub, mono_mul, mono_div, mono_gcd,
	mono_notzero, NULL, NULL, NULL, NULL,
	mono_zero, mono_negate, mono_one, NULL, NULL
};

void init_MathType_Monomial (void)
{
	register_mtype(ST_MONO, &MathType_Monomial);
	register_CV_routine(ST_LITERAL, ST_MONO, literal2mono);
	register_CV_routine(ST_MONO, ST_MONO, mono2mono);
}

static inline s_mnode* mono_new (int var)
{
	return __mnalloc(ST_MONO, sizeof(monomial)+var*sizeof(lit_exp));
}

static void mono_free (s_mnode *mn)
{
	monomial *m = MONO(mn);
	int i;

	unlink_mnode(m->coeff);
	for (i = 0; i < m->factors; i++)
		unlink_mnode(m->le[i].lit);
	free(mn);
}

static s_mnode* mono_build (const char* str)
{
	/* For now we only allow literals! */
	s_mnode *mn;
	monomial *m;

	mn = mono_new(1);
	m = MONO(mn);
	m->coeff = mnode_build(ST_INTEGER, "1");
	m->factors = 1;
	m->le[0].lit = mnode_build(ST_LITERAL, str);
	m->le[0].exp = 1;
	return mn;
}

static s_mnode* mono_zero (s_mnode* model)
{
	s_mnode *mn;
	monomial *m;

	mn = mono_new(0);
	m = MONO(mn);
	m->coeff = mnode_zero (MONO(model)->coeff);
	m->factors = -1;
	return mn;
}

static s_mnode* mono_one (s_mnode* model)
{
	s_mnode *mn;
	monomial *m;

	mn = mono_new(0);
	m = MONO(mn);
	m->coeff = mnode_one (MONO(model)->coeff);
	m->factors = 0;
	return mn;
}

s_mnode* mono_make (s_mnode* constant)
{
	s_mnode *mn = mono_new(0);
	monomial *m = MONO(mn);
	m->coeff = copy_mnode(constant);
	m->factors = mnode_notzero(constant)? 0 : -1;
	return mn;
}

static s_mnode* mono_negate (s_mnode* mn1)
{
	s_mnode *mn;
	monomial *m, *m1;
	int i, fact;

	m1 = MONO(mn1);
	if ((fact = m1->factors) < 0) {
		/* It is the zero monomial */
		return copy_mnode(mn1);
	}
	mn = mono_new(fact);
	m = MONO(mn);
	/* Copy everything, modify the coefficient */
	m->coeff = mnode_negate(m1->coeff);
	m->factors = fact;
	for (i = 0; i < fact; i++) {
		m->le[i].lit = copy_mnode(m1->le[i].lit);
		m->le[i].exp = m1->le[i].exp;
	}
	return mn;
}
	
static gr_string* mono_stringify (s_mnode *mn)
{
	monomial *m = MONO(mn);
	char buff[24];
	gr_string *grs, *grlitt;
	int i, j, eat_star;

	grs = mnode_stringify(m->coeff);
	if (m->factors < 0)
		return grs;
	eat_star = 0;
	if ((grs->len == 1 && grs->s[0] == '1') ||
	    (grs->len == 2 && grs->s[1] == '1' && (grs->s[0] == '+' ||
	    	grs->s[0] == '-')))
	    		eat_star = 1;
	for (i = 0; i < m->factors; i++) {
		if (eat_star) {
			/* Remove the factor 1 and the star */
			--(grs->len);
			eat_star = 0;
		} else {
			/* Well, just append the star */
			grs = grs_append1(grs, '*');
		}
		grlitt = mnode_stringify(m->le[i].lit);
		grs = grs_append(grs, grlitt->s, grlitt->len);
		free(grlitt);
		if ((j = m->le[i].exp) > 1) {
			sprintf(buff, "^%d", j);
			grs = grs_append(grs, buff, strlen(buff));
		}
	}
	return grs;
}

static int mono_notzero (s_mnode* mn)
{
	monomial *m = MONO(mn);
	return !(m->factors < 0);
}

int mono_compare (s_mnode* mn1, s_mnode* mn2)
{
	monomial *m1 = MONO(mn1), *m2 = MONO(mn2);
	int i, f1 = m1->factors, f2 = m2->factors, f12, diff;

	/*
	 * This order is compatible with multiplication.
	 */
	if (f1 < 0 || f2 < 0)
		return f1 - f2;
	f12 = (f1 < f2) ? f1 : f2;
	for (i = 0; i < f12; i++) {
		if (m1->le[i].lit < m2->le[i].lit)
			return 1;
		if (m1->le[i].lit > m2->le[i].lit)
			return -1;
		if ((diff = m1->le[i].exp - m2->le[i].exp) != 0)
			return diff;
	}
	/* Any terms left? */
	return f1 - f2;
}

s_mnode* mono_add_sim (s_mnode* mn1, s_mnode* mn2)
{
	monomial *m1 = MONO(mn1), *m2 = MONO(mn2), *m;
	s_mnode *mn, *coeff;
	int i,f;

	if (m1->factors < 0)
		return copy_mnode(mn2);
	if (m2->factors < 0)
		return copy_mnode(mn1);
	if ((f = m1->factors) != m2->factors)
		return NULL;
	if (memcmp(m1->le, m2->le, f*sizeof(lit_exp)))
		return NULL;
	coeff = mnode_add(m1->coeff, m2->coeff);
	if (!mnode_notzero(coeff)) {
		unlink_mnode(coeff);
		return mono_zero(mn1);
	}
	mn = mono_new(f);
	m = MONO(mn);
	m->coeff = coeff;
	m->factors = f;
	for (i = 0; i < f; i++) {
		m->le[i].lit = copy_mnode(m1->le[i].lit);
		m->le[i].exp = m1->le[i].exp;
	}
	return mn;
}

static s_mnode* mono_sub_sim (s_mnode* mn1, s_mnode* mn2)
{
	monomial *m1 = MONO(mn1), *m2 = MONO(mn2), *m;
	s_mnode *mn, *coeff;
	int i,f;

	if (m1->factors < 0)
		return mono_negate(mn2);
	if (m2->factors < 0)
		return copy_mnode(mn1);
	if ((f = m1->factors) != m2->factors)
		return NULL;
	if (memcmp(m1->le, m2->le, f*sizeof(lit_exp)))
		return NULL;
	coeff = mnode_sub(m1->coeff, m2->coeff);
	if (!mnode_notzero(coeff)) {
		unlink_mnode(coeff);
		return mono_zero(mn1);
	}
	mn = mono_new(f);
	m = MONO(mn);
	m->coeff = coeff;
	m->factors = f;
	for (i = 0; i < f; i++) {
		m->le[i].lit = copy_mnode(m1->le[i].lit);
		m->le[i].exp = m1->le[i].exp;
	}
	return mn;
}

static s_mnode* mono_add (s_mnode* mn1, s_mnode* mn2)
{
	return mono_add_sim(mn1,mn2) ? : mnode_error(SE_HETERO, "mono_add");
}

static s_mnode* mono_sub (s_mnode* mn1, s_mnode* mn2)
{
	return mono_sub_sim(mn1,mn2) ? : mnode_error(SE_HETERO, "mono_sub");
}

static s_mnode* mono_mul (s_mnode* mn1, s_mnode* mn2)
{
	monomial *m1 = MONO(mn1), *m2 = MONO(mn2), *m;
	lit_exp *le1, *le2, *le, *le1_end, *le2_end, *prod;
	s_mnode *mn, *coeff;
	int factors;

	if (m1->factors < 0 || m2->factors < 0)
		return mono_zero(mn1);
	coeff = mnode_mul(m1->coeff, m2->coeff);
	if (!mnode_notzero(coeff)) {
		/* We have just found divisors of zero */
		unlink_mnode(coeff);
		return mono_zero(mn1);
	}
	prod = alloca((m1->factors + m2->factors) * sizeof(lit_exp));
	le1 = m1->le; le1_end = le1 + m1->factors;
	le2 = m2->le; le2_end = le2 + m2->factors;
	le = prod;
	while (1) {
		if (le1 == le1_end) {
			while (le2 < le2_end)
				*le++ = *le2++;
			break;
		}
		if (le2 == le2_end) {
			while (le1 < le1_end)
				*le++ = *le1++;
			break;
		}
		assert(le1 < le1_end && le2 < le2_end);
		if (le1->lit < le2->lit) {
			*le++ = *le1++;
			continue;
		} else if (le1->lit > le2->lit) {
			*le++ = *le2++;
			continue;
		}
		le->lit = le1->lit;
		le->exp = le1->exp + le2->exp;
		le++; le1++; le2++;
	}
	factors = le - prod;
	assert(factors <= m1->factors + m2->factors);
	mn = mono_new(factors); m = MONO(mn);
	m->coeff = coeff;
	m->factors = factors;
	le1 = prod; le2 = m->le;
	while(factors) {
		le2->lit = copy_mnode(le1->lit);
		le2->exp = le1->exp;
		le1++; le2++; factors--;
	}
	return mn;
}

static s_mnode* mono_div (s_mnode* mn1, s_mnode* mn2)
{
	monomial *m1 = MONO(mn1), *m2 = MONO(mn2), *m;
	lit_exp *le1, *le2, *le, *le1_end, *le2_end, *quot;
	s_mnode *mn, *coeff;
	int factors;

	if (m2->factors < 0)
		return mnode_error(SE_DIVZERO, "mono_div");
	if (m1->factors < 0)
		return copy_mnode(mn1);
	coeff = mnode_div(m1->coeff, m2->coeff);
	if (!mnode_notzero(coeff)) {
ret_zero:
		unlink_mnode(coeff);
		return mono_zero(mn1);
	}
	quot = alloca((m1->factors)*sizeof(lit_exp));
	le1 = m1->le; le1_end = le1 + m1->factors;
	le2 = m2->le; le2_end = le2 + m2->factors;
	le = quot;
	while (1) {
		if (le2 == le2_end) {
			while (le1 < le1_end)
				*le++ = *le1++;
			break;
		}
		if (le1 == le1_end)
			goto ret_zero;
		assert(le1 < le1_end && le2 < le2_end);
		if (le1->lit < le2->lit) {
			*le++ = *le1++;
			continue;
		} else if (le1->lit > le2->lit) {
			goto ret_zero;
		}
		le->lit = le1->lit;
		le->exp = le1->exp - le2->exp;
		if (le->exp < 0)
			goto ret_zero;
		if (le->exp > 0) {
			le->lit = le1->lit;
			le++;
		}
		le1++; le2++;
	}
	factors = le - quot;
	assert(factors <= m1->factors);
	mn = mono_new(factors); m = MONO(mn);
	m->coeff = coeff;
	m->factors = factors;
	le1 = quot; le2 = m->le;
	while (factors) {
		le2->lit = copy_mnode(le1->lit);
		le2->exp = le1->exp;
		le1++; le2++; factors--;
	}
	return mn;
}

/*
 * Returns the biggest e such that N^e divides P. We assume that N and P
 * are monic.
 */

static inline int biggest_power (s_mnode* N, s_mnode* P)
{
	monomial *m_N = MONO(N), *m_P = MONO(P);
	int maxe = INT_MAX, quot;
	lit_exp *le_N, *le_P, *le_N_end, *le_P_end;

	le_N = m_N->le; le_N_end = le_N + m_N->factors;
	le_P = m_P->le; le_P_end = le_P + m_P->factors;
	while (le_N < le_N_end) {
		while (le_P < le_P_end && le_P->lit < le_N->lit)
			++le_P;
		if (le_P == le_P_end || le_P->lit > le_N->lit) {
			/* This literal appears in N but not in P */
			return 0;
		}
		quot = le_P->exp / le_N->exp;
		++le_N;
		if (quot < maxe)
			maxe = quot;
		if (maxe == 0)
			return 0;  /* No need to continue */
	}
	return maxe;
}

/*
 * Returns P / N^e, which is supposed to exist. We assume N is monic.
 */

static s_mnode* divide_power_umono (s_mnode* N, s_mnode* P, int e)
{
	monomial *m_N = MONO(N), *m_P = MONO(P);
	lit_exp *le_N, *le_P, *le_N_end, *le_P_end, *le0, *le, *le1;
	int nexp, factors;
	s_mnode *mn;

	if (e == 0)
		return copy_mnode(P);
	le = le0 = alloca(m_P->factors * sizeof(lit_exp));
	le_N = m_N->le; le_N_end = le_N + m_N->factors;
	le_P = m_P->le; le_P_end = le_P + m_P->factors;
	while (le_N < le_N_end) {
		while (le_P < le_P_end && le_P->lit < le_N->lit)
			*le++ = *le_P++;
		assert(le_N->lit == le_P->lit);
		nexp = le_P->exp - e * le_N->exp;
		if (nexp) {
			le->lit = le_P->lit;
			le->exp = nexp;
			++le;
		}
		++le_N, ++le_P;
	}
	/* The remaining factors of P don't need any modification */
	while (le_P < le_P_end)
		*le++ = *le_P++;
	/* Adjust the reference counters */
	for (le1 = le0; le1 < le; le1++)
		copy_mnode(le1->lit);
	factors = le - le0;
	mn = mono_new(factors);
	MONO(mn)->coeff = copy_mnode(m_P->coeff);
	MONO(mn)->factors = factors;
	memcpy(MONO(mn)->le, le0, factors * sizeof(lit_exp));
	return mn;
}

/*
 * Decompose polynomial "poly" into powers of a monic monomial M, i.e.,
 *  poly = p[0] + p[1]*M + ... + p[len-1]*M^(len-1)
 * It is guaranteed that len >= 1.
 */

s_mnode* decompose_powers_umono (std_mnode* poly, s_mnode* M)
{
	int terms, maxpower, *nbt, *powers, i, j;
	s_mnode ***ptrs; std_mnode *mn;

	if (MONO(M)->factors <= 0 || poly->length == 1) {
dont_decompose:
		mn = mstd_alloc(ST_UPOLY, 1);
		mn->x[0] = copy_mnode((mn_ptr)poly);
		return (mn_ptr)mn;
	}
	terms = poly->length - 1;
	powers = alloca(terms * sizeof(int));
	/* First step, determine the powers */
	maxpower = 0;
	for (i = 0; i < terms; i++) {
		powers[i] = biggest_power(M, poly->x[i+1]);
		if (powers[i] > maxpower)
			maxpower = powers[i];
	}
	assert(maxpower < INT_MAX);
	if (maxpower == 0)
		goto dont_decompose;
	mn = mstd_alloc(ST_UPOLY, maxpower+1);
	/* Second, count the number of terms in each p[j] */
	nbt = alloca((maxpower+1) * sizeof(int));
	memset(nbt, 0, (maxpower+1) * sizeof(int));
	for (i = 0; i < terms; i++)
		++nbt[powers[i]];
	/* Allocate space and pointers for each p[j] */
	ptrs = alloca((maxpower+1) * sizeof(s_mnode**));
	for (j = 0; j <= maxpower; j++) {
		mn->x[j] = (mn_ptr) mstd_alloc(ST_POLY, nbt[j]+1);
		((smn_ptr)(mn->x[j]))->x[0] = copy_mnode(poly->x[0]);
		ptrs[j] = &((smn_ptr)(mn->x[j]))->x[1];
	}
	/* And reap all terms after division */
	for (i = 0; i < terms; i++) {
		j = powers[i];
		*(ptrs[j])++ = divide_power_umono(M, poly->x[i+1], j);
	}
	return (mn_ptr)mn;
}

static s_mnode* mono2mono (s_mnode* mn1, s_mnode* model)
{
	s_mnode *mn, *coeff;
	monomial *m, *m1;
	int i, terms;
	
	if (!model)
		return copy_mnode(mn1);
	coeff = mnode_promote(MONO(mn1)->coeff, MONO(model)->coeff);
	if (coeff->type == ST_VOID)
		return coeff;
	if (!mnode_notzero(coeff)) {
		mn = mono_new(0);
		m = MONO(mn);
		m->coeff = coeff;
		m->factors = -1;
		return mn;
	}
	m1 = MONO(mn1);
	terms = m1->factors;
	mn = mono_new(terms);
	m = MONO(mn);
	m->factors = terms;
	m->coeff = coeff;
	for (i = 0; i < terms; i++) {
		m->le[i].lit = copy_mnode(m1->le[i].lit);
		m->le[i].exp = m1->le[i].exp;
	}
	return mn;
}

static s_mnode* literal2mono (s_mnode* mn1, s_mnode* model)
{
	s_mnode *mn;
	monomial *m;
	
	if (!model)
		return mnode_error(SE_ICAST, "literal2mono");
	mn = mono_new(1);
	m = MONO(mn);
	m->coeff = mnode_one(MONO(model)->coeff);
	m->factors = 1;
	m->le[0].lit = copy_mnode(mn1);
	m->le[0].exp = 1;
	return mn;
}

static s_mnode* mono_gcd (s_mnode *mn1, s_mnode *mn2)
{
	return mnode_error(SE_NOTRDY, "mono_gcd");
}

s_mnode* mono_unpack (s_mnode *mn1)
{
	monomial *m = MONO(mn1);
	int i, factors;
	std_mnode *list;
	s_mnode **plist;
	extern s_mnode* mint_ibuild (int);

	factors = (m->factors >= 0) ? m->factors : 0;
	list = mstd_alloc(ST_LINE, 1 + 2 * factors);
	plist = list->x;
	*plist++ = copy_mnode(m->coeff);
	for (i = 0; i < factors; i++) {
		*plist++ = copy_mnode(m->le[i].lit);
		*plist++ = mint_ibuild(m->le[i].exp);
	}
	return (mn_ptr)list;
}
