#ifndef NGCORE_MPIWRAPPER_HPP
#define NGCORE_MPIWRAPPER_HPP

#ifdef PARALLEL
#define OMPI_SKIP_MPICXX
#include <mpi.h>
#endif

#include "exception.hpp"

namespace ngcore
{

#ifdef PARALLEL

  template <class T> struct MPI_typetrait  { };
  
  template <> struct MPI_typetrait<int> {
    static MPI_Datatype MPIType () { return MPI_INT; } };

  template <> struct MPI_typetrait<short> {
    static MPI_Datatype MPIType () { return MPI_SHORT; } };

  template <> struct MPI_typetrait<char> {
    static MPI_Datatype MPIType () { return MPI_CHAR; } };  

  template <> struct MPI_typetrait<unsigned char> {
    static MPI_Datatype MPIType () { return MPI_CHAR; } };  

  template <> struct MPI_typetrait<size_t> {
    static MPI_Datatype MPIType () { return MPI_UINT64_T; } };

  template <> struct MPI_typetrait<double> {
    static MPI_Datatype MPIType () { return MPI_DOUBLE; } };

  template <> struct MPI_typetrait<bool> {
    static MPI_Datatype MPIType () { return MPI_C_BOOL; } };

  
  template <class T, class T2 = decltype(MPI_typetrait<T>::MPIType())>
  inline MPI_Datatype GetMPIType () {
    return MPI_typetrait<T>::MPIType();
  }
  

  class NgMPI_Comm
  {
  protected:
    MPI_Comm comm;
    bool valid_comm;
    int * refcount;
    int rank, size;
  public:
    NgMPI_Comm ()
      : valid_comm(false), refcount(nullptr), rank(0), size(1)
    { ; }

    NgMPI_Comm (MPI_Comm _comm, bool owns = false)
      : comm(_comm), valid_comm(true)
    {
      if (!owns)
        refcount = nullptr;
      else
        refcount = new int{1};
      
      MPI_Comm_rank(comm, &rank);
      MPI_Comm_size(comm, &size);
    }
    
    NgMPI_Comm (const NgMPI_Comm & c)
      : comm(c.comm), valid_comm(c.valid_comm), refcount(c.refcount),
        rank(c.rank), size(c.size)
    {
      if (refcount) (*refcount)++;
    }

    NgMPI_Comm (NgMPI_Comm && c)
      : comm(c.comm), valid_comm(c.valid_comm), refcount(c.refcount),
        rank(c.rank), size(c.size)
    {
      c.refcount = nullptr;
    }
    
    ~NgMPI_Comm()
    {
      if (refcount)
        if (--(*refcount) == 0)
          MPI_Comm_free(&comm);
    }

    NgMPI_Comm & operator= (const NgMPI_Comm & c)
    {
      if (refcount)
        if (--(*refcount) == 0)
          MPI_Comm_free(&comm);

      refcount = c.refcount;
      if (refcount) (*refcount)++;      
      comm = c.comm;
      valid_comm = c.valid_comm;
      size = c.size;
      rank = c.rank;
      return *this;
    }
    
    class InvalidCommException : public Exception {
    public:
      InvalidCommException() : Exception("Do not have a valid communicator") { ; }
    };
    
    operator MPI_Comm() const {
      if (!valid_comm) throw InvalidCommException();
      return comm;
    }

    int Rank() const { return rank; }
    int Size() const { return size; }
    void Barrier() const {
      if (size > 1) MPI_Barrier (comm);
    }
    

    /** --- blocking P2P --- **/

    template<typename T, typename T2 = decltype(GetMPIType<T>())>
    void Send (T & val, int dest, int tag) const {
      MPI_Send (&val, 1, GetMPIType<T>(), dest, tag, comm);
    }
    
    template<typename T, typename T2 = decltype(GetMPIType<T>())> 
    void Recv (T & val, int src, int tag) const {
      MPI_Recv (&val, 1, GetMPIType<T>(), src, tag, comm, MPI_STATUS_IGNORE);
    }


    /** --- non-blocking P2P --- **/
    
    template<typename T, typename T2 = decltype(GetMPIType<T>())> 
    MPI_Request ISend (T & val, int dest, int tag) const
    {
      MPI_Request request;
      MPI_Isend (&val, 1, GetMPIType<T>(), dest, tag, comm, &request);
      return request;
    }
    
    template<typename T, typename T2 = decltype(GetMPIType<T>())> 
    MPI_Request IRecv (T & val, int dest, int tag) const
    {
      MPI_Request request;
      MPI_Irecv (&val, 1, GetMPIType<T>(), dest, tag, comm, &request);
      return request;
    }
    
    /** --- collectives --- **/

    template <typename T, typename T2 = decltype(GetMPIType<T>())> 
    T Reduce (T d, const MPI_Op & op, int root = 0) const
    {
      if (size == 1) return d;
      
      T global_d;
      MPI_Reduce (&d, &global_d, 1, GetMPIType<T>(), op, root, comm);
      return global_d;
    }
    
    template <typename T, typename T2 = decltype(GetMPIType<T>())> 
    T AllReduce (T d, const MPI_Op & op) const
    {
      if (size == 1) return d;
      
      T global_d;
      MPI_Allreduce ( &d, &global_d, 1, GetMPIType<T>(), op, comm);
      return global_d;
    }

    template <typename T, typename T2 = decltype(GetMPIType<T>())> 
    void Bcast (T & s, int root = 0) const {
      if (size == 1) return ;
      MPI_Bcast (&s, 1, GetMPIType<T>(), root, comm);
    }
    
    void Bcast (std::string & s, int root = 0) const 
    {
      if (size == 1) return;
      int len = s.length();
      Bcast (len, root);
      if (rank != 0) s.resize (len);
      MPI_Bcast (&s[0], len, MPI_CHAR, root, comm);
    }

    
  };

  
#else // PARALLEL
  class MPI_Comm {
    int nr;
  public:
    MPI_Comm (int _nr = 0) : nr(_nr) { ; }
    operator int() const { return nr; }
    bool operator== (MPI_Comm c2) const { return nr == c2.nr; }
  };
  static MPI_Comm MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 10000;

  typedef int MPI_Op;
  typedef int MPI_Request;
  
  enum { MPI_SUM = 0, MPI_MIN = 1, MPI_MAX = 2 };
  
  class NgMPI_Comm
  {
    
  public:
    NgMPI_Comm () { ; } 
    NgMPI_Comm (MPI_Comm _comm, bool owns = false) { ; }

    size_t Rank() const { return 0; }
    size_t Size() const { return 1; }
    void Barrier() const { ; } 
    operator MPI_Comm() const { return MPI_Comm(); }

    template<typename T>
    void Send( T & val, int dest, int tag) const { ; }
    
    template<typename T>
    void MyMPI_Recv (T & val, int src, int tag) const { ; }

    template<typename T>
    MPI_Request ISend (T & val, int dest, int tag) const { return 0; } 
    
    template<typename T>
    MPI_Request IRecv (T & val, int dest, int tag) const { return 0; } 
    
    template <typename T>
    T Reduce (T d, const MPI_Op & op, int root = 0) { return d; }
    
    template <typename T>
    T AllReduce (T d, const MPI_Op & op) const { return d; }

    template <typename T>
    void Bcast (T & s, int root = 0) const { ; } 
  };  
  
#endif // PARALLEL






  

  
}

#endif // NGCORE_MPIWRAPPER_HPP