#include <cm/cmmd.h>
#include <cm/timers.h>

#include "fish.h"

#define SQR(X)		((X)*(X))
#define ABS(X)		((X>=0) ? X : -(X))
#define MYPROC 		CMMD_self_address()

typedef struct {
    float x_pos,y_pos; 	/* position of fish */
    float x_vel,y_vel;	/* velocity of fish */
    float mass;
} fish_t;	

void current(float x, float y, float *cx, float *cy)
{
    float xx,yy;	
    float mag;
    
    xx = -y;
    yy = x;
    mag = sqrt(SQR(xx)+SQR(yy));
    *cx = 3*xx / MAX(mag,0.01) - x;
    *cy = 3*yy / MAX(mag,0.01) - y;
}

/* initialize fish position, velocity, and mass */
void fish_init (int num_fish, fish_t fish_list[])
{	
    int i, k;
    
    int total_fish = num_fish*CMMD_partition_size();
    
    for (i = 0, k = MYPROC*num_fish; i< num_fish; i++, k++) {
	fish_list[i].x_pos = (float)k*2.0/total_fish - 1.0;
	fish_list[i].y_pos = (float)k*2.0/total_fish - 1.0;
	fish_list[i].x_vel = - fish_list[i].y_pos;
	fish_list[i].y_vel = fish_list[i].x_pos;
	fish_list[i].mass = 1.0+(float)k / total_fish;
    }
}

main()
{
    struct arg_data args;
    int total_fish;
    int time_steps;
    int steps_per_display;
    int myfish;
    fish_t *Fish;
    int i;
    float t;
    float vsum, cx, cy;
    float v,a,v2;
    float maxv, maxa;
    float dt = 1;
    int k;
    
    CMMD_enable_host ();
    CMMD_receive_bc_from_host (&args, sizeof (struct arg_data));
    total_fish = args.total_fish;
    time_steps = args.time_steps;
    steps_per_display = args.steps_per_display;
    
    CMNA_timer_clear (0);
    CMNA_timer_start (0);
    
    /* determine # fish each processor gets */
    myfish = total_fish / CMMD_partition_size();
    
    /* allocate space for fish */
    Fish = (fish_t*) malloc(sizeof(fish_t)*myfish);
    
    fish_init(myfish, Fish);
    
    /* loop over all time steps */
    for (k=0,t=0;t<time_steps;t+=dt) {
	/* move all fish and accumulate local info.*/
	maxv = 0;
	maxa = 0;
	vsum = 0;
	for (i=0; i<myfish; ++i) {
	    /* compute the next position of each fish */
	    Fish[i].x_pos += Fish[i].x_vel * dt;
	    Fish[i].y_pos += Fish[i].y_vel * dt;
	    current(Fish[i].x_pos, Fish[i].y_pos, &cx, &cy);
	    Fish[i].x_vel += cx / Fish[i].mass * dt;
	    Fish[i].y_vel += cy / Fish[i].mass * dt;
	    
	    /* accumulate velocity and acceleration of fish */
	    v2 = SQR(Fish[i].x_vel)+SQR(Fish[i].y_vel);
	    v = sqrt(v2);
	    a = sqrt(SQR(cx/Fish[i].mass)+SQR(cy/Fish[i].mass));
	    maxv = MAX(maxv,v);
	    maxa = MAX(maxa,a);
	    vsum += v2;
	}
	/* send info. to host for display */
	if (++k > steps_per_display) {
	    k = 0;
	    CMMD_send(CMMD_host_node(),0,Fish,
		      sizeof(fish_t)*myfish);
	}
	
	/* compute global rms velocity */
	vsum = CMMD_reduce_to_host_float(vsum, CMMD_combiner_fadd);
	
	/* compute next time step */
	maxv = CMMD_reduce_to_host_float(maxv, CMMD_combiner_fmax);
	maxa = CMMD_reduce_to_host_float(maxa, CMMD_combiner_fmax);
	dt = MIN(0.1*maxv/maxa,1);
	
    }

    CMNA_timer_stop (0);
    CMMD_reduce_to_host_double (CMNA_timer_busy (0), CMMD_combiner_dmax);

    return;
}