#include <petsc/private/dmpleximpl.h> /*I      "petscdmplex.h"   I*/

/* TODO PetscArrayExchangeBegin/End */
/* TODO blocksize */
/* TODO move to API ? */
static PetscErrorCode ExchangeArrayByRank_Private(PetscObject obj, MPI_Datatype dt, PetscInt nsranks, const PetscMPIInt sranks[], PetscInt ssize[], const void *sarr[], PetscInt nrranks, const PetscMPIInt rranks[], PetscInt *rsize_out[], void **rarr_out[])
{
  PetscInt     r;
  PetscInt    *rsize;
  void       **rarr;
  MPI_Request *sreq, *rreq;
  PetscMPIInt  tag, unitsize;
  MPI_Comm     comm;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Type_size(dt, &unitsize));
  PetscCall(PetscObjectGetComm(obj, &comm));
  PetscCall(PetscMalloc2(nrranks, &rsize, nrranks, &rarr));
  PetscCall(PetscMalloc2(nrranks, &rreq, nsranks, &sreq));
  /* exchange array size */
  PetscCall(PetscObjectGetNewTag(obj, &tag));
  for (r = 0; r < nrranks; r++) PetscCallMPI(MPI_Irecv(&rsize[r], 1, MPIU_INT, rranks[r], tag, comm, &rreq[r]));
  for (r = 0; r < nsranks; r++) PetscCallMPI(MPI_Isend(&ssize[r], 1, MPIU_INT, sranks[r], tag, comm, &sreq[r]));
  PetscCallMPI(MPI_Waitall(nrranks, rreq, MPI_STATUSES_IGNORE));
  PetscCallMPI(MPI_Waitall(nsranks, sreq, MPI_STATUSES_IGNORE));
  /* exchange array */
  PetscCall(PetscObjectGetNewTag(obj, &tag));
  for (r = 0; r < nrranks; r++) {
    PetscCall(PetscMalloc(rsize[r] * unitsize, &rarr[r]));
    PetscCallMPI(MPI_Irecv(rarr[r], rsize[r], dt, rranks[r], tag, comm, &rreq[r]));
  }
  for (r = 0; r < nsranks; r++) PetscCallMPI(MPI_Isend(sarr[r], ssize[r], dt, sranks[r], tag, comm, &sreq[r]));
  PetscCallMPI(MPI_Waitall(nrranks, rreq, MPI_STATUSES_IGNORE));
  PetscCallMPI(MPI_Waitall(nsranks, sreq, MPI_STATUSES_IGNORE));
  PetscCall(PetscFree2(rreq, sreq));
  *rsize_out = rsize;
  *rarr_out  = rarr;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* TODO VecExchangeBegin/End */
/* TODO move to API ? */
static PetscErrorCode ExchangeVecByRank_Private(PetscObject obj, PetscInt nsranks, const PetscMPIInt sranks[], Vec svecs[], PetscInt nrranks, const PetscMPIInt rranks[], Vec *rvecs[])
{
  PetscInt            r;
  PetscInt           *ssize, *rsize;
  PetscScalar       **rarr;
  const PetscScalar **sarr;
  Vec                *rvecs_;
  MPI_Request        *sreq, *rreq;

  PetscFunctionBegin;
  PetscCall(PetscMalloc4(nsranks, &ssize, nsranks, &sarr, nrranks, &rreq, nsranks, &sreq));
  for (r = 0; r < nsranks; r++) {
    PetscCall(VecGetLocalSize(svecs[r], &ssize[r]));
    PetscCall(VecGetArrayRead(svecs[r], &sarr[r]));
  }
  PetscCall(ExchangeArrayByRank_Private(obj, MPIU_SCALAR, nsranks, sranks, ssize, (const void **)sarr, nrranks, rranks, &rsize, (void ***)&rarr));
  PetscCall(PetscMalloc1(nrranks, &rvecs_));
  for (r = 0; r < nrranks; r++) {
    /* set array in two steps to mimic PETSC_OWN_POINTER */
    PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, rsize[r], NULL, &rvecs_[r]));
    PetscCall(VecReplaceArray(rvecs_[r], rarr[r]));
  }
  for (r = 0; r < nsranks; r++) PetscCall(VecRestoreArrayRead(svecs[r], &sarr[r]));
  PetscCall(PetscFree2(rsize, rarr));
  PetscCall(PetscFree4(ssize, sarr, rreq, sreq));
  *rvecs = rvecs_;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode SortByRemote_Private(PetscSF sf, PetscInt *rmine1[], PetscInt *rremote1[])
{
  PetscInt           nleaves;
  PetscInt           nranks;
  const PetscMPIInt *ranks;
  const PetscInt    *roffset, *rmine, *rremote;
  PetscInt           n, o, r;

  PetscFunctionBegin;
  PetscCall(PetscSFGetRootRanks(sf, &nranks, &ranks, &roffset, &rmine, &rremote));
  nleaves = roffset[nranks];
  PetscCall(PetscMalloc2(nleaves, rmine1, nleaves, rremote1));
  for (r = 0; r < nranks; r++) {
    /* simultaneously sort rank-wise portions of rmine & rremote by values in rremote
       - to unify order with the other side */
    o = roffset[r];
    n = roffset[r + 1] - o;
    PetscCall(PetscArraycpy(&(*rmine1)[o], &rmine[o], n));
    PetscCall(PetscArraycpy(&(*rremote1)[o], &rremote[o], n));
    PetscCall(PetscSortIntWithArray(n, &(*rremote1)[o], &(*rmine1)[o]));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode GetRecursiveConeCoordinatesPerRank_Private(DM dm, PetscSF sf, PetscInt rmine[], Vec *coordinatesPerRank[])
{
  IS                 pointsPerRank, conesPerRank;
  PetscInt           nranks;
  const PetscMPIInt *ranks;
  const PetscInt    *roffset;
  PetscInt           n, o, r;

  PetscFunctionBegin;
  PetscCall(DMGetCoordinatesLocalSetUp(dm));
  PetscCall(PetscSFGetRootRanks(sf, &nranks, &ranks, &roffset, NULL, NULL));
  PetscCall(PetscMalloc1(nranks, coordinatesPerRank));
  for (r = 0; r < nranks; r++) {
    o = roffset[r];
    n = roffset[r + 1] - o;
    PetscCall(ISCreateGeneral(PETSC_COMM_SELF, n, &rmine[o], PETSC_USE_POINTER, &pointsPerRank));
    PetscCall(DMPlexGetConeRecursiveVertices(dm, pointsPerRank, &conesPerRank));
    PetscCall(DMGetCoordinatesLocalTuple(dm, conesPerRank, NULL, &(*coordinatesPerRank)[r]));
    PetscCall(ISDestroy(&pointsPerRank));
    PetscCall(ISDestroy(&conesPerRank));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSFComputeMultiRootOriginalNumberingByRank_Private(PetscSF sf, PetscSF imsf, PetscInt *irmine1[])
{
  PetscInt       *mRootsOrigNumbering;
  PetscInt        nileaves, niranks;
  const PetscInt *iroffset, *irmine, *degree;
  PetscInt        i, n, o, r;

  PetscFunctionBegin;
  PetscCall(PetscSFGetGraph(imsf, NULL, &nileaves, NULL, NULL));
  PetscCall(PetscSFGetRootRanks(imsf, &niranks, NULL, &iroffset, &irmine, NULL));
  PetscCheck(nileaves == iroffset[niranks], PETSC_COMM_SELF, PETSC_ERR_PLIB, "nileaves != iroffset[niranks])");
  PetscCall(PetscSFComputeDegreeBegin(sf, &degree));
  PetscCall(PetscSFComputeDegreeEnd(sf, &degree));
  PetscCall(PetscSFComputeMultiRootOriginalNumbering(sf, degree, NULL, &mRootsOrigNumbering));
  PetscCall(PetscMalloc1(nileaves, irmine1));
  for (r = 0; r < niranks; r++) {
    o = iroffset[r];
    n = iroffset[r + 1] - o;
    for (i = 0; i < n; i++) (*irmine1)[o + i] = mRootsOrigNumbering[irmine[o + i]];
  }
  PetscCall(PetscFree(mRootsOrigNumbering));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  DMPlexCheckInterfaceCones - Check that points on inter-partition interfaces have conforming order of cone points.

  Input Parameters:
. dm - The `DMPLEX` object

  Level: developer

  Notes:
  For example, if there is an edge (rank,index)=(0,2) connecting points cone(0,2)=[(0,0),(0,1)] in this order, and the point SF contains connections 0 <- (1,0), 1 <- (1,1) and 2 <- (1,2),
  then this check would pass if the edge (1,2) has cone(1,2)=[(1,0),(1,1)]. By contrast, if cone(1,2)=[(1,1),(1,0)], then this check would fail.

  This is mainly intended for debugging/testing purposes. Does not check cone orientation, for this purpose use `DMPlexCheckFaces()`.

  For the complete list of DMPlexCheck* functions, see `DMSetFromOptions()`.

  Developer Note:
  Interface cones are expanded into vertices and then their coordinates are compared.

.seealso: [](chapter_unstructured), `DM`, `DMPLEX`, `DMPlexGetCone()`, `DMPlexGetConeSize()`, `DMGetPointSF()`, `DMGetCoordinates()`, `DMSetFromOptions()`
@*/
PetscErrorCode DMPlexCheckInterfaceCones(DM dm)
{
  PetscSF            sf;
  PetscInt           nleaves, nranks, nroots;
  const PetscInt    *mine, *roffset, *rmine, *rremote;
  const PetscSFNode *remote;
  const PetscMPIInt *ranks;
  PetscSF            msf, imsf;
  PetscInt           nileaves, niranks;
  const PetscMPIInt *iranks;
  const PetscInt    *iroffset, *irmine, *irremote;
  PetscInt          *rmine1, *rremote1; /* rmine and rremote copies simultaneously sorted by rank and rremote */
  PetscInt          *mine_orig_numbering;
  Vec               *sntCoordinatesPerRank;
  Vec               *refCoordinatesPerRank;
  Vec               *recCoordinatesPerRank = NULL;
  PetscInt           r;
  PetscMPIInt        commsize, myrank;
  PetscBool          same;
  PetscBool          verbose = PETSC_FALSE;
  MPI_Comm           comm;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(dm, DM_CLASSID, 1);
  PetscCall(PetscObjectGetComm((PetscObject)dm, &comm));
  PetscCallMPI(MPI_Comm_rank(comm, &myrank));
  PetscCallMPI(MPI_Comm_size(comm, &commsize));
  if (commsize < 2) PetscFunctionReturn(PETSC_SUCCESS);
  PetscCall(DMGetPointSF(dm, &sf));
  if (!sf) PetscFunctionReturn(PETSC_SUCCESS);
  PetscCall(PetscSFGetGraph(sf, &nroots, &nleaves, &mine, &remote));
  if (nroots < 0) PetscFunctionReturn(PETSC_SUCCESS);
  PetscCheck(dm->coordinates[0].x || dm->coordinates[0].xl, PetscObjectComm((PetscObject)dm), PETSC_ERR_ARG_WRONGSTATE, "DM coordinates must be set");
  PetscCall(PetscSFSetUp(sf));
  PetscCall(PetscSFGetRootRanks(sf, &nranks, &ranks, &roffset, &rmine, &rremote));

  /* Expand sent cones per rank */
  PetscCall(SortByRemote_Private(sf, &rmine1, &rremote1));
  PetscCall(GetRecursiveConeCoordinatesPerRank_Private(dm, sf, rmine1, &sntCoordinatesPerRank));

  /* Create inverse SF */
  PetscCall(PetscSFGetMultiSF(sf, &msf));
  PetscCall(PetscSFCreateInverseSF(msf, &imsf));
  PetscCall(PetscSFSetUp(imsf));
  PetscCall(PetscSFGetGraph(imsf, NULL, &nileaves, NULL, NULL));
  PetscCall(PetscSFGetRootRanks(imsf, &niranks, &iranks, &iroffset, &irmine, &irremote));

  /* Compute original numbering of multi-roots (referenced points) */
  PetscCall(PetscSFComputeMultiRootOriginalNumberingByRank_Private(sf, imsf, &mine_orig_numbering));

  /* Expand coordinates of the referred cones per rank */
  PetscCall(GetRecursiveConeCoordinatesPerRank_Private(dm, imsf, mine_orig_numbering, &refCoordinatesPerRank));

  /* Send the coordinates */
  PetscCall(ExchangeVecByRank_Private((PetscObject)sf, nranks, ranks, sntCoordinatesPerRank, niranks, iranks, &recCoordinatesPerRank));

  /* verbose output */
  PetscCall(PetscOptionsGetBool(((PetscObject)dm)->options, ((PetscObject)dm)->prefix, "-dm_plex_check_cones_conform_on_interfaces_verbose", &verbose, NULL));
  if (verbose) {
    PetscViewer sv, v = PETSC_VIEWER_STDOUT_WORLD;
    PetscCall(PetscViewerASCIIPrintf(v, "============\nDMPlexCheckInterfaceCones output\n============\n"));
    PetscCall(PetscViewerASCIIPushSynchronized(v));
    PetscCall(PetscViewerASCIISynchronizedPrintf(v, "[%d] --------\n", myrank));
    for (r = 0; r < nranks; r++) {
      PetscCall(PetscViewerASCIISynchronizedPrintf(v, "  r=%" PetscInt_FMT " ranks[r]=%d sntCoordinatesPerRank[r]:\n", r, ranks[r]));
      PetscCall(PetscViewerASCIIPushTab(v));
      PetscCall(PetscViewerGetSubViewer(v, PETSC_COMM_SELF, &sv));
      PetscCall(VecView(sntCoordinatesPerRank[r], sv));
      PetscCall(PetscViewerRestoreSubViewer(v, PETSC_COMM_SELF, &sv));
      PetscCall(PetscViewerASCIIPopTab(v));
    }
    PetscCall(PetscViewerASCIISynchronizedPrintf(v, "  ----------\n"));
    for (r = 0; r < niranks; r++) {
      PetscCall(PetscViewerASCIISynchronizedPrintf(v, "  r=%" PetscInt_FMT " iranks[r]=%d refCoordinatesPerRank[r]:\n", r, iranks[r]));
      PetscCall(PetscViewerASCIIPushTab(v));
      PetscCall(PetscViewerGetSubViewer(v, PETSC_COMM_SELF, &sv));
      PetscCall(VecView(refCoordinatesPerRank[r], sv));
      PetscCall(PetscViewerRestoreSubViewer(v, PETSC_COMM_SELF, &sv));
      PetscCall(PetscViewerASCIIPopTab(v));
    }
    PetscCall(PetscViewerASCIISynchronizedPrintf(v, "  ----------\n"));
    for (r = 0; r < niranks; r++) {
      PetscCall(PetscViewerASCIISynchronizedPrintf(v, "  r=%" PetscInt_FMT " iranks[r]=%d recCoordinatesPerRank[r]:\n", r, iranks[r]));
      PetscCall(PetscViewerASCIIPushTab(v));
      PetscCall(PetscViewerGetSubViewer(v, PETSC_COMM_SELF, &sv));
      PetscCall(VecView(recCoordinatesPerRank[r], sv));
      PetscCall(PetscViewerRestoreSubViewer(v, PETSC_COMM_SELF, &sv));
      PetscCall(PetscViewerASCIIPopTab(v));
    }
    PetscCall(PetscViewerFlush(v));
    PetscCall(PetscViewerASCIIPopSynchronized(v));
  }

  /* Compare recCoordinatesPerRank with refCoordinatesPerRank */
  for (r = 0; r < niranks; r++) {
    PetscCall(VecEqual(refCoordinatesPerRank[r], recCoordinatesPerRank[r], &same));
    PetscCheck(same, PETSC_COMM_SELF, PETSC_ERR_PLIB, "interface cones do not conform for remote rank %d", iranks[r]);
  }

  /* destroy sent stuff */
  for (r = 0; r < nranks; r++) PetscCall(VecDestroy(&sntCoordinatesPerRank[r]));
  PetscCall(PetscFree(sntCoordinatesPerRank));
  PetscCall(PetscFree2(rmine1, rremote1));
  PetscCall(PetscSFDestroy(&imsf));

  /* destroy referenced stuff */
  for (r = 0; r < niranks; r++) PetscCall(VecDestroy(&refCoordinatesPerRank[r]));
  PetscCall(PetscFree(refCoordinatesPerRank));
  PetscCall(PetscFree(mine_orig_numbering));

  /* destroy received stuff */
  for (r = 0; r < niranks; r++) PetscCall(VecDestroy(&recCoordinatesPerRank[r]));
  PetscCall(PetscFree(recCoordinatesPerRank));
  PetscFunctionReturn(PETSC_SUCCESS);
}
