/*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/

#include <ctf.hpp>
#include <assert.h>
#include <stdlib.h>
  
int test_dft(int64_t const  n,
             cCTF_World    &wrld){
  int numPes, myRank;
  int64_t  np, i;
  int64_t * idx;
  std::complex<double> * data;
  std::complex<double> imag(0,1);
  MPI_Comm_size(MPI_COMM_WORLD, &numPes);
  MPI_Comm_rank(MPI_COMM_WORLD, &myRank);
  cCTF_Matrix DFT(n, n, SY, wrld);
  cCTF_Matrix IDFT(n, n, SY, wrld);

  DFT.get_local_data(&np, &idx, &data);

  for (i=0; i<np; i++){
    data[i] = exp(-2.*(idx[i]/n)*(idx[i]%n)*(M_PI/n)*imag);
  }
  DFT.write_remote_data(np, idx, data);
  //DFT.print(stdout);
  free(idx);
  free(data); 
  
  IDFT.get_local_data(&np, &idx, &data);

  for (i=0; i<np; i++){
    data[i] = (1./n)*exp(2.*(idx[i]/n)*(idx[i]%n)*(M_PI/n)*imag);
  }
  IDFT.write_remote_data(np, idx, data);
  //IDFT.print(stdout);
  free(idx);
  free(data); 

  /*DFT.contract(std::complex<double> (1.0, 0.0), DFT, "ij", IDFT, "jk", 
               std::complex<double> (0.0, 0.0), "ik");*/
  DFT["ik"] = DFT["ij"]*IDFT["jk"];

 
  DFT.get_local_data(&np, &idx, &data);
  int pass = 1;
  //DFT.print(stdout);
  for (i=0; i<np; i++){
    //printf("data[%lld] = %lf\n",idx[i],data[i].real());
    if (idx[i]/n == idx[i]%n){
      if (fabs(data[i].real() - 1.)>=1.E-9)
        pass = 0;
    } else  {
      if (fabs(data[i].real())>=1.E-9)
        pass = 0;
    }
  }
  MPI_Allreduce(MPI_IN_PLACE, &pass, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
  
  if (myRank == 0) {
    MPI_Reduce(MPI_IN_PLACE, &pass, 1, MPI_INT, MPI_MIN, 0, MPI_COMM_WORLD);
    if (pass)
      printf("{ DFT[\"ik\"] = DFT[\"ij\"]*IDFT[\"jk\"] } passed\n");
    else
      printf("{ DFT[\"ik\"] = DFT[\"ij\"]*IDFT[\"jk\"] } failed\n");
  } else 
    MPI_Reduce(&pass, MPI_IN_PLACE, 1, MPI_INT, MPI_MIN, 0, MPI_COMM_WORLD);

  MPI_Barrier(MPI_COMM_WORLD);

  free(idx);
  free(data);
  return pass;
}

#ifndef TEST_SUITE
/**
 * \brief Forms N-by-N DFT matrix A and inverse-dft iA and checks A*iA=I
 */
int main(int argc, char ** argv){
  int myRank, numPes, logn;
  int64_t n;
  
  MPI_Init(&argc, &argv);

  if (argc > 1){
    logn = atoi(argv[1]);
    if (logn<0) logn = 5;
  } else {
    logn = 5;
  }
  n = 1<<logn;


  {
    cCTF_World dw(argc, argv);
    int pass = test_dft(n, dw);
    assert(pass);
  }

  MPI_Finalize();
  
}

#endif