/*
 *             Automatically Tuned Linear Algebra Software v2.0
 *                  This routine written by R. Clint Whaley                  
 *                     (C) 1997 All Rights Reserved
 *
 *                              NOTICE
 *
 * Permission to use, copy, modify, and distribute this software and
 * its documentation for any purpose and without fee is hereby granted
 * provided that the above copyright notice appear in all copies and
 * that both the copyright notice and this permission notice appear in
 * supporting documentation.
 *
 * Neither the University of Tennessee nor the Author make any
 * representations about the suitability of this software for any
 * purpose.  This software is provided ``as is'' without express or
 * implied warranty.
 *
 */
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

#define SAFE_ALPHA -3
#ifndef REPS
   #define REPS 1500
#endif

#ifndef L2SIZE
   #define L2SIZE 4194304
#endif

#define tname(pre, nam) my_join(pre, nam)
#define my_join(pre, nam) pre ## nam
#define tname(pre, nam) my_join(pre, nam)
#define my_join(pre, nam) pre ## nam
#define Mstr2(m) # m
#define Mstr(m) Mstr2(m)
#define Mmin(x, y) ( (x) > (y) ? (y) : (x) )

#ifdef FULLMM
   #define LANG 'M'
#elif !defined(LangF77)
   #define LANG 'C'
#else
   #define LANG 'F'
#endif
#if defined(sREAL)
   #include "atlas_sfc.h"
   #define PRE 's'
   #define pre s
   #define TYPE float
   #define SCALAR float
   #define ATL_sizeof sizeof(TYPE)
   #define TREAL
   #define SHIFT
   #define EPS 1.0e-7
#elif defined(dREAL)
   #include "atlas_dfc.h"
   #define PRE 'd'
   #define pre d
   #define TYPE double
   #define SCALAR double
   #define ATL_sizeof sizeof(TYPE)
   #define TREAL
   #define SHIFT
   #define EPS 1.0e-16
#elif defined (qREAL)
   #include "atlas_qfc.h"
   #define PRE 'q'
   #define pre q
   #define TYPE long double
   #define SCALAR long double
   #define ATL_sizeof sizeof(TYPE)
   #define TREAL
   #define SHIFT
#elif defined(sCPLX) || defined(cCPLX) || defined(cREAL)
   #include "atlas_sfc.h"
   #define PRE 'c'
   #define pre c
   #define TYPE float
   #define ATL_sizeof (sizeof(TYPE)<<1)
   #define SCALAR float *
   #define TCPLX
   #define SHIFT <<1
   #define EPS 1.0e-7
#else
   #include "atlas_dfc.h"
   #define PRE 'z'
   #define pre z
   #define TYPE double
   #define ATL_sizeof (sizeof(TYPE)<<1)
   #define SCALAR double*
   #define TCPLX
   #define SHIFT <<1
   #define EPS 1.0e-16
#endif

#ifndef RAND_MAX
   #define RAND_MAX ((unsigned long)(1<<30))
#endif
double dumb_rand()
{
   return( 0.5 - ((double)rand())/((double)RAND_MAX) );
}

#define Mabs(x) ( (x) < 0 ? (x) * -1 : (x) )
#ifndef MB
   #define MB NB
#endif
#ifndef KB
   #define KB NB
#endif

#ifndef csA
   #define csA 2
#endif
#ifndef csB
   #define csB 2
#endif

#if defined(tranAt) || defined(tranAT)
   #define TransA
#elif defined(tranAc) || defined(tranAC)
   #define ConjTransA
#else
   #define NoTransA
#endif
#if defined(tranBt) || defined(tranBT)
   #define TransB
#elif defined(tranBc) || defined(tranBC)
   #define ConjTransB
#else
   #define NoTransB
#endif
#ifdef NoTransA
   #define Ma MB
   #define Na KB
#else
   #define Ma KB
   #define Na MB
#endif
#ifdef NoTransB
   #define Mb KB
   #define Nb NB
#else
   #define Mb NB
   #define Nb KB
#endif

#ifdef IJK
   #undef IJK
   #define LOOPO IJK
#else
   #undef JIK
   #define  LOOPO JIK
#endif
#ifdef TransA
   #define TA T
   #define LDA2 KB
#elif defined(ConjTransA)
   #define TA C
   #define LDA2 KB
#else
   #define TA N
   #define LDA2 MB
#endif
#ifdef TransB
   #define TB T
   #define LDB2 NB
#elif defined(ConjTransB)
   #define TB C
   #define LDB2 NB
#else
   #define TB N
   #define LDB2 KB
#endif
#if (ALPHA == 1)
   #define ALPHAnam _a1
#elif (ALPHA == -1)
   #define ALPHAnam _an1
#elif (ALPHA == SAFE_ALPHA)
   #define ALPHAnam _aXX
#else
   #define ALPHAnam _aX
#endif
#if (BETA == 1)
   #define BETAnam _b1
   #define NBETAnam _bn1
#elif (BETA == 0)
   #define BETAnam _b0
   #define NBETAnam _b0
#elif (BETA == -1)
   #define BETAnam _bn1
   #define NBETAnam _b1
#else
   #define BETAnam _bX
   #define NBETAnam _bX
#endif


#define ppre tname(tname(ATL_,pre),LOOPO)
#define MNKnam tname(tname(tname(tname(MB0,x),NB0),x),KB0)
#define TRnam tname(TA, TB)
#define ldnam tname(tname(tname(tname(LDA,x),LDB),x),LDC)

#ifdef TREAL
   #define NBmm tname(tname(tname(tname(tname(ppre,MNKnam), TRnam),ldnam), ALPHAnam), BETAnam)

   void NBmm(const int, const int, const int, const SCALAR, const TYPE*, 
             const int, const TYPE*, const int, const SCALAR, TYPE*, const int);
#else
   #define NBmm0 tname(tname(tname(tname(ppre,MNKnam), TRnam),ldnam), ALPHAnam)

   void tname(NBmm0,BETAnam)(const int, const int, const int, const TYPE,
                             const TYPE*, const int, const TYPE*, const int, 
                             const TYPE, TYPE*, const int);
   void tname(NBmm0,_bn1)(const int, const int, const int, const TYPE, 
                        const TYPE*, const int, const TYPE*, const int, 
                        const TYPE, TYPE*, const int);
   void tname(NBmm0,_b1)(const int, const int, const int, const TYPE, 
                       const TYPE*, const int, const TYPE*, const int, 
                       const TYPE, TYPE*, const int);
#if csA == 1 && csB == 1
   #if csC == 2
      #define NBmm(m_, n_, k_, alp_, a_, lda_, b_, ldb_, bet_, c_, ldc_) \
      { \
         tname(NBmm0,NBETAnam)(m_, n_, k_, *(alp_), (a_), lda_, (b_), \
                               ldb_, -(*(bet_)), c_, ldc_); \
         tname(NBmm0,BETAnam)(m_, n_, k_, *(alp_), a_, lda_, (b_)+incb, ldb_, \
                              *(bet_), (c_)+1, ldc_); \
         tname(NBmm0,_bn1)(m_, n_, k_, *(alp_), (a_)+inca, lda_, (b_)+incb, \
                           ldb_, rnone, c_, ldc_); \
         tname(NBmm0,_b1)(m_, n_, k_, *(alp_), (a_)+inca, lda_, (b_), ldb_, \
                          rone, (c_)+1, ldc_); \
      }
   #elif csC == 1
      #define NBmm(m_, n_, k_, alp_, a_, lda_, b_, ldb_, bet_, c_, ldc_) \
      { \
         tname(NBmm0,NBETAnam)(m_, n_, k_, *(alp_), (a_), lda_, (b_), \
                               ldb_, -(*(bet_)), c_, ldc_); \
         tname(NBmm0,BETAnam)(m_, n_, k_, *(alp_), a_, lda_, (b_)+incb, ldb_, \
                              *(bet_), (c_)+incc, ldc_); \
         tname(NBmm0,_bn1)(m_, n_, k_, *(alp_), (a_)+inca, lda_, (b_)+incb, \
                           ldb_, rnone, c_, ldc_); \
         tname(NBmm0,_b1)(m_, n_, k_, *(alp_), (a_)+inca, lda_, (b_), ldb_, \
                          rone, (c_)+incc, ldc_); \
      }
   #endif
#elif csA == 2 && csB == 2 && csC == 2
   #define NBmm(m_, n_, k_, alp_, a_, lda_, b_, ldb_, bet_, c_, ldc_) \
   { \
      tname(NBmm0,NBETAnam)(m_, n_, k_, *(alp_), (a_)+1, lda_, (b_)+1, ldb_, \
                            -(*(bet_)), c_, ldc_); \
      tname(NBmm0,BETAnam)(m_, n_, k_, *(alp_), (a_)+1, lda_, b_, ldb_, \
                           *(bet_), (c_)+1, ldc_); \
      tname(NBmm0,_bn1)(m_, n_, k_, *(alp_), a_, lda_, b_, ldb_, \
                        rnone, c_, ldc_); \
      tname(NBmm0,_b1)(m_, n_, k_, *(alp_), a_, lda_, (b_)+1, ldb_, \
                       rone, (c_)+1, ldc_); \
   }
#endif
#if 0
   #undef NBmm
   #define NBmm(m_, n_, k_, alp_, a_, lda_, b_, ldb_, bet_, c_, ldc_) \
   { \
      zgemm_("T", "N", &(m_), &(n_), &(k_), alp_, a_, &(lda_), b_, \
             &(ldb_), bet_, c_, &(ldc_)); \
   }
#endif
#endif

void tst_mm(const int M, const int N, const int K, const SCALAR alpha, 
            const TYPE *A, const int lda0, const TYPE *B, const int ldb0,
            const SCALAR beta, TYPE *C, const int ldc0)
{
   int i, j, k;
   const int lda = lda0 SHIFT, ldb = ldb0 SHIFT, ldc = ldc0 SHIFT;
   register TYPE c0;
   #ifdef TREAL
      for (j=0; j < N; j++)
      {
         for (i=0; i < M; i++)
         {
            c0 = 0.0;
            for (k=0; k < K; k++)
            {
               #if defined(NoTransA) && defined(NoTransB)
                  c0 += A[i+k*lda] * B[j*ldb+k];
               #elif defined(NoTransA) && defined(TransB)
                  c0 += A[i+k*lda] * B[j+k*ldb];
               #elif defined(TransA) && defined(NoTransB)
                  c0 += A[i*lda+k] * B[j*ldb+k];
               #elif defined(TransA) && defined(TransB)
                  c0 += A[i*lda+k] * B[j+k*ldb];
               #endif
            }
            C[i+j*ldc] = beta*C[i+j*ldc] + alpha*c0;
         }
      }
   #else
      register TYPE cr, ci, ar, ai, br, bi;
/*
 *    If matrix is stored split into real & imaginary parts, allocate some
 *    matrices and intermix them for f77-like imaginary matrices
 */
      #if csA == 1 && csB == 1
         int inc, lda2=lda/2;
         TYPE *aa, *bb;
         aa = malloc(M*K*sizeof(TYPE)*2);
         bb = malloc(K*N*sizeof(TYPE)*2);
         assert(aa && bb);
         #ifdef NoTransA
            inc = K*(lda/2);
            for (k=0; k < K; k++)
            {
               j = k*lda;
               for (i=0; i < M; i++) 
               {
                  aa[k*lda+2*i]   = A[inc+k*lda2+i];
                  aa[k*lda+2*i+1] = A[k*lda2+i];
               }
            }
         #else
            inc = M*(lda/2);
            for (i=0; i < M; i++) 
            {
               for (k=0; k < K; k++)
               {
                  aa[i*lda+2*k]   = A[inc+i*lda2+k];
                  aa[i*lda+2*k+1] = A[i*lda2+k];
               }
            }
         #endif
         #ifdef NoTransB
            inc = N*(ldb/2);
            for (j=0; j < N; j++) 
            {
               for (k=0; k < K; k++)
               {
                  bb[j*lda+2*k]   = B[inc+j*lda2+k];
                  bb[j*lda+2*k+1] = B[j*lda2+k];
               }
            }
         #else
            inc = K*(ldb/2);
            for (k=0; k < K; k++)
            {
               for (j=0; j < N; j++) 
               {
                  bb[k*lda+2*j]   = B[inc+k*lda2+j];
                  bb[k*lda+2*j+1] = B[k*lda2+j];
               }
            }
         #endif
         A = (const TYPE *) aa;
         B = (const TYPE *) bb;
      #endif

      for (j=0; j < N; j++)
      {
         for (i=0; i < M; i++)
         {
            cr = ci = 0.0;
            for (k=0; k < K; k++)
            {
               #if defined(NoTransA) && defined(NoTransB)
                  ar = A[2*i+k*lda];
                  ai = A[2*i+k*lda+1];
                  br = B[j*ldb+2*k];
                  bi = B[j*ldb+2*k+1];
               #elif defined(NoTransA) && !defined(NoTransB)
                  ar = A[2*i+k*lda] ;
                  ai = A[2*i+k*lda+1];
                  br = B[2*j+k*ldb];
                  bi = B[2*j+k*ldb+1];
               #elif !defined(NoTransA) && defined(NoTransB)
                  ar = A[i*lda+k*2];
                  ai = A[i*lda+k*2+1];
                  br = B[j*ldb+k*2];
                  bi = B[j*ldb+k*2+1];
               #elif !defined(NoTransA) && !defined(NoTransB)
                  ar = A[i*lda+k*2];
                  ai = A[i*lda+k*2+1];
                  br = B[2*j+k*ldb];
                  bi = B[2*j+k*ldb+1];
               #endif
               #ifdef ConjTransA
                  ai = -ai;
               #endif
               #ifdef ConjTransB
                  bi = -bi;
               #endif
               cr += ar * br - ai * bi;
               ci += ar * bi + ai * br;
            }
/*
 *          Scale by alpha
 */
            ar = *alpha;
            ai = alpha[1];
            br = cr;
            bi = ci;
            cr =  br * ar;
            ci =  bi * ar;
            cr -= bi * ai;
            ci += br * ai;
/*
 *          Scale C by beta
 */
            br = *beta;
            bi = beta[1];
            ar = C[2*i+j*ldc];
            ai = C[2*i+j*ldc+1];
            C[2*i+j*ldc]   = ar*br - ai * bi;
            C[2*i+j*ldc+1] = ai*br + ar * bi;
/*
 *          Store answer back to C
 */
            C[2*i+j*ldc]   += cr;
            C[2*i+j*ldc+1] += ci;

         }
      }
      #if csA == 1 && csB == 1
         free(aa);
         free(bb);
      #endif
   #endif
}
int mmtst(void)
{
   char fnam[80];
#if defined(LDA) && LDA != 0
      const int lda=LDA;
#else
      const int lda=2*LDA2;
#endif
#if defined(LDB) && LDB != 0
   const int ldb=LDB;
#else
   const int ldb=2*LDB2;
#endif
#if defined(LDC) && LDC != 0
   const int ldc=LDC;
#else
   const int ldc=2*MB;
#endif
   int mA, nA, mB, nB;
   #ifdef TCPLX
      int inca, incb, incc;
      const TYPE one=1.0, none=(-1.0);
      #if (ALPHA == 1)
         TYPE alpha[2] = {1.0, 0.0};
      #elif (ALPHA == -1)
         TYPE alpha[2] = {-1.0, 0.0};
      #else
         TYPE alpha[2] = {2.3, 0.0};
      #endif
      #if (BETA == 1)
         TYPE beta[2] = {1.0, 0.0};
      #elif (BETA == -1)
         TYPE beta[2] = {-1.0, 0.0};
      #elif (BETA == 0)
         TYPE beta[2] = {0.0, 0.0};
      #else
         TYPE beta[2] = {1.3, 0.0};
      #endif
   #else
      #ifdef ALPHA
         TYPE alpha=ALPHA;
      #else
         TYPE alpha=1.0;
      #endif
      #ifdef BETA
         TYPE beta=BETA;
      #else
         TYPE beta=1.0;
      #endif
   #endif
   const TYPE rone=1.0, rnone=(-1.0);
   TYPE *C0, *C1, *A, *B;
   TYPE diff, tmp;
   int i, j, k, n, nerr;
   int M=MB, N=NB, K=KB;
   TYPE ErrBound;

   if (!M) M = MB0;
   if (!N) N = NB0;
   if (!K) K = KB0;
   #ifdef TREAL
      ErrBound = 2.0 * (Mabs(alpha) * 2.0*K*EPS + Mabs(beta) * EPS) + EPS;
   #else
      diff = Mabs(*alpha) + Mabs(alpha[1]);
      tmp = Mabs(*beta) + Mabs(beta[1]);
      ErrBound =  2.0 * (diff*8.0*K*EPS + tmp*EPS) + EPS;
   #endif
   #ifdef NoTransA
      mA = M;
      nA = K;
   #else
      mA = K;
      nA = M;
   #endif
   #ifdef NoTransB
      mB = K;
      nB = N;
   #else
      mB = N;
      nB = K;
   #endif
   #ifdef TCPLX
      inca = lda*nA;
      incb = ldb*nB;
   #endif
   C0 = malloc( (2*ldc*N + lda*nA + ldb*nB) * ATL_sizeof);
   assert(C0);
   C1 = C0 + (ldc * N SHIFT);
   A = C1 + (ldc * N SHIFT);
   B = A + (lda * nA SHIFT);
   for (n=lda*nA SHIFT, i=0; i < n; i++) A[i] = dumb_rand();
   for (n=ldb*nB SHIFT, i=0; i < n; i++) B[i] = dumb_rand();
   for (n=ldc*N SHIFT, i=0; i < n; i++) C0[i] = C1[i] = dumb_rand();
   tst_mm(M, N, K, alpha, A, lda, B, ldb, beta, C0, ldc);
   NBmm(M, N, K, alpha, A, lda, B, ldb, beta, C1, ldc);
   nerr = 0;
   for (j=0; j < N; j++)
   {
      for (i=0; i < M SHIFT; i++)
      {
         k = i + j*(ldc SHIFT);
         diff = C0[k] - C1[k];
         if (diff < 0.0) diff = -diff;
         if (diff > ErrBound) 
         {
            fprintf(stderr, "C(%d,%d) : expected=%f, got=%f\n", 
                    i, j, C0[k], C1[k]);
            nerr++;
         }
      }
   }
   return(nerr);
}

main()
{
   int ierr;
   ierr = mmtst();
   if (!ierr) fprintf(stdout, "PASSED TEST\n");
   exit(ierr);
}

