#include <petscsys.h> /*I  "petscsys.h"  I*/

PETSC_SINGLE_LIBRARY_INTERN PetscErrorCode PetscGatherNumberOfMessages_Private(MPI_Comm, const PetscMPIInt[], const PetscInt[], PetscMPIInt *);
PETSC_SINGLE_LIBRARY_INTERN PetscErrorCode PetscGatherMessageLengths_Private(MPI_Comm, PetscMPIInt, PetscMPIInt, const PetscInt[], PetscMPIInt **, PetscInt **);

/*@C
  PetscGatherNumberOfMessages -  Computes the number of messages an MPI rank expects to receive during a neighbor communication

  Collective, No Fortran Support

  Input Parameters:
+ comm     - Communicator
. iflags   - an array of integers of length sizeof(comm). A '1' in `ilengths`[i] represent a
             message from current node to ith node. Optionally `NULL`
- ilengths - Non zero ilengths[i] represent a message to i of length `ilengths`[i].
             Optionally `NULL`.

  Output Parameter:
. nrecvs - number of messages received

  Level: developer

  Notes:
  With this info, the correct message lengths can be determined using
  `PetscGatherMessageLengths()`

  Either `iflags` or `ilengths` should be provided.  If `iflags` is not
  provided (`NULL`) it can be computed from `ilengths`. If `iflags` is
  provided, `ilengths` is not required.

.seealso: `PetscGatherMessageLengths()`, `PetscGatherMessageLengths2()`, `PetscCommBuildTwoSided()`
@*/
PetscErrorCode PetscGatherNumberOfMessages(MPI_Comm comm, const PetscMPIInt iflags[], const PetscMPIInt ilengths[], PetscMPIInt *nrecvs)
{
  PetscMPIInt size, rank, *recv_buf, i, *iflags_local = NULL, *iflags_localm;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Comm_size(comm, &size));
  PetscCallMPI(MPI_Comm_rank(comm, &rank));

  PetscCall(PetscMalloc2(size, &recv_buf, size, &iflags_localm));

  /* If iflags not provided, compute iflags from ilengths */
  if (!iflags) {
    PetscCheck(ilengths, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Either iflags or ilengths should be provided");
    iflags_local = iflags_localm;
    for (i = 0; i < size; i++) {
      if (ilengths[i]) iflags_local[i] = 1;
      else iflags_local[i] = 0;
    }
  } else iflags_local = (PetscMPIInt *)iflags;

  /* Post an allreduce to determine the number of messages the current MPI rank will receive */
  PetscCallMPI(MPIU_Allreduce(iflags_local, recv_buf, size, MPI_INT, MPI_SUM, comm));
  *nrecvs = recv_buf[rank];

  PetscCall(PetscFree2(recv_buf, iflags_localm));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  PetscGatherMessageLengths - Computes information about messages that an MPI rank will receive,
  including (from-id,length) pairs for each message.

  Collective, No Fortran Support

  Input Parameters:
+ comm     - Communicator
. nsends   - number of messages that are to be sent.
. nrecvs   - number of messages being received
- ilengths - an array of integers of length sizeof(comm)
              a non zero `ilengths`[i] represent a message to i of length `ilengths`[i]

  Output Parameters:
+ onodes   - list of ranks from which messages are expected
- olengths - corresponding message lengths

  Level: developer

  Notes:
  With this info, the correct `MPIU_Irecv()` can be posted with the correct
  from-id, with a buffer with the right amount of memory required.

  The calling function deallocates the memory in onodes and olengths

  To determine `nrecvs`, one can use `PetscGatherNumberOfMessages()`

.seealso: `PetscGatherNumberOfMessages()`, `PetscGatherMessageLengths2()`, `PetscCommBuildTwoSided()`
@*/
PetscErrorCode PetscGatherMessageLengths(MPI_Comm comm, PetscMPIInt nsends, PetscMPIInt nrecvs, const PetscMPIInt ilengths[], PetscMPIInt **onodes, PetscMPIInt **olengths)
{
  PetscMPIInt  size, rank, tag, i, j;
  MPI_Request *s_waits, *r_waits;
  MPI_Status  *w_status;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Comm_size(comm, &size));
  PetscCallMPI(MPI_Comm_rank(comm, &rank));
  PetscCall(PetscCommGetNewTag(comm, &tag));

  /* cannot use PetscMalloc3() here because in the call to MPI_Waitall() they MUST be contiguous */
  PetscCall(PetscMalloc2(nrecvs + nsends, &r_waits, nrecvs + nsends, &w_status));
  s_waits = PetscSafePointerPlusOffset(r_waits, nrecvs);

  /* Post the Irecv to get the message length-info */
  PetscCall(PetscMalloc1(nrecvs, olengths));
  for (i = 0; i < nrecvs; i++) PetscCallMPI(MPIU_Irecv((*olengths) + i, 1, MPI_INT, MPI_ANY_SOURCE, tag, comm, r_waits + i));

  /* Post the Isends with the message length-info */
  for (i = 0, j = 0; i < size; ++i) {
    if (ilengths[i]) {
      PetscCallMPI(MPIU_Isend((void *)(ilengths + i), 1, MPI_INT, i, tag, comm, s_waits + j));
      j++;
    }
  }

  /* Post waits on sends and receives */
  if (nrecvs + nsends) PetscCallMPI(MPI_Waitall(nrecvs + nsends, r_waits, w_status));

  /* Pack up the received data */
  PetscCall(PetscMalloc1(nrecvs, onodes));
  for (i = 0; i < nrecvs; ++i) {
    (*onodes)[i] = w_status[i].MPI_SOURCE;
#if defined(PETSC_HAVE_OPENMPI)
    /* This line is a workaround for a bug in Open MPI 2.1.1 distributed by Ubuntu-18.04.2 LTS.
       It happens in self-to-self MPI_Send/Recv using MPI_ANY_SOURCE for message matching. Open MPI
       does not put correct value in recv buffer. See also
       https://lists.mcs.anl.gov/pipermail/petsc-dev/2019-July/024803.html
       https://www.mail-archive.com/users@lists.open-mpi.org//msg33383.html
     */
    if (w_status[i].MPI_SOURCE == rank) (*olengths)[i] = ilengths[rank];
#endif
  }
  PetscCall(PetscFree2(r_waits, w_status));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* Same as PetscGatherNumberOfMessages(), except using PetscInt for ilengths[] */
PetscErrorCode PetscGatherNumberOfMessages_Private(MPI_Comm comm, const PetscMPIInt iflags[], const PetscInt ilengths[], PetscMPIInt *nrecvs)
{
  PetscMPIInt size, rank, *recv_buf, i, *iflags_local = NULL, *iflags_localm;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Comm_size(comm, &size));
  PetscCallMPI(MPI_Comm_rank(comm, &rank));

  PetscCall(PetscMalloc2(size, &recv_buf, size, &iflags_localm));

  /* If iflags not provided, compute iflags from ilengths */
  if (!iflags) {
    PetscCheck(ilengths, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Either iflags or ilengths should be provided");
    iflags_local = iflags_localm;
    for (i = 0; i < size; i++) {
      if (ilengths[i]) iflags_local[i] = 1;
      else iflags_local[i] = 0;
    }
  } else iflags_local = (PetscMPIInt *)iflags;

  /* Post an allreduce to determine the number of messages the current MPI rank will receive */
  PetscCallMPI(MPIU_Allreduce(iflags_local, recv_buf, size, MPI_INT, MPI_SUM, comm));
  *nrecvs = recv_buf[rank];

  PetscCall(PetscFree2(recv_buf, iflags_localm));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* Same as PetscGatherMessageLengths(), except using PetscInt for message lengths */
PetscErrorCode PetscGatherMessageLengths_Private(MPI_Comm comm, PetscMPIInt nsends, PetscMPIInt nrecvs, const PetscInt ilengths[], PetscMPIInt **onodes, PetscInt **olengths)
{
  PetscMPIInt  size, rank, tag, i, j;
  MPI_Request *s_waits, *r_waits;
  MPI_Status  *w_status;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Comm_size(comm, &size));
  PetscCallMPI(MPI_Comm_rank(comm, &rank));
  PetscCall(PetscCommGetNewTag(comm, &tag));

  /* cannot use PetscMalloc3() here because in the call to MPI_Waitall() they MUST be contiguous */
  PetscCall(PetscMalloc2(nrecvs + nsends, &r_waits, nrecvs + nsends, &w_status));
  s_waits = PetscSafePointerPlusOffset(r_waits, nrecvs);

  /* Post the Irecv to get the message length-info */
  PetscCall(PetscMalloc1(nrecvs, olengths));
  for (i = 0; i < nrecvs; i++) PetscCallMPI(MPIU_Irecv((*olengths) + i, 1, MPIU_INT, MPI_ANY_SOURCE, tag, comm, r_waits + i));

  /* Post the Isends with the message length-info */
  for (i = 0, j = 0; i < size; ++i) {
    if (ilengths[i]) {
      PetscCallMPI(MPIU_Isend((void *)(ilengths + i), 1, MPIU_INT, i, tag, comm, s_waits + j));
      j++;
    }
  }

  /* Post waits on sends and receives */
  if (nrecvs + nsends) PetscCallMPI(MPI_Waitall(nrecvs + nsends, r_waits, w_status));

  /* Pack up the received data */
  PetscCall(PetscMalloc1(nrecvs, onodes));
  for (i = 0; i < nrecvs; ++i) {
    (*onodes)[i] = w_status[i].MPI_SOURCE;
    if (w_status[i].MPI_SOURCE == rank) (*olengths)[i] = ilengths[rank]; /* See comments in PetscGatherMessageLengths */
  }
  PetscCall(PetscFree2(r_waits, w_status));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  PetscGatherMessageLengths2 - Computes info about messages that a MPI rank will receive,
  including (from-id,length) pairs for each message. Same functionality as `PetscGatherMessageLengths()`
  except it takes TWO ilenths and output TWO olengths.

  Collective, No Fortran Support

  Input Parameters:
+ comm      - Communicator
. nsends    - number of messages that are to be sent.
. nrecvs    - number of messages being received
. ilengths1 - first array of integers of length sizeof(comm)
- ilengths2 - second array of integers of length sizeof(comm)

  Output Parameters:
+ onodes    - list of ranks from which messages are expected
. olengths1 - first corresponding message lengths
- olengths2 - second  message lengths

  Level: developer

  Notes:
  With this info, the correct `MPIU_Irecv()` can be posted with the correct
  from-id, with a buffer with the right amount of memory required.

  The calling function should `PetscFree()` the memory in `onodes` and `olengths`

  To determine `nrecvs`, one can use `PetscGatherNumberOfMessages()`

.seealso: `PetscGatherMessageLengths()`, `PetscGatherNumberOfMessages()`, `PetscCommBuildTwoSided()`
@*/
PetscErrorCode PetscGatherMessageLengths2(MPI_Comm comm, PetscMPIInt nsends, PetscMPIInt nrecvs, const PetscMPIInt ilengths1[], const PetscMPIInt ilengths2[], PetscMPIInt **onodes, PetscMPIInt **olengths1, PetscMPIInt **olengths2)
{
  PetscMPIInt  size, tag, i, j, *buf_s, *buf_r, *buf_j = NULL;
  MPI_Request *s_waits, *r_waits;
  MPI_Status  *w_status;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Comm_size(comm, &size));
  PetscCall(PetscCommGetNewTag(comm, &tag));

  /* cannot use PetscMalloc5() because r_waits and s_waits must be contiguous for the call to MPI_Waitall() */
  PetscCall(PetscMalloc4(nrecvs + nsends, &r_waits, 2 * nrecvs, &buf_r, 2 * nsends, &buf_s, nrecvs + nsends, &w_status));
  s_waits = PetscSafePointerPlusOffset(r_waits, nrecvs);

  /* Post the Irecv to get the message length-info */
  PetscCall(PetscMalloc1(nrecvs + 1, olengths1));
  PetscCall(PetscMalloc1(nrecvs + 1, olengths2));
  for (i = 0; i < nrecvs; i++) {
    buf_j = buf_r + (2 * i);
    PetscCallMPI(MPIU_Irecv(buf_j, 2, MPI_INT, MPI_ANY_SOURCE, tag, comm, r_waits + i));
  }

  /* Post the Isends with the message length-info */
  for (i = 0, j = 0; i < size; ++i) {
    if (ilengths1[i]) {
      buf_j    = buf_s + (2 * j);
      buf_j[0] = *(ilengths1 + i);
      buf_j[1] = *(ilengths2 + i);
      PetscCallMPI(MPIU_Isend(buf_j, 2, MPI_INT, i, tag, comm, s_waits + j));
      j++;
    }
  }
  PetscCheck(j == nsends, PETSC_COMM_SELF, PETSC_ERR_PLIB, "j %d not equal to expected number of sends %d", j, nsends);

  /* Post waits on sends and receives */
  if (nrecvs + nsends) PetscCallMPI(MPI_Waitall(nrecvs + nsends, r_waits, w_status));

  /* Pack up the received data */
  PetscCall(PetscMalloc1(nrecvs + 1, onodes));
  for (i = 0; i < nrecvs; ++i) {
    (*onodes)[i]    = w_status[i].MPI_SOURCE;
    buf_j           = buf_r + (2 * i);
    (*olengths1)[i] = buf_j[0];
    (*olengths2)[i] = buf_j[1];
  }

  PetscCall(PetscFree4(r_waits, buf_r, buf_s, w_status));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*
  Allocate a buffer sufficient to hold messages of size specified in olengths.
  And post Irecvs on these buffers using node info from onodes
 */
PetscErrorCode PetscPostIrecvInt(MPI_Comm comm, PetscMPIInt tag, PetscMPIInt nrecvs, const PetscMPIInt onodes[], const PetscMPIInt olengths[], PetscInt ***rbuf, MPI_Request **r_waits)
{
  PetscInt   **rbuf_t, i, len = 0;
  MPI_Request *r_waits_t;

  PetscFunctionBegin;
  /* compute memory required for recv buffers */
  for (i = 0; i < nrecvs; i++) len += olengths[i]; /* each message length */

  /* allocate memory for recv buffers */
  PetscCall(PetscMalloc1(nrecvs + 1, &rbuf_t));
  PetscCall(PetscMalloc1(len, &rbuf_t[0]));
  for (i = 1; i < nrecvs; ++i) rbuf_t[i] = rbuf_t[i - 1] + olengths[i - 1];

  /* Post the receives */
  PetscCall(PetscMalloc1(nrecvs, &r_waits_t));
  for (i = 0; i < nrecvs; ++i) PetscCallMPI(MPIU_Irecv(rbuf_t[i], olengths[i], MPIU_INT, onodes[i], tag, comm, r_waits_t + i));

  *rbuf    = rbuf_t;
  *r_waits = r_waits_t;
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode PetscPostIrecvScalar(MPI_Comm comm, PetscMPIInt tag, PetscMPIInt nrecvs, const PetscMPIInt onodes[], const PetscMPIInt olengths[], PetscScalar ***rbuf, MPI_Request **r_waits)
{
  PetscMPIInt   i;
  PetscScalar **rbuf_t;
  MPI_Request  *r_waits_t;
  PetscInt      len = 0;

  PetscFunctionBegin;
  /* compute memory required for recv buffers */
  for (i = 0; i < nrecvs; i++) len += olengths[i]; /* each message length */

  /* allocate memory for recv buffers */
  PetscCall(PetscMalloc1(nrecvs + 1, &rbuf_t));
  PetscCall(PetscMalloc1(len, &rbuf_t[0]));
  for (i = 1; i < nrecvs; ++i) rbuf_t[i] = rbuf_t[i - 1] + olengths[i - 1];

  /* Post the receives */
  PetscCall(PetscMalloc1(nrecvs, &r_waits_t));
  for (i = 0; i < nrecvs; ++i) PetscCallMPI(MPIU_Irecv(rbuf_t[i], olengths[i], MPIU_SCALAR, onodes[i], tag, comm, r_waits_t + i));

  *rbuf    = rbuf_t;
  *r_waits = r_waits_t;
  PetscFunctionReturn(PETSC_SUCCESS);
}
