/*
 * Copyright 1995,96,97 Thierry Bousch
 * Licensed under the Gnu Public License, Version 2
 *
 * $Id: Ratio.c,v 2.8 1997/04/15 22:45:08 bousch Exp $
 *
 * The set of all (a/b) objects, where a and b belong to some ring R
 * without divisors of zero, and b != 0, is a field. We do not suppose
 * that R is euclidean or factorial; if it's not the case, the fractions
 * won't be completely simplified, but that's all.
 */

#include <stdlib.h>
#include <string.h>
#include "saml.h"
#include "saml-errno.h"
#include "mnode.h"
#include "builtin.h"

typedef struct {
	struct mnode_header hdr;
	s_mnode* num;		/* Numerator */
	s_mnode* den;		/* Denominator, or NULL if none */
} ratio_mnode;

static void ratio_free (ratio_mnode*);
static s_mnode* ratio_build (const char*);
static gr_string* ratio_stringify (ratio_mnode*);
static s_mnode* ratio_make (s_mnode*);
static s_mnode* ratio_add (ratio_mnode*, ratio_mnode*);
static s_mnode* ratio_mul (ratio_mnode*, ratio_mnode*);
static int ratio_notzero (ratio_mnode*);
static int ratio_isneg (ratio_mnode*);
static int ratio_differ (ratio_mnode*, ratio_mnode*);
static int ratio_lessthan (ratio_mnode*, ratio_mnode*);
static s_mnode* ratio_zero (ratio_mnode*);
static s_mnode* ratio_negate (ratio_mnode*);
static s_mnode* ratio_one (ratio_mnode*);
static s_mnode* ratio_invert (ratio_mnode*);
static s_mnode* ratio2integer (ratio_mnode*, s_mnode*);
static s_mnode* ratio2float (ratio_mnode*, s_mnode*);
static s_mnode* ratio2ratio (ratio_mnode*, ratio_mnode*);

static unsafe_s_mtype MathType_Rational = {
	"Rational",
	ratio_free, ratio_build, ratio_stringify,
	ratio_make, NULL,
	ratio_add, mn_std_sub, ratio_mul, mn_std_div, mn_field_gcd,
	ratio_notzero, ratio_isneg, NULL, ratio_differ, ratio_lessthan,
	ratio_zero, ratio_negate, ratio_one, ratio_invert, NULL
};

void init_MathType_Rational (void)
{
	register_mtype(ST_RATIONAL, &MathType_Rational);
	register_CV_routine(ST_RATIONAL, ST_INTEGER, ratio2integer);
	register_CV_routine(ST_RATIONAL, ST_FLOAT, ratio2float);
	register_CV_routine(ST_RATIONAL, ST_RATIONAL, ratio2ratio);
}

static inline ratio_mnode* ratio_new (void)
{
	return (ratio_mnode*) __mnalloc(ST_RATIONAL, sizeof(ratio_mnode));
}

static void ratio_free (ratio_mnode *R)
{
	unlink_mnode(R->num);
	if (R->den)
	    unlink_mnode(R->den);
	free(R);
}

static s_mnode* simplified_ratio (s_mnode *n, s_mnode *d)
{
	ratio_mnode *r;
	s_mnode *ta, *tb, *tc;

	if (mnode_notzero(d) == 0)
		return mnode_error(SE_DIVZERO, "simplified_ratio");

	if (mnode_isneg(d) == 1) {
		ta = mnode_negate(n);
		tb = mnode_negate(d);
		tc = simplified_ratio(ta, tb);
		unlink_mnode(ta);
		unlink_mnode(tb);
		return tc;
	}
	ta = mnode_gcd(n, d);
	r = ratio_new();
	r->num = mnode_div(n, ta);
	if (!mnode_differ(d, ta))
		r->den = NULL;
	else {
		r->den = mnode_div(d, ta);
#if 0
		/*
		 * Test that the denominator is invertible. For now, this
		 * code is disabled because (a) it would be too expensive
		 * for rational numbers, and (b) because mnode_invert()
		 * does not work (yet) for integers and polynomials.
		 */
		tb = mnode_invert(r->den);
		if (tb->type != ST_VOID) {
			/* The denominator is invertible */
			unlink_mnode(r->den);
			r->den = NULL;
			tc = mnode_mul(r->num, tb);
			unlink_mnode(r->num); unlink_mnode(tb);
			r->num = tc;
		}
#else
		/*
		 * What follows is a quick hack for rational functions.
		 * If the denominator is a polynomial with integer
		 * coefficients, then it's invertible iff it's +1 or -1.
		 * This code will go away someday, to be replaced by
		 * a more general scheme, as above.
		 */
		tb = mnode_one(ta);
		tc = mnode_add(tb, r->den);
		unlink_mnode(tb);
		if (!mnode_notzero(tc)) {
			/* The denominator is -1 */
			unlink_mnode(r->den);
			r->den = NULL;
			tb = mnode_negate(r->num);
			unlink_mnode(r->num);
			r->num = tb;
		}
		unlink_mnode(tc);
#endif
	}
	unlink_mnode(ta);
	return (mn_ptr)r;
}
		
static s_mnode* ratio_build (const char *str)
{
	s_mnode *n1, *n2, *quot;
	char *s2, *p;
	int len;

	if (str[0] == 0)
		return mnode_error(SE_STRING, "ratio_build");

	if ((p = strchr(str+1, '/')) == NULL) {
		n1 = mnode_build(ST_INTEGER, str);
		if (n1->type == ST_VOID)
			return n1;
		n2 = mnode_one(n1);
	} else {
		len = p - str;
		s2 = alloca(len+1);
		strncpy(s2, str, len);
		s2[len] = 0;
		n1 = mnode_build(ST_INTEGER, s2);
		if (n1->type == ST_VOID)
			return n1;
		n2 = mnode_build(ST_INTEGER, p+1);
		if (n2->type == ST_VOID)
			return n2;
	}
	quot = simplified_ratio(n1, n2);
	unlink_mnode(n1);
	unlink_mnode(n2);
	return quot;
}

static gr_string* ratio_stringify (ratio_mnode* r)
{
	gr_string *gsnum, *gsden;

	gsnum = mnode_stringify(r->num);
	if (r->den == NULL)
		return gsnum;
	/*
	 * Protect the numerator and the denominator with parentheses,
	 * in case they'd contain operators with low precedence; this is
	 * needed for rational functions.
	 */
	if (r->num->type != ST_INTEGER)
		gsnum = grs_append1(grs_prepend1(gsnum,'('),')');
	gsden = mnode_stringify(r->den);
	if (r->den->type != ST_INTEGER)
		gsden = grs_append1(grs_prepend1(gsden,'('),')');
	gsnum = grs_append1(gsnum, '/');
	gsnum = grs_append(gsnum, gsden->s, gsden->len);
	free(gsden);
	return gsnum;
}
		
static s_mnode* ratio_make (s_mnode* numerator)
{
	ratio_mnode* R = ratio_new();
	R->num = copy_mnode(numerator);
	R->den = NULL;
	return (mn_ptr)R;
}

static s_mnode* ratio2integer (ratio_mnode* R, s_mnode* model)
{
	if (R->den || R->num->type != ST_INTEGER)
		return mnode_error(SE_ICAST, "ratio2integer");
	return copy_mnode(R->num);
}

static s_mnode* ratio2ratio (ratio_mnode* R, ratio_mnode* model)
{
	s_mnode *num, *den;
	ratio_mnode *result;
	
	if (!model)
		return copy_mnode((s_mnode*)R);

	num = mnode_promote(R->num, model->num);
	if (num->type == ST_VOID)
		return num;
	den = NULL;
	if (R->den) {
		den = mnode_promote(R->den, model->num);
		if (den->type == ST_VOID)
			{ unlink_mnode(num); return den; }
	}
	result = ratio_new();
	result->num = num;
	result->den = den;
	return (s_mnode*) result;
}

static s_mnode* ratio2float (ratio_mnode* R, s_mnode* model)
{
	s_mnode *num = R->num, *den = R->den, *num1, *den1, *result;

	if (num->type != ST_INTEGER)
		return mnode_error(SE_ICAST, "ratio2integer");
	if (den == NULL) {
		/* No denominator; simply promote the numerator */
		if (model)
			return mnode_promote(num, model);
		else
			return mnode_cast(num, ST_FLOAT);
	}
	if (!model) {
		s_mnode *tmp;
		/*
		 * Be careful here; choose a number whose absolute
		 * magnitude is bigger than both num and den
		 */
		if (mnode_isneg(num) == mnode_isneg(den))
			tmp = mnode_add(num, den);
		else
			tmp = mnode_sub(num, den);

		model = mnode_cast(tmp, ST_FLOAT);
		unlink_mnode(tmp);
	} else
		copy_mnode(model);

	num1 = mnode_promote(num, model);
	den1 = mnode_promote(den, model);
	unlink_mnode(model);
	result = mnode_div(num1, den1);
	unlink_mnode(num1);
	unlink_mnode(den1);
	return result;
}

static s_mnode* ratio_zero (ratio_mnode* model)
{
	ratio_mnode* R = ratio_new();
	R->num = mnode_zero(model->num);
	R->den = NULL;
	return (mn_ptr)R;
}

static s_mnode* ratio_one (ratio_mnode* model)
{
	ratio_mnode* R = ratio_new();
	R->num = mnode_one(model->num);
	R->den = NULL;
	return (mn_ptr)R;
}

static int ratio_notzero (ratio_mnode *R)
{
	return mnode_notzero(R->num);
}

static int ratio_isneg (ratio_mnode *R)
{
	return mnode_isneg(R->num);
}

static int ratio_differ (ratio_mnode *R1, ratio_mnode *R2)
{
	s_mnode *ta, *tb;
	int result;

	if (R2->den)
		ta = mnode_mul(R1->num, R2->den);
	else ta = copy_mnode(R1->num);

	if (R1->den)
		tb = mnode_mul(R2->num, R1->den);
	else tb = copy_mnode(R2->num);

	result = mnode_differ(ta, tb);
	unlink_mnode(ta); unlink_mnode(tb);
	return result;
}

static int ratio_lessthan (ratio_mnode *R1, ratio_mnode *R2)
{
	s_mnode *ta, *tb;
	int result;

	if (R2->den)
		ta = mnode_mul(R1->num, R2->den);
	else ta = copy_mnode(R1->num);

	if (R1->den)
		tb = mnode_mul(R2->num, R1->den);
	else tb = copy_mnode(R2->num);

	result = mnode_lessthan(ta, tb);
	unlink_mnode(ta); unlink_mnode(tb);
	return result;
}

static s_mnode* ratio_negate (ratio_mnode *r)
{
	ratio_mnode* R = ratio_new();
	R->num = mnode_negate(r->num);
	R->den = r->den? copy_mnode(r->den) : NULL;
	return (mn_ptr)R;
}

static s_mnode* ratio_invert (ratio_mnode *r)
{
	s_mnode *n = r->num, *d = r->den, *R;

	d = d ? copy_mnode(d) : mnode_one(n);
	R = simplified_ratio(d, n);
	unlink_mnode(d);
	return R;
}

/*
 * To add an integer with a fraction, we use the formula
 *
 *        B      AC+B
 *   A + --- == ------
 *        C       C
 *
 * and this fraction is irreducible if B/C is.
 */

static s_mnode* add_int_frac (s_mnode* A, s_mnode* B, s_mnode* C)
{
	ratio_mnode* R = ratio_new();
	s_mnode *ta;

	ta = mnode_mul(A, C);
	R->num = mnode_add(ta, B);
	unlink_mnode(ta);
	R->den = copy_mnode(C);
	return (mn_ptr)R;
}

/*
 * General case: add two fractions
 *
 *    A     C      A(D/g) + C(B/g)
 *   --- + --- == -----------------  where g = gcd(B,D)
 *    B     D           BD/g
 *
 * and this fraction is irreducible if A/B and C/D are irreducible and g=1,
 * but we don't use this property.
 */

static s_mnode* add_frac2 (s_mnode* A, s_mnode* B, s_mnode* C, s_mnode* D)
{
	s_mnode *g, *BC_g, *AD_g, *BD_g, *t1, *R;

	g = mnode_gcd(B, D);
	if (!mnode_differ(g, B)) {
		/* case B==g */
		BC_g = copy_mnode(C);
		BD_g = copy_mnode(D);
		t1 = mnode_div(D, g);
		AD_g = mnode_mul(t1, A);
		unlink_mnode(t1);
	}
	else if (!mnode_differ(g, D)) {
		/* case D==g */
		t1 = mnode_div(B, g);
		BC_g = mnode_mul(t1, C);
		unlink_mnode(t1);
		BD_g = copy_mnode(B);
		AD_g = copy_mnode(A);
	}
	else {
		/* General case */
		t1 = mnode_div(B, g);
		BC_g = mnode_mul(t1, C);
		BD_g = mnode_mul(t1, D);
		unlink_mnode(t1);
		t1 = mnode_div(D, g);
		AD_g = mnode_mul(t1, A);
		unlink_mnode(t1);
	}
	unlink_mnode(g);
	t1 = mnode_add(AD_g, BC_g);
	unlink_mnode(AD_g); unlink_mnode(BC_g);
	R = simplified_ratio(t1, BD_g);
	unlink_mnode(t1); unlink_mnode(BD_g);
	return R;
}

static s_mnode* ratio_add (ratio_mnode *r1, ratio_mnode *r2)
{
	s_mnode *A, *B, *C, *D;
	ratio_mnode *R;

	A = r1->num;
	if (!mnode_notzero(A))
		return copy_mnode((mn_ptr)r2);
	C = r2->num;
	if (!mnode_notzero(C))
		return copy_mnode((mn_ptr)r1);
	B = r1->den;
	D = r2->den;
	if (B != NULL) {
		if (D != NULL)
			return add_frac2(A,B,C,D);
		else
			return add_int_frac(C,A,B);
	} else {
		if (D != NULL)
			return add_int_frac(A,C,D);
		else {
			R = ratio_new();
			R->num = mnode_add(A,C);
			R->den = NULL;
			return (mn_ptr)R;
		}
	}
}

/*
 * Multiplication of two fractions:
 *
 *    A     C      (A/g)(C/h)
 *   --- x --- == ------------  where g=gcd(A,D) and h=gcd(B,C),
 *    B     D      (B/h)(D/g)
 *
 * and the fraction is irreducible if A/B and C/D are.
 */

static s_mnode* ratio_mul (ratio_mnode *r1, ratio_mnode *r2)
{
	s_mnode *g, *one, *A, *B, *C, *D, *A1, *B1, *C1, *D1;
	ratio_mnode *R;

	A = r1->num;
	if (!mnode_notzero(A))
		return copy_mnode((mn_ptr)r1);
	C = r2->num;
	if (!mnode_notzero(C))
		return copy_mnode((mn_ptr)r2);
	B = r1->den;
	D = r2->den;
	one = mnode_one(A);
	/* Remove common factors between A and D */
	if (D == NULL) {
		A1 = copy_mnode(A);
		D1 = copy_mnode(one);
	} else {
		g = mnode_gcd(A, D);
		if (!mnode_differ(g, one)) {
			A1 = copy_mnode(A);
			D1 = copy_mnode(D);
		} else {
			A1 = mnode_div(A, g);
			D1 = mnode_div(D, g);
		}
		unlink_mnode(g);
	}
	/* Remove common factors between B and C */
	if (B == NULL) {
		C1 = copy_mnode(C);
		B1 = copy_mnode(one);
	} else {
		g = mnode_gcd(B, C);
		if (!mnode_differ(g, one)) {
			C1 = copy_mnode(C);
			B1 = copy_mnode(B);
		} else {
			C1 = mnode_div(C, g);
			B1 = mnode_div(B, g);
		}
		unlink_mnode(g);
	}
	R = ratio_new();
	R->num = mnode_mul(A1, C1);
	unlink_mnode(A1);
	unlink_mnode(C1);

	/* Is this really worth the effort? */
	if (D == NULL)
		g = B ? copy_mnode(B1) : copy_mnode(one);
	else
		g = B ? mnode_mul(B1, D1) : copy_mnode(D1);
	unlink_mnode(B1);
	unlink_mnode(D1);

	if (!mnode_differ(g, one)) {
		unlink_mnode(g);
		g = NULL;
	}
	unlink_mnode(one);
	R->den = g;
	return (mn_ptr)R;
}
