#include <cm/cmmd.h>
#include <cm/timers.h>

#include "fish.h"

#define SIZE	256
#define SCALE   0.3

typedef struct {
    float x_pos, y_pos;    /* position of fish */
    float x_vel, y_vel;    /* velocity of fish */
    float mass;
} fish_t;

/* the display window */
float show[SIZE][SIZE];

/* clear everything in the window */
zapshow()
{
    int i, j;

    for (i = 0; i < SIZE; i++)
	for (j = 0; j < SIZE; j++)
	    show[i][j] = 0;
}

main (int argc, char **argv)
{
    struct arg_data args;
    int fish, myfish, tsteps, psteps;
    fish_t *allfish;
    int i, j, k;
    int x, y;
    float vsum;
    float t;
    float dt = 1;
    float maxv, maxa;
    double time;
    
    if (argc < 4) {
	printf ("syntax: %s #fish duration display_interval\n", argv[0]);
	return;
    }
    CMMD_enable ();
    fish = atoi (argv[1]);
    tsteps = atoi (argv[2]);
    psteps = atoi (argv[3]);
    myfish = fish / CMMD_partition_size ();
    allfish = (fish_t *)malloc (fish * sizeof (fish_t));

    /* initialize X display */
    openX (SIZE);
    openWindow (SIZE, "Fish1");
    imageXregister (show, SIZE);

    /* start the node program */
    args.total_fish = fish;
    args.time_steps = tsteps;
    args.steps_per_display = psteps;
    CMMD_bc_from_host (&args, sizeof (struct arg_data));

    /* looping over the time steps */
    for (k = 0, t = 0; t < tsteps; t += dt) {
	/* display fish positions */
	if ( ++k > psteps) {
	    k = 0;
	    zapshow ();
	    for (j = 0; j < CMMD_partition_size (); j++)
		CMMD_receive (j, 0, &allfish[j*myfish], sizeof(fish_t)*myfish);
	    for (j = 0; j < fish; j++) {
		x = (allfish[j].x_pos/SCALE + SIZE/2);
		y = (allfish[j].y_pos/SCALE + SIZE/2);
		if ((x >= 0) && (x < SIZE) && (y >= 0) && (y < SIZE))
			show[x][y] = 1;
	    }
	    imageXdraw (show, SIZE);
	}
	/* printout the root-mean-square velocity of all fish */
	vsum = CMMD_reduce_from_nodes_float (0, CMMD_combiner_fadd);
	maxv = CMMD_reduce_from_nodes_float (0, CMMD_combiner_fmax);
	maxa = CMMD_reduce_from_nodes_float (0, CMMD_combiner_fmax);
	dt = MIN (0.1 * maxv / maxa, 1);

	printf ("%f mean v= %f, dt = %f\n", t, sqrt (vsum / fish), dt);
    }

    printf ("Elapsed Time = %lf\n",
	    CMMD_reduce_from_nodes_double (0, CMMD_combiner_dmax));
}