#ifndef __MONOID_H__
#define __MONOID_H__

namespace CTF {
  template <typename dtype>
  dtype default_add(dtype a, dtype b){
    return a+b;
  }
  
  template <typename dtype, void (*fxpy)(int, dtype const *, dtype *)>
  void default_mxpy(void * X,
                    void * Y,
                    int    n){
    fxpy(n, (dtype const*)X, (dtype *)Y);
  }

  template <typename dtype>
  void default_fxpy(int           n,
                    dtype const * X,
                    dtype *       Y){
    for (int i=0; i<n; i++){
      Y[i] = X[i] + Y[i];
    }
  }
  template <typename dtype>
  MPI_Datatype get_default_mdtype(){
    MPI_Datatype newtype;
    MPI_Type_contiguous(sizeof(dtype), MPI_CHAR, &newtype);
    return newtype;
  }
  template <>
  MPI_Datatype get_default_mdtype<char>(){ return MPI_CHAR; }
  template <>
  MPI_Datatype get_default_mdtype<bool>(){ return MPI_C_BOOL; }
  template <>
  MPI_Datatype get_default_mdtype<int>(){ return MPI_INT; }
  template <>
  MPI_Datatype get_default_mdtype<int64_t>(){ return MPI_INT64_T; }
  template <>
  MPI_Datatype get_default_mdtype<unsigned int>(){ return MPI_UNSIGNED; }
  template <>
  MPI_Datatype get_default_mdtype<uint64_t>(){ return MPI_UINT64_T; }
  template <>
  MPI_Datatype get_default_mdtype<float>(){ return MPI_FLOAT; }
  template <>
  MPI_Datatype get_default_mdtype<double>(){ return MPI_DOUBLE; }
  template <>
  MPI_Datatype get_default_mdtype<long double>(){ return MPI_LONG_DOUBLE; }
  template <>
  MPI_Datatype get_default_mdtype< std::complex<float> >(){ return MPI_COMPLEX; }
  template <>
  MPI_Datatype get_default_mdtype< std::complex<double> >(){ return MPI_DOUBLE_COMPLEX; }

  template <typename dtype>
  MPI_Op get_default_maddop(){
    //FIXME: assumes + operator commutes
    MPI_Op newop;
    MPI_Op_create(&default_mxpy<dtype,default_fxpy<dtype>>, 1, &newop);
    return newop;
  }

  //c++ sucks...
  template <> MPI_Op get_default_maddop<char>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<bool>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<int>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<int64_t>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<unsigned int>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<uint64_t>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<float>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<double>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop<long double>(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop< std::complex<float> >(){ return MPI_SUM; }
  template <> MPI_Op get_default_maddop< std::complex<double> >(){ return MPI_SUM; }
  
  template <typename dtype>
  MPI_Op get_maddop(void (*fxpy)(int, dtype const *, dtype *)){
    //FIXME: assumes + operator commutes
    MPI_Op newop;
    MPI_Op_create(&default_mxpy<dtype, fxpy>, 1, &newop);
    return newop;
  }

  /**
   * A Monoid is a Set equipped with a binary addition operator '+' or a custom function
   *   addition must have an identity and be associative, does not need to be commutative
   *   special case (parent) of a semiring, group, and ring
   */
  template <typename dtype=double, bool is_ord=true> 
  class Monoid : public Set<dtype, is_ord> {
    public:
      dtype taddid;
      dtype (*fadd)(dtype a, dtype b);
      MPI_Datatype tmdtype;
      MPI_Op       taddmop;

      Monoid(Monoid const & other) : Set<dtype, is_ord>(other) {
        this->taddid  = other.taddid;
        this->fadd    = other.fadd;
        this->tmdtype = other.tmdtype;
        this->taddmop = other.taddmop;
      }
      
      virtual CTF_int::algstrct * clone() const {
        return new Monoid<dtype, is_ord>(*this);
      }
      Monoid() : Set<dtype, is_ord>() {
        taddid  = (dtype)0;
        fadd    = &default_add<dtype>;
        taddmop = get_default_maddop<dtype>();
        tmdtype = get_default_mdtype<dtype>();
      } 

      Monoid(dtype taddid_,
             dtype (*fadd_)(dtype a, dtype b),
             MPI_Op addmop_)
              : Set<dtype, is_ord>() {
        taddid  = taddid_;
        fadd    = fadd_;
        taddmop = addmop_;
        tmdtype = get_default_mdtype<dtype>();
      }

      void add(char const * a, 
               char const * b,
               char *       c) const {
        ((dtype*)c)[0] = fadd(((dtype*)a)[0],((dtype*)b)[0]);
      }
 
      char const * addid() const {
        return (char const *)&taddid;
      }

      MPI_Op addmop() const {
        return taddmop;        
      }
      
      MPI_Datatype mdtype() const {
        return tmdtype;        
      }

      void axpy(int          n,
                char const * alpha,
                char const * X,
                int          incX,
                char       * Y,
                int          incY) const {
        ASSERT(alpha == NULL);
        for (int64_t i=0; i<n; i++){
          add(X+sizeof(dtype)*i*incX,Y+sizeof(dtype)*i*incY,Y+sizeof(dtype)*i*incY);
        }
      }

  };
}

#include "group.h"
#endif