/******************************************************************
** Matrix Matrix Multiply Contest Test Driver:
** U.C. Berkeley, Department of EECS, Computer Science Division
** CS 267, Applications of Parallel Processing
** Spring 1996
** $Id: mm_contest.c,v 1.1 1995/01/23 12:29:27 bilmes Exp borisv $
**
*******************************************************************
*/


#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/resource.h>


#ifdef DOOURS
extern void mul_mfmf_mf(int,int,int,
			const double *A,
			const double *B,
			const double *C);
#define MUL_MFMF_MF(size,A,B,C) mul_mfmf_mf(size,size,size,A,B,C)
#elif defined(DOESSL)
const char N = 'N';
const double one = 1.0;
#define MUL_MFMF_MF(size,A,B,C) dgemm(&N,&N, \
				      &size,&size,&size, \
				      &one, \
				      B,&size, \
				      A,&size, \
				      &one, \
				      C,&size)
#else
extern void mul_mfmf_mf(int matdim,
			const double *A,
			const double *B,
			const double *C);
#define MUL_MFMF_MF(size,A,B,C) mul_mfmf_mf(size,A,B,C)
#endif


#define NUM_CORRECTNESS_CHECKS 10
#define MAX_ERROR 2.0

#define TEST_RUNS 10  /* number of runs of each size */

/* Arbitrary sized tests */
#define NUM_ATESTS (sizeof(atest_sizes)/sizeof(int))
int atest_sizes[] = {   23,  43,  61,  79,  99, 119, 151 };
int atest_iters[] = { 4000, 900, 500, 220, 130,  70,  35 };
/* The above iters are choosen so that dgemm from ESSL runs in >= .75 seconds. */
/* They are not yet set in stone, and they may change before the contest */

/* quad-word aligned tests */
#define NUM_QTESTS (sizeof(qtest_sizes)/sizeof(int))
int qtest_sizes[] = {    16,   32,  64, 128, 256 };
int qtest_iters[] = { 11000, 5000, 760, 110,  11 };
/* The above iters are choosen so that dgemm from ESSL runs in >= .75 seconds. */
/* They are not yet set in stone, and they may change before the contest */


extern double drand48();
extern unsigned short* seed48();
extern int getrussage(int,struct rusage*);
struct rusage rus; /* starting time */
struct rusage rue; /* ending time */
#define START_TIMING  getrusage(RUSAGE_SELF,&rus);
#define STOP_TIMING   getrusage(RUSAGE_SELF,&rue);

#define ABS(val) ((val)>0?(val):-(val))
#define MIN(a,b) ((a)<(b)?(a):(b))
#define MAX(a,b) ((a)>(b)?(a):(b))
#define SQ(a)    ((a)*(a))

double reportTiming() {
  struct timeval utime;
  utime.tv_sec = rue.ru_utime.tv_sec - rus.ru_utime.tv_sec ;
  if ( rue.ru_utime.tv_usec < rus.ru_utime.tv_usec ) {
    utime.tv_sec--;
    utime.tv_usec = 1000000l - rus.ru_utime.tv_usec +
      rue.ru_utime.tv_usec;
  } else
    utime.tv_usec = rue.ru_utime.tv_usec -
      rus.ru_utime.tv_usec ;
  return ((double)utime.tv_sec + (double)utime.tv_usec*1e-6);
}

void
myseed()
{
  int i;
  unsigned short seed16v[3];
  for (i=0;i<3;i++)
     seed16v[i] = time(0);
  seed48(seed16v);
}


/*
** A naive matrix multiply routine.
** Used to test for correctness.
*/
void
naive_mm(int Sm,int Sk,int Sn,
	 const double *A,const double *B,double *C)
{
  int i,j,k;
  for (i=0;i<Sm;i++)
    for (j=0;j<Sn;j++) {
      double tmp = C[i*Sn+j];
      for (k=0;k<Sk;k++)
	tmp += A[i*Sk+k]*B[k*Sn+j];
      C[i*Sn+j] = tmp;
    }
}

/* return a random number uniformly in [l,u] inclusive, l < u */
int rrand(int l, int u) 
{ return (l + (int)((1+u-l)*drand48())); }


double l1_norm(double *mat,int rows,int cols)
{
  double sum=0;
  int i;
  for (i=0;i<rows*cols;i++) {
    double val = *mat++;
    sum += ABS(val);
  }
  return sum;
}

double l1_norm_diff(double *mat1,double *mat2,int rows,int cols)
{
  double sum=0;
  int i;
  for (i=0;i<rows*cols;i++) {
    double val = *mat1++ - *mat2++;
    sum += ABS(val);
  }
  return sum;
}


/*
** error: Error formula to compare two 
**        matrix multiplies.
** norm(C1-C2)/(macheps*norm(A)*norm(B)),
** Ci=float(A*B)
** macheps=2^(-24) in single prec.
**        =2^(-53) in double prec.
*/
double error(double *mat1,double *mat2,int rows,int cols)
{
  const double macheps = 1.110223024625157e-16; /* = 2^(-53) */
  return l1_norm_diff(mat1,mat2,rows,cols)/
    (macheps*l1_norm(mat1,rows,cols)*l1_norm(mat2,rows,cols));
}


void mat_init(double *mat,int rows,int cols)
{
  int i;
  for (i=0;i<rows*cols;i++)
    *mat++ = 2.0*drand48()-1.0;
}



int main()
{
  double *A,*B,*C,*cC;
  int i,j;
  int test;

  myseed();

  /* check for correctness */
  fprintf(stderr,"Checking for correctness on sizes:"); fflush(stderr);
  for (i=0;i<NUM_CORRECTNESS_CHECKS;i++) {
    double err;
    int matdim = rrand(1,256);
    fprintf(stderr," %d",matdim); fflush(stderr);
    A = (double*)malloc(SQ(matdim)*sizeof(double));
    B = (double*)malloc(SQ(matdim)*sizeof(double));
    C = (double*)malloc(SQ(matdim)*sizeof(double));
    cC = (double*)malloc(SQ(matdim)*sizeof(double));
    bzero((char*)C,sizeof(double)*SQ(matdim));
    bzero((char*)cC,sizeof(double)*SQ(matdim));

    mat_init(A,matdim,matdim);
    mat_init(B,matdim,matdim);
    naive_mm(matdim,matdim,matdim,
	     A,B,cC);
    MUL_MFMF_MF(matdim,A,B,C);
    if ((err = error(C,cC,matdim,matdim)) > MAX_ERROR) {
      printf("Error for test case %dx%d is %f > %f. DISQUALIFIED!!!\n",
	     matdim,matdim,err,MAX_ERROR);
      exit(0);
    }
    free(A); free(B); free(C); free(cC);
  }
  fprintf(stderr,"\n"); fflush(stderr);

  /* quad-word aligned sizes */
  fprintf(stderr,"Checking quad-word aligned sizes\n"); fflush(stderr);
  for (test=0;test<NUM_QTESTS;test++) {
    int matdim = qtest_sizes[test];
    double max_mflops = 0.0;
    
    int run;

    /* make sure these are quad-word (i.e., 16-byte) aligned */
    A = (double*)malloc((SQ(matdim)+1)*sizeof(double));
    B = (double*)malloc((SQ(matdim)+1)*sizeof(double));
    C = (double*)malloc((SQ(matdim)+1)*sizeof(double));
    if (((unsigned)A) & 0x8)
      A = (double*)(((unsigned)A)+0x8);
    if (((unsigned)B) & 0x8)
      B = (double*)(((unsigned)B)+0x8);
    if (((unsigned)C) & 0x8)
      C = (double*)(((unsigned)C)+0x8);
    mat_init(A,matdim,matdim);
    mat_init(B,matdim,matdim);


    for (run=0;run<TEST_RUNS;run++) {
      int iter;
      double mflops;
      double utime;
      const num_iters = qtest_iters[test];

      bzero((char*)C,sizeof(double)*SQ(matdim));
      START_TIMING;
      for (iter=0;iter<num_iters;iter++) {
	/* iteratively accumulate into C */
	MUL_MFMF_MF(matdim,A,B,C);
      }
      STOP_TIMING;
      utime = reportTiming();
      mflops = 2.0*matdim*matdim*matdim*num_iters*1e-6/utime;
      /* printf("%f %g\n",mflops,utime); fflush(stdout); */
      if (mflops > max_mflops)
	max_mflops = mflops;
    }
    printf("%d %f\n",matdim,max_mflops); fflush(stdout);
    free(A); free(B); free(C); 
  }

  
  /* arbitrary sizes */
  fprintf(stderr,"Checking arbitrary sizes\n"); fflush(stderr);
  for (test=0;test<NUM_ATESTS;test++) {
    int matdim = atest_sizes[test];
    double max_mflops = 0.0;
    
    int run;

    A = (double*)malloc(SQ(matdim)*sizeof(double));
    B = (double*)malloc(SQ(matdim)*sizeof(double));
    C = (double*)malloc(SQ(matdim)*sizeof(double));
    mat_init(A,matdim,matdim);
    mat_init(B,matdim,matdim);

    for (run=0;run<TEST_RUNS;run++) {
      int iter;
      double mflops;
      double utime;
      const num_iters = atest_iters[test];

      bzero((char*)C,sizeof(double)*SQ(matdim));
      START_TIMING;
      for (iter=0;iter<num_iters;iter++) {
	/* iteratively accumulate into C */
	MUL_MFMF_MF(matdim,A,B,C);
      }
      STOP_TIMING;
      utime = reportTiming();
      mflops = 2.0*matdim*matdim*matdim*num_iters*1e-6/utime;
      if (mflops > max_mflops)
	max_mflops = mflops;
    }
    printf("%d %f\n",matdim,max_mflops); fflush(stdout);
    free(A); free(B); free(C); 
  }
}