/*                                  tab:4
 * fish1.c - particle-like fish under current
 *
 * Author:          Xiaoye Li
 * Version:         1
 * Creation Date:   Sat Jun 10 13:14:19 PDT 1995
 * Filename:        fish1.c
 * History:
 */

#include <thread.h>
#include <math.h>
#include "barrier.h"

#define NTHREADS         4          /* number of threads to do simulation */
#define NFISH            10000      /* number of fish */
#define DISPLAY_SIZE       256      /* pixel size of display window */
#define DISPLAY_RANGE       20      /* display range of fish space */
#define STEPS_PER_DISPLAY   10      /* time steps between display of fish */
#define T_FINAL            10.0     /* simulation end time */

#define myID         (thr_self())   /* my thread ID */

/* Utility functions used in defining the current. */

#define DISTANCE(X,Y)  (sqrt((X)*(X) + (Y)*(Y))) /* distance from origin */
#define MAX(X,Y)       ((X) > (Y) ? (X) : (Y))   /* maximum of two numbers */


/* Two functions of x and y describe the external force due to the current.
   The functions here describe a whirlpool-like current. */

#define X_CURRENT(X,Y) (- 3.0*(Y)/MAX(DISTANCE(X,Y), 0.01) - X)
#define Y_CURRENT(X,Y) (+ 3.0*(X)/MAX(DISTANCE(X,Y), 0.01) - Y)


/*
    This structure holds information for a single fish, 
    including position, velocity, and mass.
*/
typedef struct {
    double x_pos, y_pos;
    double x_vel, y_vel;
    double mass;
} fish_t;

/*
 * Information passed onto to a thread to tell it what to do.
 */
typedef struct
{
        thread_t tid;   /* thread ID */
        int      chunk; /* chunk of the global array assigned to this thread */
} thrinfo_t;

typedef struct {
    volatile int  zeroed; /* zeroed==1 means accum is already set to zero */
    double        accum;
} g_reduce_t;

/*
   Globally shared variables
*/
fish_t      *fishes;
thrinfo_t   *thread_ptr = NULL;
mutex_t     mul_lock;
barrier_t   ba;
g_reduce_t  g_dmax;
g_reduce_t  g_dsum;

/*
    Place fish in their initial positions.
*/
void all_init_fish(int mychunk, int num_fish, fish_t fishes[])
{
    int i, n;
    double total_fish = NFISH;
    fish_t *fish;

    n = mychunk * num_fish;
    fish = &fishes[n];
    for (i = 0;  i < num_fish; i++, n++, fish++) {
        fish->x_pos = n*2.0/total_fish - 1.0;
        fish->y_pos = 0.0;
        fish->x_vel = 0.0;
        fish->y_vel = fish->x_pos;
        fish->mass = 1.0 + n/total_fish;
    }
}

double all_reduce_to_all_dmax(double dmax)
{
    barrier_wait(&ba);
    if ( myID == thread_ptr[0].tid ) {
	g_dmax.accum = 0.;
	g_dmax.zeroed = 1;
    }
    while ( !g_dmax.zeroed ) ;

    mutex_lock(&mul_lock);
    g_dmax.accum = MAX(g_dmax.accum, dmax);
    mutex_unlock(&mul_lock);
    
    barrier_wait(&ba);
    if ( myID == thread_ptr[0].tid )
	g_dmax.zeroed = 0;
    
    return (g_dmax.accum);
}

double all_reduce_to_all_dadd(double data)
{
    barrier_wait(&ba);
    if ( myID == thread_ptr[0].tid ) {
	g_dsum.accum = 0.;
	g_dsum.zeroed = 1;
    }
    while ( !g_dsum.zeroed ) ;
    
    mutex_lock(&mul_lock);
    g_dsum.accum += data;
    mutex_unlock(&mul_lock);
    
    barrier_wait(&ba);
    if ( myID == thread_ptr[0].tid )
	g_dsum.zeroed = 0;
    
    return (g_dsum.accum);
}

/*  
    Move fish one time step.
    Update positions, velocity, and acceleration.
    Return local computations.
*/
void all_move_fish(int mychunk, int num_fish, fish_t fish_list[],
		   double step, double *max_acc_ptr, double *max_speed_ptr,
                   double *sum_speed_sq_ptr)
{
    int i;
    fish_t *fish;
    double x_acc, y_acc;
    double cur_acc, max_acc = 0.0;
    double cur_speed, max_speed = 0.0;
    double speed_sq, sum_speed_sq  = 0.0;

    fish = &fishes[mychunk * num_fish];
    for (i = 0; i < num_fish; i++, fish++) {
        /* Update fish positions, calculate acceleration, and update
           velocity. */
        fish->x_pos += (fish->x_vel)*step;
        fish->y_pos += (fish->y_vel)*step;
        x_acc = X_CURRENT(fish->x_pos, fish->y_pos)/fish->mass;
        y_acc = Y_CURRENT(fish->x_pos, fish->y_pos)/fish->mass;
        fish->x_vel += x_acc*step;
        fish->y_vel += y_acc*step;

        /* Accumulate local max speed, accel and contribution to
           mean square velocity. */
        cur_acc = sqrt(x_acc*x_acc + y_acc*y_acc);
        max_acc = MAX(max_acc, cur_acc);
        speed_sq = (fish->x_vel)*(fish->x_vel) + (fish->y_vel)*(fish->y_vel);
        sum_speed_sq += speed_sq;
        cur_speed = sqrt(speed_sq);
        max_speed = MAX(max_speed, cur_speed);
    }

    /* Return local computation results. */
    *max_acc_ptr      = max_acc;
    *max_speed_ptr    = max_speed;
    *sum_speed_sq_ptr = sum_speed_sq;
}

/*
   Each thread begins execution by calling this function
*/
void *move_fish(void *arg_ptr)
{
    int    mychunk, num_fish;
    int    count = 1;
    double t = 0.0, dt = 0.01;
    double max_acc, max_speed, sum_speed_sq, mnsqvel;
    
    mychunk = *(int*)arg_ptr;
    num_fish   = NFISH / NTHREADS;
    all_init_fish(mychunk, num_fish, fishes);

    while (t < T_FINAL) {
        
        /* Update time. */
        t += dt;

        /* Move fish with the current and compute rms velocity. */
        all_move_fish(mychunk, num_fish, fishes, dt,
                      &max_acc, &max_speed, &sum_speed_sq);
        max_acc      = all_reduce_to_all_dmax(max_acc);
        max_speed    = all_reduce_to_all_dmax(max_speed);
        sum_speed_sq = all_reduce_to_all_dadd(sum_speed_sq);
        mnsqvel      = sqrt(sum_speed_sq/NFISH);

        /* Adjust dt based on maximum speed and acceleration--this
           simple rule tries to insure that no velocity will change
           by more than 10% */
        dt = 0.1*max_speed/max_acc;
        
        /* Print out time and rms velocity for this step. */
	if ( myID == thread_ptr[0].tid ) {
	    printf("%15.6lf %15.6lf;\n", t, mnsqvel);
	}
    }

    return 0;
}

/* 
   Simulate the movement of NFISH fish under a current.
*/

main()
{
    int    sync_type, i;
    
    /* Allocate a global shared array for the fish data set. */
    fishes = (fish_t *) malloc(NFISH * sizeof(fish_t));

    /* Initialize thread data structures */
    thread_ptr = (thrinfo_t *) malloc(NTHREADS * sizeof(thrinfo_t));
    sync_type = USYNC_PROCESS;
    mutex_init(&mul_lock, sync_type, NULL);
    barrier_init(&ba, NTHREADS, sync_type, NULL);

    thread_ptr[0].chunk = 0;
    thread_ptr[0].tid = myID;
    for (i = 1; i < NTHREADS; i++) {
        thread_ptr[i].chunk = i;
        if (thr_create(0, 0, move_fish, (void*)&thread_ptr[i].chunk,
                       0, &thread_ptr[i].tid)) {
            perror("thr_create");
            exit(1);
        }
    }

    /* Main thread starts simulation ... */
    i = 0;
    move_fish(&i);

    /* Termination */
    for (i = 1; i < NTHREADS; ++i)
	thr_join(thread_ptr[i].tid, NULL, NULL);

    return 0;
}