#ifndef __FUNCTIONS_H__
#define __FUNCTIONS_H__

#include "../scaling/scaling.h"
#include "../summation/summation.h"
#include "../contraction/contraction.h"


namespace CTF {

/**
 * @defgroup CTF_func CTF functions
 * \brief user-defined function interface
 * @addtogroup CTF_func
 * @{
 */
  class Idx_Tensor;

  /**
   * \brief custom scalar function on tensor: e.g. A["ij"] = f(A["ij"])
   */
  template<typename dtype=double>
  class Endomorphism : public CTF_int::endomorphism {
    public:
      /**
       * \brief function signature for element-wise operation a=f(a)
       */
      //dtype (*f)(dtype);
      std::function<void(dtype&)> f;
     
      /**
       * \brief constructor takes function pointer
       * \param[in] f_ scalar function: (type) -> (type)
       */
      Endomorphism(std::function<void(dtype&)> f_){ f = f_; }
      /**
       * \brief default constructor
       */
      Endomorphism(){}

      /**
       * \brief apply function f to value stored at a
       * \param[in,out] a pointer to operand that will be cast to dtype
       *                  is set to result of applying f on value at a
       */
      void apply_f(char * a) const { f(((dtype*)a)[0]); }
  };


  /**
   * \brief custom function f : X -> Y to be applied to tensor elemetns: 
   *          e.g. B["ij"] = f(A["ij"])
   */
  template<typename dtype_A=double, typename dtype_B=dtype_A>
  class Univar_Function : public CTF_int::univar_function {
    public:
      /**
       * \brief function signature for element-wise multiplication, compute b=f(a)
       */
      //dtype_B (*f)(dtype_A);
      std::function<dtype_B(dtype_A)> f;
      
      /**
       * \brief constructor takes function pointers to compute B=f(A));
       * \param[in] f_ linear function (type_A)->(type_B)
       */
      Univar_Function(std::function<dtype_B(dtype_A)> f_){ f = f_; }

      /** 
       * \brief evaluate B=f(A) 
       * \param[in] A operand tensor with pre-defined indices 
       * return f(A) output tensor 
       */
      //Idx_Tensor operator()(Idx_Tensor const  & A);
      
      /**
       * \brief apply function f to value stored at a
       * \param[in] a pointer to operand that will be cast to dtype 
       * \param[in,out] result &f(*a) of applying f on value of (different type) on a
       */
      void apply_f(char const * a, char * b) const { ((dtype_B*)b)[0]=f(((dtype_A*)a)[0]); }
      
      /**
       * \brief compute b = b+f(a)
       * \param[in] a pointer to operand that will be cast to dtype 
       * \param[in,out] result &f(*a) of applying f on value of (different type) on a
       * \param[in] sr_B algebraic structure for b, needed to do add
       */
      void acc_f(char const * a, char * b, CTF_int::algstrct const * sr_B) const {
        dtype_B tb=f(((dtype_A*)a)[0]); 
        sr_B->add(b, (char const *)&tb, b);
      }

  };


  /**
   * \brief custom function f : (X * Y) -> X applied on two tensors as summation: 
   *          e.g. B["ij"] = f(A["ij"],B["ij"])
   */
  template<typename dtype_A=double, typename dtype_B=dtype_A>
  class Univar_Transform : public CTF_int::univar_function {
    public:
      /**
       * \brief function signature for element-wise multiplication, compute b=f(a)
       */
      //void (*f)(dtype_A, dtype_B &);
      std::function<void(dtype_A, dtype_B &)> f;
      
      /**
       * \brief constructor takes function pointers to compute B=f(A));
       * \param[in] f_ linear function (type_A)->(type_B)
       */
      Univar_Transform(std::function<void(dtype_A, dtype_B &)> f_){ f = f_; }

      /** 
       * \brief evaluate B=f(A) 
       * \param[in] A operand tensor with pre-defined indices 
       * return f(A) output tensor 
       */
      //Idx_Tensor operator()(Idx_Tensor const  & A);
      
      /**
       * \brief apply function f to value stored at a, for an accumulator, this is the same as acc_f below
       * \param[in] a pointer to operand that will be cast to dtype 
       * \param[in,out] result &f(*a) of applying f on value of (different type) on a
       */
      void apply_f(char const * a, char * b) const { acc_f(a,b,NULL); }

       /**
       * \brief compute f(a,b)
       * \param[in] a pointer to the accumulated operand 
       * \param[in,out] value that is accumulated to
       * \param[in] sr_B algebraic structure for b, here is ignored
       */
      void acc_f(char const * a, char * b, CTF_int::algstrct const * sr_B) const {
        f(((dtype_A*)a)[0], ((dtype_B*)b)[0]);
      }

      bool is_accumulator() const { return true; }
  };


  /**
   * \brief custom bilinear function on two tensors: 
   *          e.g. C["ij"] = f(A["ik"],B["kj"])
   */
  template<typename dtype_A=double, typename dtype_B=dtype_A, typename dtype_C=dtype_A>
  class Bivar_Function : public CTF_int::bivar_function {
    public:
      /**
       * \brief function signature for element-wise multiplication, compute C=f(A,B)
       */
      //dtype_C (*f)(dtype_A, dtype_B);
      std::function<dtype_C (dtype_A, dtype_B)> f;
     
      /**
       * \brief constructor takes function pointers to compute C=f(A,B);
       * \param[in] f_ bilinear function (type_A,type_B)->(type_C)
       */
      Bivar_Function(std::function<dtype_C (dtype_A, dtype_B)> f_){ f=f_; }

      /**
       * \brief default constructor sets function pointer to NULL
       */
      Bivar_Function();

      /** 
       * \brief evaluate C=f(A,B) 
       * \param[in] A left operand tensor with pre-defined indices 
       * \param[in] B right operand tensor with pre-defined indices
       * \return C output tensor
      */
      //Idx_Tensor operator()(Idx_Tensor const  & A, 
      //                      Idx_Tensor const  & B);

      /**
       * \brief compute c = f(a,b)
       * \param[in] a pointer to operand that will be cast to dtype 
       * \param[in] b pointer to operand that will be cast to dtype 
       * \param[in,out] result c+f(*a,b) of applying f on value of (different type) on a
       */
      void apply_f(char const * a, char const * b, char * c) const { 
        ((dtype_C*)c)[0] = f(((dtype_A const*)a)[0],((dtype_B const*)b)[0]); 
      }

      /**
       * \brief compute c = c+ f(a,b)
       * \param[in] a pointer to operand that will be cast to dtype 
       * \param[in] b pointer to operand that will be cast to dtype 
       * \param[in,out] result c+f(*a,b) of applying f on value of (different type) on a
       * \param[in] sr_C algebraic structure for b, needed to do add
       */
      void acc_f(char const * a, char const * b, char * c, CTF_int::algstrct const * sr_C) const { 
        dtype_C tmp;
        tmp = f(((dtype_A const*)a)[0],((dtype_B const*)b)[0]);
        sr_C->add(c, (char const *)&tmp, c); 
      }


  };

  /**
   * \brief custom function f : (X * Y) -> X applied on two tensors as summation: 
   *          e.g. B["ij"] = f(A["ij"],B["ij"])
   */
  template<typename dtype_A=double, typename dtype_B=dtype_A, typename dtype_C=dtype_A>
  class Bivar_Transform : public CTF_int::bivar_function {
    public:
      /**
       * \brief function signature for element-wise multiplication, compute b=f(a)
       */
      //void (*f)(dtype_A, dtype_B &);
      std::function<void(dtype_A, dtype_B, dtype_C &)> f;
      
      /**
       * \brief constructor takes function pointers to compute B=f(A));
       * \param[in] f_ linear function (type_A)->(type_B)
       */
      Bivar_Transform(std::function<void(dtype_A, dtype_B, dtype_C &)> f_){ f = f_; }

      /** 
       * \brief evaluate B=f(A) 
       * \param[in] A operand tensor with pre-defined indices 
       * return f(A) output tensor 
       */
      //Idx_Tensor operator()(Idx_Tensor const  & A);
       /**
       * \brief compute f(a,b)
       * \param[in] a pointer to the accumulated operand 
       * \param[in,out] value that is accumulated to
       * \param[in] sr_B algebraic structure for b, here is ignored
       */
      void acc_f(char const * a, char const * b, char * c, CTF_int::algstrct const * sr_B) const {
        f(((dtype_A*)a)[0], ((dtype_B*)b)[0], ((dtype_C*)c)[0]);
      }
      
      /**
       * \brief apply function f to value stored at a, for an accumulator, this is the same as acc_f below
       * \param[in] a pointer to operand that will be cast to dtype 
       * \param[in,out] result &f(*a) of applying f on value of (different type) on a
       */
      void apply_f(char const * a, char const * b, char * c) const { acc_f(a,b,c,NULL); }


      bool is_accumulator() const { return true; }
  };




  template<typename dtype_A=double, typename dtype_B=dtype_A, typename dtype_C=dtype_A>
  class Function {
    public:
      bool is_univar;
      Univar_Function<dtype_A, dtype_B> * univar;
      bool is_bivar;
      Bivar_Function<dtype_A, dtype_B, dtype_C> * bivar;

      Function(std::function<dtype_B(dtype_A)> f_){
        is_univar = true;
        is_bivar = false;
        univar = new Univar_Function<dtype_A, dtype_B>(f_);
      }

 
      Function(std::function<dtype_C(dtype_A,dtype_B)> f_){
        is_univar = false;
        is_bivar = true;
        bivar = new Bivar_Function<dtype_A, dtype_B, dtype_C>(f_);
      }

      CTF_int::Unifun_Term operator()(CTF_int::Term const & A) const {
        assert(is_univar);
        return univar->operator()(A);
      }
 
      CTF_int::Bifun_Term operator()(CTF_int::Term const & A, CTF_int::Term const & B) const {
        assert(is_bivar);
        return bivar->operator()(A,B);
      }
      
      operator Univar_Function<dtype_A, dtype_B>() const {
        assert(is_univar);
        return *univar;
      }
      
      operator Bivar_Function<dtype_A, dtype_B, dtype_C>() const {
        assert(is_bivar);
        return *bivar;
      }

      ~Function(){
        if (is_univar) delete(univar);
        if (is_bivar) delete(bivar);
      }
  };

  
  template<typename dtype_A=double, typename dtype_B=dtype_A, typename dtype_C=dtype_A>
  class Transform {
    public:
      bool is_endo;
      Endomorphism<dtype_A> * endo;
      bool is_univar;
      Univar_Transform<dtype_A, dtype_B> * univar;
      bool is_bivar;
      Bivar_Transform<dtype_A, dtype_B, dtype_C> * bivar;

      Transform(std::function<void(dtype_A&)> f_){
        is_endo = true;
        is_univar = false;
        is_bivar = false;
        endo = new Endomorphism<dtype_A>(f_);
      }
      
      Transform(std::function<void(dtype_A, dtype_B&)> f_){
        is_endo = false;
        is_univar = true;
        is_bivar = false;
        univar = new Univar_Transform<dtype_A, dtype_B>(f_);
      }
      
      Transform(std::function<void(dtype_A, dtype_B, dtype_C&)> f_){
        is_endo = false;
        is_univar = false;
        is_bivar = true;
        bivar = new Bivar_Transform<dtype_A, dtype_B, dtype_C>(f_);
      }


      ~Transform(){
        if (is_endo) delete endo;
        if (is_univar) delete univar;
        if (is_bivar) delete bivar;
      }

      void operator()(CTF_int::Term const & A) const {
        assert(is_endo);
        endo->operator()(A);
      }
 
      void operator()(CTF_int::Term const & A, CTF_int::Term const & B) const {
        assert(is_univar);
        univar->operator()(A,B);
      }
 
      void operator()(CTF_int::Term const & A, CTF_int::Term const & B, CTF_int::Term const & C) const {
        assert(is_bivar);
        bivar->operator()(A,B,C);
      }
      
      operator Bivar_Transform<dtype_A, dtype_B, dtype_C>(){
        assert(is_bivar);
        return *bivar;
      }

      operator Univar_Transform<dtype_A, dtype_B>(){
        assert(is_univar);
        return *univar;
      }
      
      operator Endomorphism<dtype_A>(){
        assert(is_endo);
        return *endo;
      }
      
      bool is_accumulator() const { return true; }
  };
  

/**
 * @}
 */
}

#endif