/* sparse polynomial multiplication */
/* using a chained heap of pointers */
/* Roman Pearce, CECM/SFU, May 2010 */

/* 1 = print terms    */
/* 0 = print cpu time */
#define PRINTING 0

/* tested on Linux and Mac OS X */
/*
   example: heapmul 1000 2000 3 5
   generates two polynomials and multiplies them

   the first polynomial has 1000 terms and the difference
   between consecutive exponents is a random integer 1 <= r <= 3

   the second polynomial has 2000 terms and the difference
   between consecutive exponents is a random integer 1 <= r <= 5

   all arguments are optional but the defaults are tiny

   the program creates a heap the size of the first polynomial
   and merges while incrementing along the terms of the second
   this is inefficient if the first polynomial has more terms
   in practice, use a heap the size of the smaller polynomial
*/

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#if PRINTING
#define PRINT printf
#else
#define PRINT(...)
#endif


#define TIMER(x) (((double)(x) + CLOCKS_PER_SEC/10000.)/CLOCKS_PER_SEC)


/* machine word */
#define M_INT long
#define U_INT unsigned M_INT
#define I(x) (M_INT)(x)
#define UI(x) (U_INT)(x)


/* monomial operations */
#define MEQ(p,q)   (UI(p)==UI(q))
#define MGT(p,q)   (UI(p) > UI(q))
#define MCOMP(p,q) (UI(p) > UI(q) ? 1 : UI(p)==UI(q) ? 0 : -1)
#define MMUL(p,q) I(UI(p)+UI(q))


/* the product of a polynomial and a term */
/* increments along the terms of the poly */
typedef struct prod_t prod_t;
struct prod_t {
	prod_t  *next;  /* chain products in the heap	*/
	M_INT   *pt;    /* poly term (incrementing)	*/
	M_INT   *qt;    /* poly term (fixed multiplier)	*/
};


/* heap elements */
typedef struct heap_t heap_t;
struct heap_t {
	M_INT   mon;    /* monomial of pt*qt  */
	prod_t *ppd;    /* pointer to product */
};


/* remove top element of heap */
/* gcc 4.x messes up the loop */
/* n = number of heap element */
static inline void heap_shrink(heap_t *heap, M_INT *n)
{
	M_INT i, j, s;
	s = (*n);
	i = 1;
	looptop:
		j = 2*i;
		if (j >= s) goto done;
		if (MGT(heap[j+1].mon, heap[j].mon)) j++;
		if (!MGT(heap[j].mon, heap[s].mon)) goto done;
		heap[i] = heap[j];
		i = j;
		goto looptop;
	done:
	heap[i] = heap[s];
	(*n)--;
}


/* chained heap insert */
/* ppd = product pt*qt */
/* mon = monomial of prod */
/* n   = size of the heap */
static inline void heap_insert(heap_t *heap, M_INT *n, M_INT mon, prod_t *ppd)
{
	M_INT i, j, d;
	/* check the top first */
	if (*n && MEQ(mon, heap[1].mon)) {
		ppd->next = heap[1].ppd;
		heap[1].ppd = ppd;
		return;
	}
	/* find where to insert (no data movement) */
	for (d=1, i=(*n)+1, j=i/2; j > 0; i=j, j=j/2) {
		d = MCOMP(mon, heap[j].mon);
		if (d <= 0) break;
	}
	/* chain elements if equal */
	if (d==0) {
		ppd->next = heap[j].ppd;
		heap[j].ppd = ppd;
	}
	/* move elements to insert */
	else {
		for (d=i, i=(*n)+1, j=i/2; i > d; heap[i]=heap[j], i=j, j=j/2);
		ppd->next = NULL;
		heap[i].mon = mon;
		heap[i].ppd = ppd;
		(*n)++;
	}
}


/* insert next product if it exists */
/* bound is a sentinel for the poly */
static inline void reinsert_mul(heap_t *heap, M_INT *hsize, prod_t *ppd, M_INT *bound)
{
	ppd->pt += 2;
	if (ppd->pt < bound) {
		heap_insert(heap, hsize, MMUL(*(ppd->pt),*(ppd->qt)), ppd);
	}
}


/* multiply poly0 and poly1 (n0 and n1 terms) */
/* put first n3 terms of the product into res */
/* we expect n0 <= n1 so swap polys if needed */
static M_INT sdmp_multiply(M_INT *poly0, M_INT n0, M_INT *poly1, M_INT n1, M_INT *res, M_INT rmax)
{
	heap_t *heap;
	M_INT  hsize, i, n2=0;
        prod_t *ppd, *next, *prev, *queue, *maxp;
	M_INT *bound, *first;

	/* coefficient, monomial */
	M_INT a, m;

	/* allocate heap and products */
	heap = malloc((n0+1)*sizeof(heap_t));
	ppd  = malloc(n0*sizeof(prod_t));
	hsize = 0;

	/* initialize products */
	first = poly1;
	bound = first + 2*n1;
	for (next=ppd, i=0; i < n0; i++) {
		next->next = NULL;
		next->pt   = first - 2;
		next->qt   = poly0 + 2*i;
		next++;
	}
	/* last valid product */
	maxp = next-1;

	/* insert first term and multiply */
	reinsert_mul(heap, &hsize, ppd, bound);
	while (hsize) {
		/* extract and merge largest terms */
		/* put extracted products in queue */
		a = 0;
		m = heap[1].mon;
		queue = heap[1].ppd;	/* chain of all products we extract */
		next  = queue;		/* current product in current chain */
		extract_next:
			/* multiply coefficients and accumulate */
			a += (*(next->pt+1)) * (*(next->qt+1));
			prev = next;
			next = next->next;
			/* more terms in chain ? */
			if (next) goto extract_next;
			heap_shrink(heap, &hsize);
			/* new max monomial equal ? */
			if (hsize > 0 && MEQ(m, heap[1].mon)) {
				/* append this chain to last chain */
				prev->next = heap[1].ppd;
				next = prev->next;
				goto extract_next;
			}

		/* for all extracted terms f[i]*g[j] */
		/* we insert the next term into heap */
		reinsert_next:
			next = queue;
			queue = queue->next;
			/* after merging f[1]*g[j], insert f[1]*g[j+1] */
			if (next->pt == first && next != maxp)
				reinsert_mul(heap, &hsize, next+1, bound);
			/* now insert the next term f[i]*g[j+1] */
			reinsert_mul(heap, &hsize, next, bound);
		if (queue) goto reinsert_next;

		/* store computed term */
		PRINT("+%d*x^%d\n", a, m);
		if (a) {
			*(res) = m;
			*(res+1) = a;
			n2++;
			if (--rmax == 0) break;
		}
	}
	free(heap);
	free(ppd);
	return n2;
}


/* generate random integer A <= r < B */
#define RAND(A,B)  (M_INT)(((((double)rand())/RAND_MAX)*((B)-(A)))+(A))


int main(int argc, char *argv[]) {
	M_INT *poly0, *poly1, *res;
	M_INT n0, n1, n2, rmax;
	M_INT r, s, t, i;
	clock_t c0, c1;

	n0 = 10;
	n1 = 10;
	r = 1;
	s = 1;
	switch (argc) {
		default:
		case 5: s  = atoi(argv[4]);
		case 4:	r  = atoi(argv[3]);
		case 3: n1 = atoi(argv[2]);
		case 2: n0 = atoi(argv[1]);
		case 1:
		break;
	}
	rmax  = n0*n1;
	poly0 = malloc(2*n0*sizeof(M_INT));
	poly1 = malloc(2*n1*sizeof(M_INT));
	res   = malloc(2*rmax*sizeof(M_INT));

	/* generate polynomials */
	/* terms must be in descending order */
	t = 0;
	PRINT("poly0\n");
	for (i=0; i < n0; i++) {
		poly0[2*(n0-i-1)]   = t;	/* monomial */
		poly0[2*(n0-i-1)+1] = 1;	/* coefficient */
		PRINT("+%d*x^%d", poly0[2*(n0-i-1)+1], poly0[2*(n0-i-1)]);
		t += RAND(1,r+1);
	}
	PRINT("\n");
	PRINT("poly1\n");
	t = 0;
	for (i=0; i < n1; i++) {
		poly1[2*(n0-i-1)]   = t;	/* monomial */
		poly1[2*(n0-i-1)+1] = 1;	/* coefficient */
		PRINT("+%d*x^%d", poly1[2*(n0-i-1)+1], poly1[2*(n0-i-1)]);
		t += RAND(1,s+1);
	}
	PRINT("\n");

	PRINT("result\n");
c0 = clock();
	n2 = sdmp_multiply(poly0, n0, poly1, n1, res, rmax);
c1 = clock();
	printf("%d x %d = %d terms, W(f,g) = %.2f", n0, n1, n2, ((double)n0*n1)/n2);
	/* print times if no console output */
	if (!PRINTING) printf(", cpu time = %.6f s", TIMER(c1-c0));
	printf("\n");

	return 0;
}