/*
  Code for timestepping with discrete gradient integrators
*/
#include <petsc/private/tsimpl.h> /*I   "petscts.h"   I*/
#include <petscdm.h>

PetscBool  DGCite       = PETSC_FALSE;
const char DGCitation[] = "@article{Gonzalez1996,\n"
                          "  title   = {Time integration and discrete Hamiltonian systems},\n"
                          "  author  = {Oscar Gonzalez},\n"
                          "  journal = {Journal of Nonlinear Science},\n"
                          "  volume  = {6},\n"
                          "  pages   = {449--467},\n"
                          "  doi     = {10.1007/978-1-4612-1246-1_10},\n"
                          "  year    = {1996}\n}\n";

const char *DGTypes[] = {"gonzalez", "average", "none", "TSDGType", "DG_", NULL};

typedef struct {
  PetscReal stage_time;
  Vec       X0, X, Xdot;
  void     *funcCtx;
  TSDGType  discgrad; /* Type of electrostatic model */
  PetscErrorCode (*Sfunc)(TS, PetscReal, Vec, Mat, void *);
  PetscErrorCode (*Ffunc)(TS, PetscReal, Vec, PetscScalar *, void *);
  PetscErrorCode (*Gfunc)(TS, PetscReal, Vec, Vec, void *);
} TS_DiscGrad;

static PetscErrorCode TSDiscGradGetX0AndXdot(TS ts, DM dm, Vec *X0, Vec *Xdot)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  if (X0) {
    if (dm && dm != ts->dm) PetscCall(DMGetNamedGlobalVector(dm, "TSDiscGrad_X0", X0));
    else *X0 = ts->vec_sol;
  }
  if (Xdot) {
    if (dm && dm != ts->dm) PetscCall(DMGetNamedGlobalVector(dm, "TSDiscGrad_Xdot", Xdot));
    else *Xdot = dg->Xdot;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDiscGradRestoreX0AndXdot(TS ts, DM dm, Vec *X0, Vec *Xdot)
{
  PetscFunctionBegin;
  if (X0) {
    if (dm && dm != ts->dm) PetscCall(DMRestoreNamedGlobalVector(dm, "TSDiscGrad_X0", X0));
  }
  if (Xdot) {
    if (dm && dm != ts->dm) PetscCall(DMRestoreNamedGlobalVector(dm, "TSDiscGrad_Xdot", Xdot));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode DMCoarsenHook_TSDiscGrad(DM fine, DM coarse, PetscCtx ctx)
{
  PetscFunctionBegin;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode DMRestrictHook_TSDiscGrad(DM fine, Mat restrct, Vec rscale, Mat inject, DM coarse, PetscCtx ctx)
{
  TS  ts = (TS)ctx;
  Vec X0, Xdot, X0_c, Xdot_c;

  PetscFunctionBegin;
  PetscCall(TSDiscGradGetX0AndXdot(ts, fine, &X0, &Xdot));
  PetscCall(TSDiscGradGetX0AndXdot(ts, coarse, &X0_c, &Xdot_c));
  PetscCall(MatRestrict(restrct, X0, X0_c));
  PetscCall(MatRestrict(restrct, Xdot, Xdot_c));
  PetscCall(VecPointwiseMult(X0_c, rscale, X0_c));
  PetscCall(VecPointwiseMult(Xdot_c, rscale, Xdot_c));
  PetscCall(TSDiscGradRestoreX0AndXdot(ts, fine, &X0, &Xdot));
  PetscCall(TSDiscGradRestoreX0AndXdot(ts, coarse, &X0_c, &Xdot_c));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode DMSubDomainHook_TSDiscGrad(DM dm, DM subdm, PetscCtx ctx)
{
  PetscFunctionBegin;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode DMSubDomainRestrictHook_TSDiscGrad(DM dm, VecScatter gscat, VecScatter lscat, DM subdm, PetscCtx ctx)
{
  TS  ts = (TS)ctx;
  Vec X0, Xdot, X0_sub, Xdot_sub;

  PetscFunctionBegin;
  PetscCall(TSDiscGradGetX0AndXdot(ts, dm, &X0, &Xdot));
  PetscCall(TSDiscGradGetX0AndXdot(ts, subdm, &X0_sub, &Xdot_sub));

  PetscCall(VecScatterBegin(gscat, X0, X0_sub, INSERT_VALUES, SCATTER_FORWARD));
  PetscCall(VecScatterEnd(gscat, X0, X0_sub, INSERT_VALUES, SCATTER_FORWARD));

  PetscCall(VecScatterBegin(gscat, Xdot, Xdot_sub, INSERT_VALUES, SCATTER_FORWARD));
  PetscCall(VecScatterEnd(gscat, Xdot, Xdot_sub, INSERT_VALUES, SCATTER_FORWARD));

  PetscCall(TSDiscGradRestoreX0AndXdot(ts, dm, &X0, &Xdot));
  PetscCall(TSDiscGradRestoreX0AndXdot(ts, subdm, &X0_sub, &Xdot_sub));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSSetUp_DiscGrad(TS ts)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;
  DM           dm;

  PetscFunctionBegin;
  if (!dg->X) PetscCall(VecDuplicate(ts->vec_sol, &dg->X));
  if (!dg->X0) PetscCall(VecDuplicate(ts->vec_sol, &dg->X0));
  if (!dg->Xdot) PetscCall(VecDuplicate(ts->vec_sol, &dg->Xdot));

  PetscCall(TSGetDM(ts, &dm));
  PetscCall(DMCoarsenHookAdd(dm, DMCoarsenHook_TSDiscGrad, DMRestrictHook_TSDiscGrad, ts));
  PetscCall(DMSubDomainHookAdd(dm, DMSubDomainHook_TSDiscGrad, DMSubDomainRestrictHook_TSDiscGrad, ts));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSSetFromOptions_DiscGrad(TS ts, PetscOptionItems PetscOptionsObject)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  PetscOptionsHeadBegin(PetscOptionsObject, "Discrete Gradients ODE solver options");
  {
    PetscCall(PetscOptionsEnum("-ts_discgrad_type", "Type of discrete gradient solver", "TSDiscGradSetDGType", DGTypes, (PetscEnum)dg->discgrad, (PetscEnum *)&dg->discgrad, NULL));
  }
  PetscOptionsHeadEnd();
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSView_DiscGrad(TS ts, PetscViewer viewer)
{
  PetscBool isascii;

  PetscFunctionBegin;
  PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
  if (isascii) PetscCall(PetscViewerASCIIPrintf(viewer, "  Discrete Gradients\n"));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDiscGradGetType_DiscGrad(TS ts, TSDGType *dgtype)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  *dgtype = dg->discgrad;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDiscGradSetType_DiscGrad(TS ts, TSDGType dgtype)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  dg->discgrad = dgtype;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSReset_DiscGrad(TS ts)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  PetscCall(VecDestroy(&dg->X));
  PetscCall(VecDestroy(&dg->X0));
  PetscCall(VecDestroy(&dg->Xdot));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDestroy_DiscGrad(TS ts)
{
  DM dm;

  PetscFunctionBegin;
  PetscCall(TSReset_DiscGrad(ts));
  PetscCall(TSGetDM(ts, &dm));
  if (dm) {
    PetscCall(DMCoarsenHookRemove(dm, DMCoarsenHook_TSDiscGrad, DMRestrictHook_TSDiscGrad, ts));
    PetscCall(DMSubDomainHookRemove(dm, DMSubDomainHook_TSDiscGrad, DMSubDomainRestrictHook_TSDiscGrad, ts));
  }
  PetscCall(PetscFree(ts->data));
  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradGetFormulation_C", NULL));
  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradSetFormulation_C", NULL));
  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradGetType_C", NULL));
  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradSetType_C", NULL));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSInterpolate_DiscGrad(TS ts, PetscReal t, Vec X)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;
  PetscReal    dt = t - ts->ptime;

  PetscFunctionBegin;
  PetscCall(VecCopy(ts->vec_sol, dg->X));
  PetscCall(VecWAXPY(X, dt, dg->Xdot, dg->X));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDiscGrad_SNESSolve(TS ts, Vec b, Vec x)
{
  SNES     snes;
  PetscInt nits, lits;

  PetscFunctionBegin;
  PetscCall(TSGetSNES(ts, &snes));
  PetscCall(SNESSolve(snes, b, x));
  PetscCall(SNESGetIterationNumber(snes, &nits));
  PetscCall(SNESGetLinearSolveIterations(snes, &lits));
  ts->snes_its += nits;
  ts->ksp_its += lits;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSStep_DiscGrad(TS ts)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;
  TSAdapt      adapt;
  TSStepStatus status     = TS_STEP_INCOMPLETE;
  PetscInt     rejections = 0;
  PetscBool    stageok, accept = PETSC_TRUE;
  PetscReal    next_time_step = ts->time_step;

  PetscFunctionBegin;
  PetscCall(TSGetAdapt(ts, &adapt));
  if (!ts->steprollback) PetscCall(VecCopy(ts->vec_sol, dg->X0));

  while (!ts->reason && status != TS_STEP_COMPLETE) {
    PetscReal shift = 1 / (0.5 * ts->time_step);

    dg->stage_time = ts->ptime + 0.5 * ts->time_step;

    PetscCall(VecCopy(dg->X0, dg->X));
    PetscCall(TSPreStage(ts, dg->stage_time));
    PetscCall(TSDiscGrad_SNESSolve(ts, NULL, dg->X));
    PetscCall(TSPostStage(ts, dg->stage_time, 0, &dg->X));
    PetscCall(TSAdaptCheckStage(adapt, ts, dg->stage_time, dg->X, &stageok));
    if (!stageok) goto reject_step;

    status = TS_STEP_PENDING;
    PetscCall(VecAXPBYPCZ(dg->Xdot, -shift, shift, 0, dg->X0, dg->X));
    PetscCall(VecAXPY(ts->vec_sol, ts->time_step, dg->Xdot));
    PetscCall(TSAdaptChoose(adapt, ts, ts->time_step, NULL, &next_time_step, &accept));
    status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
    if (!accept) {
      PetscCall(VecCopy(dg->X0, ts->vec_sol));
      ts->time_step = next_time_step;
      goto reject_step;
    }
    ts->ptime += ts->time_step;
    ts->time_step = next_time_step;
    break;

  reject_step:
    ts->reject++;
    accept = PETSC_FALSE;
    if (!ts->reason && ts->max_reject >= 0 && ++rejections > ts->max_reject) {
      ts->reason = TS_DIVERGED_STEP_REJECTED;
      PetscCall(PetscInfo(ts, "Step=%" PetscInt_FMT ", step rejections %" PetscInt_FMT " greater than current TS allowed, stopping solve\n", ts->steps, rejections));
    }
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSGetStages_DiscGrad(TS ts, PetscInt *ns, Vec **Y)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  if (ns) *ns = 1;
  if (Y) *Y = &dg->X;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*
  This defines the nonlinear equation that is to be solved with SNES
    G(U) = F[t0 + 0.5*dt, U, (U-U0)/dt] = 0
*/

/* x = (x+x')/2 */
/* NEED TO CALCULATE x_{n+1} from x and x_{n}*/
static PetscErrorCode SNESTSFormFunction_DiscGrad(SNES snes, Vec x, Vec y, TS ts)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;
  PetscReal    norm, shift = 1 / (0.5 * ts->time_step);
  PetscInt     n, dim;
  Vec          X0, Xdot, Xp, Xdiff;
  Mat          S;
  PetscScalar  F = 0, F0 = 0, Gp;
  Vec          G, SgF;
  DM           dm, dmsave;

  PetscFunctionBegin;
  PetscCall(SNESGetDM(snes, &dm));
  PetscCall(DMGetDimension(dm, &dim));

  PetscCall(VecDuplicate(y, &Xp));
  PetscCall(VecDuplicate(y, &Xdiff));
  PetscCall(VecDuplicate(y, &SgF));
  PetscCall(VecDuplicate(y, &G));

  PetscCall(PetscObjectSetName((PetscObject)x, "x"));
  PetscCall(VecViewFromOptions(x, NULL, "-x_view"));

  PetscCall(VecGetLocalSize(y, &n));
  PetscCall(MatCreate(PETSC_COMM_WORLD, &S));
  PetscCall(MatSetSizes(S, n, n, PETSC_DECIDE, PETSC_DECIDE));
  PetscCall(MatSetFromOptions(S));
  PetscInt *S_prealloc_arr;
  PetscCall(PetscMalloc1(n, &S_prealloc_arr));
  for (PetscInt i = 0; i < n; ++i) S_prealloc_arr[i] = 2;
  PetscCall(MatXAIJSetPreallocation(S, 1, S_prealloc_arr, NULL, NULL, NULL));
  PetscCall(MatSetUp(S));
  PetscCall((*dg->Sfunc)(ts, dg->stage_time, x, S, dg->funcCtx));
  PetscCall(PetscFree(S_prealloc_arr));
  PetscCall(PetscObjectSetName((PetscObject)S, "S"));
  PetscCall(MatViewFromOptions(S, NULL, "-S_view"));
  PetscCall(TSDiscGradGetX0AndXdot(ts, dm, &X0, &Xdot));
  PetscCall(VecAXPBYPCZ(Xdot, -shift, shift, 0, X0, x)); /* Xdot = shift (x - X0) */

  PetscCall(VecAXPBYPCZ(Xp, -1, 2, 0, X0, x));     /* Xp = 2*x - X0 + (0)*Xp */
  PetscCall(VecAXPBYPCZ(Xdiff, -1, 1, 0, X0, Xp)); /* Xdiff = xp - X0 + (0)*Xdiff */

  PetscCall(PetscObjectSetName((PetscObject)X0, "X0"));
  PetscCall(PetscObjectSetName((PetscObject)Xp, "Xp"));
  PetscCall(VecViewFromOptions(X0, NULL, "-X0_view"));
  PetscCall(VecViewFromOptions(Xp, NULL, "-Xp_view"));

  if (dg->discgrad == TS_DG_AVERAGE) {
    /* Average Value DG:
    \overline{\nabla} F (x_{n+1},x_{n}) = \int_0^1 \nabla F ((1-\xi)*x_{n+1} + \xi*x_{n}) d \xi */
    PetscQuadrature  quad;
    PetscInt         Nq;
    const PetscReal *wq, *xq;
    Vec              Xquad, den;

    PetscCall(PetscObjectSetName((PetscObject)G, "G"));
    PetscCall(VecDuplicate(G, &Xquad));
    PetscCall(VecDuplicate(G, &den));
    PetscCall(VecZeroEntries(G));

    /* \overline{\nabla} F = \nabla F ((1-\xi) x_{n} + \xi x_{n+1})*/
    PetscCall(PetscDTGaussTensorQuadrature(dim, 1, 2, 0.0, 1.0, &quad));
    PetscCall(PetscQuadratureGetData(quad, NULL, NULL, &Nq, &xq, &wq));
    for (PetscInt q = 0; q < Nq; ++q) {
      PetscReal xi = xq[q], xim1 = 1 - xq[q];
      PetscCall(VecZeroEntries(Xquad));
      PetscCall(VecAXPBYPCZ(Xquad, xi, xim1, 1.0, X0, Xp));
      PetscCall((*dg->Gfunc)(ts, dg->stage_time, Xquad, den, dg->funcCtx));
      PetscCall(VecAXPY(G, wq[q], den));
      PetscCall(PetscObjectSetName((PetscObject)den, "den"));
      PetscCall(VecViewFromOptions(den, NULL, "-den_view"));
    }
    PetscCall(VecDestroy(&Xquad));
    PetscCall(VecDestroy(&den));
    PetscCall(PetscQuadratureDestroy(&quad));
  } else if (dg->discgrad == TS_DG_GONZALEZ) {
    PetscCall((*dg->Ffunc)(ts, dg->stage_time, Xp, &F, dg->funcCtx));
    PetscCall((*dg->Ffunc)(ts, dg->stage_time, X0, &F0, dg->funcCtx));
    PetscCall((*dg->Gfunc)(ts, dg->stage_time, x, G, dg->funcCtx));

    /* Adding Extra Gonzalez Term */
    PetscCall(VecDot(Xdiff, G, &Gp));
    PetscCall(VecNorm(Xdiff, NORM_2, &norm));
    if (norm < PETSC_SQRT_MACHINE_EPSILON) {
      Gp = 0;
    } else {
      /* Gp = (1/|xn+1 - xn|^2) * (F(xn+1) - F(xn) - Gp) */
      Gp = (F - F0 - Gp) / PetscSqr(norm);
    }
    PetscCall(VecAXPY(G, Gp, Xdiff));
  } else if (dg->discgrad == TS_DG_NONE) {
    PetscCall((*dg->Gfunc)(ts, dg->stage_time, x, G, dg->funcCtx));
  } else {
    SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "DG type not supported.");
  }
  PetscCall(MatMult(S, G, SgF)); /* Xdot = S*gradF */

  PetscCall(PetscObjectSetName((PetscObject)G, "G"));
  PetscCall(VecViewFromOptions(G, NULL, "-G_view"));
  PetscCall(PetscObjectSetName((PetscObject)SgF, "SgF"));
  PetscCall(VecViewFromOptions(SgF, NULL, "-SgF_view"));
  /* DM monkey-business allows user code to call TSGetDM() inside of functions evaluated on levels of FAS */
  dmsave = ts->dm;
  ts->dm = dm;
  PetscCall(VecAXPBYPCZ(y, 1, -1, 0, Xdot, SgF));

  ts->dm = dmsave;
  PetscCall(TSDiscGradRestoreX0AndXdot(ts, dm, &X0, &Xdot));

  PetscCall(VecDestroy(&Xp));
  PetscCall(VecDestroy(&Xdiff));
  PetscCall(VecDestroy(&SgF));
  PetscCall(VecDestroy(&G));
  PetscCall(MatDestroy(&S));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode SNESTSFormJacobian_DiscGrad(SNES snes, Vec x, Mat A, Mat B, TS ts)
{
  TS_DiscGrad *dg    = (TS_DiscGrad *)ts->data;
  PetscReal    shift = 1 / (0.5 * ts->time_step);
  Vec          Xdot;
  DM           dm, dmsave;

  PetscFunctionBegin;
  PetscCall(SNESGetDM(snes, &dm));
  /* Xdot has already been computed in SNESTSFormFunction_DiscGrad (SNES guarantees this) */
  PetscCall(TSDiscGradGetX0AndXdot(ts, dm, NULL, &Xdot));

  dmsave = ts->dm;
  ts->dm = dm;
  PetscCall(TSComputeIJacobian(ts, dg->stage_time, x, Xdot, shift, A, B, PETSC_FALSE));
  ts->dm = dmsave;
  PetscCall(TSDiscGradRestoreX0AndXdot(ts, dm, NULL, &Xdot));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDiscGradGetFormulation_DiscGrad(TS ts, PetscErrorCode (**Sfunc)(TS, PetscReal, Vec, Mat, void *), PetscErrorCode (**Ffunc)(TS, PetscReal, Vec, PetscScalar *, void *), PetscErrorCode (**Gfunc)(TS, PetscReal, Vec, Vec, void *), PetscCtx ctx)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  *Sfunc = dg->Sfunc;
  *Ffunc = dg->Ffunc;
  *Gfunc = dg->Gfunc;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDiscGradSetFormulation_DiscGrad(TS ts, PetscErrorCode (*Sfunc)(TS, PetscReal, Vec, Mat, void *), PetscErrorCode (*Ffunc)(TS, PetscReal, Vec, PetscScalar *, void *), PetscErrorCode (*Gfunc)(TS, PetscReal, Vec, Vec, void *), PetscCtx ctx)
{
  TS_DiscGrad *dg = (TS_DiscGrad *)ts->data;

  PetscFunctionBegin;
  dg->Sfunc   = Sfunc;
  dg->Ffunc   = Ffunc;
  dg->Gfunc   = Gfunc;
  dg->funcCtx = ctx;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*MC
  TSDISCGRAD - ODE solver using the discrete gradients version of the implicit midpoint method

  Level: intermediate

  Notes:
  This is the implicit midpoint rule, with an optional term that guarantees the discrete
  gradient property. This timestepper applies to systems of the form $u_t = S(u) \nabla F(u)$
  where $S(u)$ is a linear operator, and $F$ is a functional of $u$.

.seealso: [](ch_ts), `TSCreate()`, `TSSetType()`, `TS`, `TSDISCGRAD`, `TSDiscGradSetFormulation()`
M*/
PETSC_EXTERN PetscErrorCode TSCreate_DiscGrad(TS ts)
{
  TS_DiscGrad *th;

  PetscFunctionBegin;
  PetscCall(PetscCitationsRegister(DGCitation, &DGCite));
  ts->ops->reset          = TSReset_DiscGrad;
  ts->ops->destroy        = TSDestroy_DiscGrad;
  ts->ops->view           = TSView_DiscGrad;
  ts->ops->setfromoptions = TSSetFromOptions_DiscGrad;
  ts->ops->setup          = TSSetUp_DiscGrad;
  ts->ops->step           = TSStep_DiscGrad;
  ts->ops->interpolate    = TSInterpolate_DiscGrad;
  ts->ops->getstages      = TSGetStages_DiscGrad;
  ts->ops->snesfunction   = SNESTSFormFunction_DiscGrad;
  ts->ops->snesjacobian   = SNESTSFormJacobian_DiscGrad;
  ts->default_adapt_type  = TSADAPTNONE;

  ts->usessnes = PETSC_TRUE;

  PetscCall(PetscNew(&th));
  ts->data = (void *)th;

  th->discgrad = TS_DG_NONE;

  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradGetFormulation_C", TSDiscGradGetFormulation_DiscGrad));
  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradSetFormulation_C", TSDiscGradSetFormulation_DiscGrad));
  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradGetType_C", TSDiscGradGetType_DiscGrad));
  PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSDiscGradSetType_C", TSDiscGradSetType_DiscGrad));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSDiscGradGetFormulation - Get the construction method for S, F, and grad F from the
  formulation $u_t = S \nabla F$ for `TSDISCGRAD`

  Not Collective

  Input Parameter:
. ts - timestepping context

  Output Parameters:
+ Sfunc - constructor for the S matrix from the formulation
. Ffunc - functional F from the formulation
. Gfunc - constructor for the gradient of F from the formulation
- ctx   - the user context

  Calling sequence of `Sfunc`:
+ ts   - the integrator
. time - the current time
. u    - the solution
. S    - the S-matrix from the formulation
- ctx  - the user context

  Calling sequence of `Ffunc`:
+ ts   - the integrator
. time - the current time
. u    - the solution
. F    - the computed function from the formulation
- ctx  - the user context

  Calling sequence of `Gfunc`:
+ ts   - the integrator
. time - the current time
. u    - the solution
. G    - the gradient of the computed function from the formulation
- ctx  - the user context

  Level: intermediate

.seealso: [](ch_ts), `TS`, `TSDISCGRAD`, `TSDiscGradSetFormulation()`
@*/
PetscErrorCode TSDiscGradGetFormulation(TS ts, PetscErrorCode (**Sfunc)(TS ts, PetscReal time, Vec u, Mat S, PetscCtx ctx), PetscErrorCode (**Ffunc)(TS ts, PetscReal time, Vec u, PetscScalar *F, PetscCtx ctx), PetscErrorCode (**Gfunc)(TS ts, PetscReal time, Vec u, Vec G, PetscCtx ctx), PetscCtx ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(Sfunc, 2);
  PetscAssertPointer(Ffunc, 3);
  PetscAssertPointer(Gfunc, 4);
  PetscUseMethod(ts, "TSDiscGradGetFormulation_C", (TS, PetscErrorCode (**Sfunc)(TS, PetscReal, Vec, Mat, void *), PetscErrorCode (**Ffunc)(TS, PetscReal, Vec, PetscScalar *, void *), PetscErrorCode (**Gfunc)(TS, PetscReal, Vec, Vec, void *), void *), (ts, Sfunc, Ffunc, Gfunc, ctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSDiscGradSetFormulation - Set the construction method for S, F, and grad F from the
  formulation $u_t = S(u) \nabla F(u)$ for `TSDISCGRAD`

  Not Collective

  Input Parameters:
+ ts    - timestepping context
. Sfunc - constructor for the S matrix from the formulation
. Ffunc - functional F from the formulation
. Gfunc - constructor for the gradient of F from the formulation
- ctx   - optional context for the functions

  Calling sequence of `Sfunc`:
+ ts   - the integrator
. time - the current time
. u    - the solution
. S    - the S-matrix from the formulation
- ctx  - the user context

  Calling sequence of `Ffunc`:
+ ts   - the integrator
. time - the current time
. u    - the solution
. F    - the computed function from the formulation
- ctx  - the user context

  Calling sequence of `Gfunc`:
+ ts   - the integrator
. time - the current time
. u    - the solution
. G    - the gradient of the computed function from the formulation
- ctx  - the user context

  Level: intermediate

.seealso: [](ch_ts), `TSDISCGRAD`, `TSDiscGradGetFormulation()`
@*/
PetscErrorCode TSDiscGradSetFormulation(TS ts, PetscErrorCode (*Sfunc)(TS ts, PetscReal time, Vec u, Mat S, PetscCtx ctx), PetscErrorCode (*Ffunc)(TS ts, PetscReal time, Vec u, PetscScalar *F, PetscCtx ctx), PetscErrorCode (*Gfunc)(TS ts, PetscReal time, Vec u, Vec G, PetscCtx ctx), PetscCtx ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidFunction(Sfunc, 2);
  PetscValidFunction(Ffunc, 3);
  PetscValidFunction(Gfunc, 4);
  PetscTryMethod(ts, "TSDiscGradSetFormulation_C", (TS, PetscErrorCode (*Sfunc)(TS, PetscReal, Vec, Mat, void *), PetscErrorCode (*Ffunc)(TS, PetscReal, Vec, PetscScalar *, void *), PetscErrorCode (*Gfunc)(TS, PetscReal, Vec, Vec, void *), void *), (ts, Sfunc, Ffunc, Gfunc, ctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSDiscGradGetType - Checks for which discrete gradient to use in formulation for `TSDISCGRAD`

  Not Collective

  Input Parameter:
. ts - timestepping context

  Output Parameter:
. dgtype - Discrete gradient type <none, gonzalez, average>

  Level: advanced

.seealso: [](ch_ts), `TSDISCGRAD`, `TSDiscGradSetType()`
@*/
PetscErrorCode TSDiscGradGetType(TS ts, TSDGType *dgtype)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(dgtype, 2);
  PetscUseMethod(ts, "TSDiscGradGetType_C", (TS, TSDGType *), (ts, dgtype));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSDiscGradSetType - Sets discrete gradient formulation.

  Not Collective

  Input Parameters:
+ ts     - timestepping context
- dgtype - Discrete gradient type <none, gonzalez, average>

  Options Database Key:
. -ts_discgrad_type <type> - flag to choose discrete gradient type

  Level: intermediate

  Notes:
  Without `dgtype` or with type `none`, the discrete gradients timestepper is just implicit midpoint.

.seealso: [](ch_ts), `TSDISCGRAD`
@*/
PetscErrorCode TSDiscGradSetType(TS ts, TSDGType dgtype)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscTryMethod(ts, "TSDiscGradSetType_C", (TS, TSDGType), (ts, dgtype));
  PetscFunctionReturn(PETSC_SUCCESS);
}
