#include <petscdmda.h> /*I "petscdmda.h" I*/
#include <petsc/private/dmimpl.h>
#include <petsc/private/tsimpl.h> /*I "petscts.h" I*/
#include <petscdraw.h>

/* This structure holds the user-provided DMDA callbacks */
typedef struct {
  PetscErrorCode (*ifunctionlocal)(DMDALocalInfo *, PetscReal, void *, void *, void *, void *);
  PetscErrorCode (*rhsfunctionlocal)(DMDALocalInfo *, PetscReal, void *, void *, void *);
  PetscErrorCode (*ijacobianlocal)(DMDALocalInfo *, PetscReal, void *, void *, PetscReal, Mat, Mat, void *);
  PetscErrorCode (*rhsjacobianlocal)(DMDALocalInfo *, PetscReal, void *, Mat, Mat, void *);
  void      *ifunctionlocalctx;
  void      *ijacobianlocalctx;
  void      *rhsfunctionlocalctx;
  void      *rhsjacobianlocalctx;
  InsertMode ifunctionlocalimode;
  InsertMode rhsfunctionlocalimode;
} DMTS_DA;

static PetscErrorCode DMTSDestroy_DMDA(DMTS sdm)
{
  PetscFunctionBegin;
  PetscCall(PetscFree(sdm->data));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode DMTSDuplicate_DMDA(DMTS oldsdm, DMTS sdm)
{
  PetscFunctionBegin;
  PetscCall(PetscNew((DMTS_DA **)&sdm->data));
  if (oldsdm->data) PetscCall(PetscMemcpy(sdm->data, oldsdm->data, sizeof(DMTS_DA)));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode DMDATSGetContext(DM dm, DMTS sdm, DMTS_DA **dmdats)
{
  PetscFunctionBegin;
  *dmdats = NULL;
  if (!sdm->data) {
    PetscCall(PetscNew((DMTS_DA **)&sdm->data));
    sdm->ops->destroy   = DMTSDestroy_DMDA;
    sdm->ops->duplicate = DMTSDuplicate_DMDA;
  }
  *dmdats = (DMTS_DA *)sdm->data;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSComputeIFunction_DMDA(TS ts, PetscReal ptime, Vec X, Vec Xdot, Vec F, PetscCtx ctx)
{
  DM            dm;
  DMTS_DA      *dmdats = (DMTS_DA *)ctx;
  DMDALocalInfo info;
  Vec           Xloc, Xdotloc;
  void         *x, *f, *xdot;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(X, VEC_CLASSID, 3);
  PetscValidHeaderSpecific(F, VEC_CLASSID, 5);
  PetscCheck(dmdats->ifunctionlocal, PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Corrupt context");
  PetscCall(TSGetDM(ts, &dm));
  PetscCall(DMGetLocalVector(dm, &Xdotloc));
  PetscCall(DMGlobalToLocalBegin(dm, Xdot, INSERT_VALUES, Xdotloc));
  PetscCall(DMGlobalToLocalEnd(dm, Xdot, INSERT_VALUES, Xdotloc));
  PetscCall(DMGetLocalVector(dm, &Xloc));
  PetscCall(DMGlobalToLocalBegin(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMGlobalToLocalEnd(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMDAGetLocalInfo(dm, &info));
  PetscCall(DMDAVecGetArray(dm, Xloc, &x));
  PetscCall(DMDAVecGetArray(dm, Xdotloc, &xdot));
  switch (dmdats->ifunctionlocalimode) {
  case INSERT_VALUES: {
    PetscCall(DMDAVecGetArray(dm, F, &f));
    CHKMEMQ;
    PetscCall((*dmdats->ifunctionlocal)(&info, ptime, x, xdot, f, dmdats->ifunctionlocalctx));
    CHKMEMQ;
    PetscCall(DMDAVecRestoreArray(dm, F, &f));
  } break;
  case ADD_VALUES: {
    Vec Floc;
    PetscCall(DMGetLocalVector(dm, &Floc));
    PetscCall(VecZeroEntries(Floc));
    PetscCall(DMDAVecGetArray(dm, Floc, &f));
    CHKMEMQ;
    PetscCall((*dmdats->ifunctionlocal)(&info, ptime, x, xdot, f, dmdats->ifunctionlocalctx));
    CHKMEMQ;
    PetscCall(DMDAVecRestoreArray(dm, Floc, &f));
    PetscCall(VecZeroEntries(F));
    PetscCall(DMLocalToGlobalBegin(dm, Floc, ADD_VALUES, F));
    PetscCall(DMLocalToGlobalEnd(dm, Floc, ADD_VALUES, F));
    PetscCall(DMRestoreLocalVector(dm, &Floc));
  } break;
  default:
    SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_INCOMP, "Cannot use imode=%d", (int)dmdats->ifunctionlocalimode);
  }
  PetscCall(DMDAVecRestoreArray(dm, Xloc, &x));
  PetscCall(DMRestoreLocalVector(dm, &Xloc));
  PetscCall(DMDAVecRestoreArray(dm, Xdotloc, &xdot));
  PetscCall(DMRestoreLocalVector(dm, &Xdotloc));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSComputeIJacobian_DMDA(TS ts, PetscReal ptime, Vec X, Vec Xdot, PetscReal shift, Mat A, Mat B, PetscCtx ctx)
{
  DM            dm;
  DMTS_DA      *dmdats = (DMTS_DA *)ctx;
  DMDALocalInfo info;
  Vec           Xloc, Xdotloc;
  void         *x, *xdot;

  PetscFunctionBegin;
  PetscCheck(dmdats->ifunctionlocal, PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Corrupt context");
  PetscCall(TSGetDM(ts, &dm));

  PetscCheck(dmdats->ijacobianlocal, PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "TSComputeIJacobian_DMDA() called without calling DMDATSSetIJacobian()");
  PetscCall(DMGetLocalVector(dm, &Xloc));
  PetscCall(DMGlobalToLocalBegin(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMGlobalToLocalEnd(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMGetLocalVector(dm, &Xdotloc));
  PetscCall(DMGlobalToLocalBegin(dm, Xdot, INSERT_VALUES, Xdotloc));
  PetscCall(DMGlobalToLocalEnd(dm, Xdot, INSERT_VALUES, Xdotloc));
  PetscCall(DMDAGetLocalInfo(dm, &info));
  PetscCall(DMDAVecGetArray(dm, Xloc, &x));
  PetscCall(DMDAVecGetArray(dm, Xdotloc, &xdot));
  CHKMEMQ;
  PetscCall((*dmdats->ijacobianlocal)(&info, ptime, x, xdot, shift, A, B, dmdats->ijacobianlocalctx));
  CHKMEMQ;
  PetscCall(DMDAVecRestoreArray(dm, Xloc, &x));
  PetscCall(DMDAVecRestoreArray(dm, Xdotloc, &xdot));
  PetscCall(DMRestoreLocalVector(dm, &Xloc));
  PetscCall(DMRestoreLocalVector(dm, &Xdotloc));
  /* This will be redundant if the user called both, but it's too common to forget. */
  if (A != B) {
    PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
    PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSComputeRHSFunction_DMDA(TS ts, PetscReal ptime, Vec X, Vec F, PetscCtx ctx)
{
  DM            dm;
  DMTS_DA      *dmdats = (DMTS_DA *)ctx;
  DMDALocalInfo info;
  Vec           Xloc;
  void         *x, *f;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(X, VEC_CLASSID, 3);
  PetscValidHeaderSpecific(F, VEC_CLASSID, 4);
  PetscCheck(dmdats->rhsfunctionlocal, PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Corrupt context");
  PetscCall(TSGetDM(ts, &dm));
  PetscCall(DMGetLocalVector(dm, &Xloc));
  PetscCall(DMGlobalToLocalBegin(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMGlobalToLocalEnd(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMDAGetLocalInfo(dm, &info));
  PetscCall(DMDAVecGetArray(dm, Xloc, &x));
  switch (dmdats->rhsfunctionlocalimode) {
  case INSERT_VALUES: {
    PetscCall(DMDAVecGetArray(dm, F, &f));
    CHKMEMQ;
    PetscCall((*dmdats->rhsfunctionlocal)(&info, ptime, x, f, dmdats->rhsfunctionlocalctx));
    CHKMEMQ;
    PetscCall(DMDAVecRestoreArray(dm, F, &f));
  } break;
  case ADD_VALUES: {
    Vec Floc;
    PetscCall(DMGetLocalVector(dm, &Floc));
    PetscCall(VecZeroEntries(Floc));
    PetscCall(DMDAVecGetArray(dm, Floc, &f));
    CHKMEMQ;
    PetscCall((*dmdats->rhsfunctionlocal)(&info, ptime, x, f, dmdats->rhsfunctionlocalctx));
    CHKMEMQ;
    PetscCall(DMDAVecRestoreArray(dm, Floc, &f));
    PetscCall(VecZeroEntries(F));
    PetscCall(DMLocalToGlobalBegin(dm, Floc, ADD_VALUES, F));
    PetscCall(DMLocalToGlobalEnd(dm, Floc, ADD_VALUES, F));
    PetscCall(DMRestoreLocalVector(dm, &Floc));
  } break;
  default:
    SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_INCOMP, "Cannot use imode=%d", (int)dmdats->rhsfunctionlocalimode);
  }
  PetscCall(DMDAVecRestoreArray(dm, Xloc, &x));
  PetscCall(DMRestoreLocalVector(dm, &Xloc));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSComputeRHSJacobian_DMDA(TS ts, PetscReal ptime, Vec X, Mat A, Mat B, PetscCtx ctx)
{
  DM            dm;
  DMTS_DA      *dmdats = (DMTS_DA *)ctx;
  DMDALocalInfo info;
  Vec           Xloc;
  void         *x;

  PetscFunctionBegin;
  PetscCheck(dmdats->rhsfunctionlocal, PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Corrupt context");
  PetscCall(TSGetDM(ts, &dm));

  PetscCheck(dmdats->rhsjacobianlocal, PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "TSComputeRHSJacobian_DMDA() called without calling DMDATSSetRHSJacobian()");
  PetscCall(DMGetLocalVector(dm, &Xloc));
  PetscCall(DMGlobalToLocalBegin(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMGlobalToLocalEnd(dm, X, INSERT_VALUES, Xloc));
  PetscCall(DMDAGetLocalInfo(dm, &info));
  PetscCall(DMDAVecGetArray(dm, Xloc, &x));
  CHKMEMQ;
  PetscCall((*dmdats->rhsjacobianlocal)(&info, ptime, x, A, B, dmdats->rhsjacobianlocalctx));
  CHKMEMQ;
  PetscCall(DMDAVecRestoreArray(dm, Xloc, &x));
  PetscCall(DMRestoreLocalVector(dm, &Xloc));
  /* This will be redundant if the user called both, but it's too common to forget. */
  if (A != B) {
    PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
    PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  DMDATSSetRHSFunctionLocal - set a local residual evaluation function for use with `DMDA`

  Logically Collective

  Input Parameters:
+ dm    - `DM` to associate callback with
. imode - insert mode for the residual
. func  - local residual evaluation, see `DMDATSRHSFunctionLocalFn` for the calling sequence
- ctx   - optional context for local residual evaluation

  Level: beginner

.seealso: [](ch_ts), `DMDA`, `DMDATSRHSFunctionLocalFn`, `TS`, `TSSetRHSFunction()`, `DMTSSetRHSFunction()`, `DMDATSSetRHSJacobianLocal()`, `DMDASNESSetFunctionLocal()`
@*/
PetscErrorCode DMDATSSetRHSFunctionLocal(DM dm, InsertMode imode, DMDATSRHSFunctionLocalFn *func, PetscCtx ctx)
{
  DMTS     sdm;
  DMTS_DA *dmdats;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(dm, DM_CLASSID, 1);
  PetscCall(DMGetDMTSWrite(dm, &sdm));
  PetscCall(DMDATSGetContext(dm, sdm, &dmdats));
  dmdats->rhsfunctionlocalimode = imode;
  dmdats->rhsfunctionlocal      = func;
  dmdats->rhsfunctionlocalctx   = ctx;
  PetscCall(DMTSSetRHSFunction(dm, TSComputeRHSFunction_DMDA, dmdats));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  DMDATSSetRHSJacobianLocal - set a local residual evaluation function for use with `DMDA`

  Logically Collective

  Input Parameters:
+ dm   - `DM` to associate callback with
. func - local RHS Jacobian evaluation routine, see `DMDATSRHSJacobianLocalFn` for the calling sequence
- ctx  - optional context for local jacobian evaluation

  Level: beginner

.seealso: [](ch_ts), `DMDA`, `DMDATSRHSJacobianLocalFn`, `DMTSSetRHSJacobian()`,
`DMDATSSetRHSFunctionLocal()`, `DMDASNESSetJacobianLocal()`
@*/
PetscErrorCode DMDATSSetRHSJacobianLocal(DM dm, DMDATSRHSJacobianLocalFn *func, PetscCtx ctx)
{
  DMTS     sdm;
  DMTS_DA *dmdats;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(dm, DM_CLASSID, 1);
  PetscCall(DMGetDMTSWrite(dm, &sdm));
  PetscCall(DMDATSGetContext(dm, sdm, &dmdats));
  dmdats->rhsjacobianlocal    = func;
  dmdats->rhsjacobianlocalctx = ctx;
  PetscCall(DMTSSetRHSJacobian(dm, TSComputeRHSJacobian_DMDA, dmdats));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  DMDATSSetIFunctionLocal - set a local residual evaluation function for use with `DMDA`

  Logically Collective

  Input Parameters:
+ dm    - `DM` to associate callback with
. imode - the insert mode of the function
. func  - local residual evaluation, see `DMDATSIFunctionLocalFn` for the calling sequence
- ctx   - optional context for local residual evaluation

  Level: beginner

.seealso: [](ch_ts), `DMDA`, `DMDATSIFunctionLocalFn`, `DMTSSetIFunction()`,
`DMDATSSetIJacobianLocal()`, `DMDASNESSetFunctionLocal()`
@*/
PetscErrorCode DMDATSSetIFunctionLocal(DM dm, InsertMode imode, DMDATSIFunctionLocalFn *func, PetscCtx ctx)
{
  DMTS     sdm;
  DMTS_DA *dmdats;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(dm, DM_CLASSID, 1);
  PetscCall(DMGetDMTSWrite(dm, &sdm));
  PetscCall(DMDATSGetContext(dm, sdm, &dmdats));
  dmdats->ifunctionlocalimode = imode;
  dmdats->ifunctionlocal      = func;
  dmdats->ifunctionlocalctx   = ctx;
  PetscCall(DMTSSetIFunction(dm, TSComputeIFunction_DMDA, dmdats));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  DMDATSSetIJacobianLocal - set a local residual evaluation function for use with `DMDA`

  Logically Collective

  Input Parameters:
+ dm   - `DM` to associate callback with
. func - local residual evaluation, see `DMDATSIJacobianLocalFn` for the calling sequence
- ctx  - optional context for local residual evaluation

  Level: beginner

.seealso: [](ch_ts), `DMDA`, `DMDATSIJacobianLocalFn`, `DMTSSetIJacobian()`,
`DMDATSSetIFunctionLocal()`, `DMDASNESSetJacobianLocal()`
@*/
PetscErrorCode DMDATSSetIJacobianLocal(DM dm, DMDATSIJacobianLocalFn *func, PetscCtx ctx)
{
  DMTS     sdm;
  DMTS_DA *dmdats;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(dm, DM_CLASSID, 1);
  PetscCall(DMGetDMTSWrite(dm, &sdm));
  PetscCall(DMDATSGetContext(dm, sdm, &dmdats));
  dmdats->ijacobianlocal    = func;
  dmdats->ijacobianlocalctx = ctx;
  PetscCall(DMTSSetIJacobian(dm, TSComputeIJacobian_DMDA, dmdats));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode TSMonitorDMDARayDestroy(PetscCtxRt mctx)
{
  TSMonitorDMDARayCtx *rayctx = *(TSMonitorDMDARayCtx **)mctx;

  PetscFunctionBegin;
  if (rayctx->lgctx) PetscCall(TSMonitorLGCtxDestroy(&rayctx->lgctx));
  PetscCall(VecDestroy(&rayctx->ray));
  PetscCall(VecScatterDestroy(&rayctx->scatter));
  PetscCall(PetscViewerDestroy(&rayctx->viewer));
  PetscCall(PetscFree(rayctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode TSMonitorDMDARay(TS ts, PetscInt steps, PetscReal time, Vec u, void *mctx)
{
  TSMonitorDMDARayCtx *rayctx = (TSMonitorDMDARayCtx *)mctx;
  Vec                  solution;

  PetscFunctionBegin;
  PetscCall(TSGetSolution(ts, &solution));
  PetscCall(VecScatterBegin(rayctx->scatter, solution, rayctx->ray, INSERT_VALUES, SCATTER_FORWARD));
  PetscCall(VecScatterEnd(rayctx->scatter, solution, rayctx->ray, INSERT_VALUES, SCATTER_FORWARD));
  if (rayctx->viewer) PetscCall(VecView(rayctx->ray, rayctx->viewer));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode TSMonitorLGDMDARay(TS ts, PetscInt step, PetscReal ptime, Vec u, PetscCtx ctx)
{
  TSMonitorDMDARayCtx *rayctx = (TSMonitorDMDARayCtx *)ctx;
  TSMonitorLGCtx       lgctx  = rayctx->lgctx;
  Vec                  v      = rayctx->ray;
  const PetscScalar   *a;
  PetscInt             dim;

  PetscFunctionBegin;
  PetscCall(VecScatterBegin(rayctx->scatter, u, v, INSERT_VALUES, SCATTER_FORWARD));
  PetscCall(VecScatterEnd(rayctx->scatter, u, v, INSERT_VALUES, SCATTER_FORWARD));
  if (!step) {
    PetscDrawAxis axis;

    PetscCall(PetscDrawLGGetAxis(lgctx->lg, &axis));
    PetscCall(PetscDrawAxisSetLabels(axis, "Solution Ray as function of time", "Time", "Solution"));
    PetscCall(VecGetLocalSize(rayctx->ray, &dim));
    PetscCall(PetscDrawLGSetDimension(lgctx->lg, dim));
    PetscCall(PetscDrawLGReset(lgctx->lg));
  }
  PetscCall(VecGetArrayRead(v, &a));
#if defined(PETSC_USE_COMPLEX)
  {
    PetscReal *areal;
    PetscInt   i, n;
    PetscCall(VecGetLocalSize(v, &n));
    PetscCall(PetscMalloc1(n, &areal));
    for (i = 0; i < n; ++i) areal[i] = PetscRealPart(a[i]);
    PetscCall(PetscDrawLGAddCommonPoint(lgctx->lg, ptime, areal));
    PetscCall(PetscFree(areal));
  }
#else
  PetscCall(PetscDrawLGAddCommonPoint(lgctx->lg, ptime, a));
#endif
  PetscCall(VecRestoreArrayRead(v, &a));
  if (((lgctx->howoften > 0) && (!(step % lgctx->howoften))) || ((lgctx->howoften == -1) && ts->reason)) {
    PetscCall(PetscDrawLGDraw(lgctx->lg));
    PetscCall(PetscDrawLGSave(lgctx->lg));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}
