#include <petsc/private/tsimpl.h> /*I <petscts.h>  I*/
#include <petscdraw.h>

PetscLogEvent TS_AdjointStep, TS_ForwardStep, TS_JacobianPEval;

/* #define TSADJOINT_STAGE */

/* ------------------------ Sensitivity Context ---------------------------*/

/*@C
  TSSetRHSJacobianP - Sets the function that computes the Jacobian of G w.r.t. the parameters P where U_t = G(U,P,t), as well as the location to store the matrix.

  Logically Collective

  Input Parameters:
+ ts   - `TS` context obtained from `TSCreate()`
. Amat - JacobianP matrix
. func - function
- ctx  - [optional] user-defined function context

  Level: intermediate

  Note:
  `Amat` has the same number of rows and the same row parallel layout as `u`, `Amat` has the same number of columns and parallel layout as `p`

.seealso: [](ch_ts), `TS`, `TSRHSJacobianP`, `TSGetRHSJacobianP()`
@*/
PetscErrorCode TSSetRHSJacobianP(TS ts, Mat Amat, TSRHSJacobianP func, void *ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(Amat, MAT_CLASSID, 2);

  ts->rhsjacobianp    = func;
  ts->rhsjacobianpctx = ctx;
  if (Amat) {
    PetscCall(PetscObjectReference((PetscObject)Amat));
    PetscCall(MatDestroy(&ts->Jacprhs));
    ts->Jacprhs = Amat;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSGetRHSJacobianP - Gets the function that computes the Jacobian of G w.r.t. the parameters P where U_t = G(U,P,t), as well as the location to store the matrix.

  Logically Collective

  Input Parameter:
. ts - `TS` context obtained from `TSCreate()`

  Output Parameters:
+ Amat - JacobianP matrix
. func - function
- ctx  - [optional] user-defined function context

  Level: intermediate

  Note:
  `Amat` has the same number of rows and the same row parallel layout as `u`, `Amat` has the same number of columns and parallel layout as `p`

.seealso: [](ch_ts), `TSSetRHSJacobianP()`, `TS`, `TSRHSJacobianP`
@*/
PetscErrorCode TSGetRHSJacobianP(TS ts, Mat *Amat, TSRHSJacobianP *func, void **ctx)
{
  PetscFunctionBegin;
  if (func) *func = ts->rhsjacobianp;
  if (ctx) *ctx = ts->rhsjacobianpctx;
  if (Amat) *Amat = ts->Jacprhs;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeRHSJacobianP - Runs the user-defined JacobianP function.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
- U  - the solution at which to compute the Jacobian

  Output Parameter:
. Amat - the computed Jacobian

  Level: developer

.seealso: [](ch_ts), `TSSetRHSJacobianP()`, `TS`
@*/
PetscErrorCode TSComputeRHSJacobianP(TS ts, PetscReal t, Vec U, Mat Amat)
{
  PetscFunctionBegin;
  if (!Amat) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  if (ts->rhsjacobianp) PetscCallBack("TS callback JacobianP for sensitivity analysis", (*ts->rhsjacobianp)(ts, t, U, Amat, ts->rhsjacobianpctx));
  else {
    PetscBool assembled;
    PetscCall(MatZeroEntries(Amat));
    PetscCall(MatAssembled(Amat, &assembled));
    if (!assembled) {
      PetscCall(MatAssemblyBegin(Amat, MAT_FINAL_ASSEMBLY));
      PetscCall(MatAssemblyEnd(Amat, MAT_FINAL_ASSEMBLY));
    }
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSSetIJacobianP - Sets the function that computes the Jacobian of F w.r.t. the parameters P where F(Udot,U,t) = G(U,P,t), as well as the location to store the matrix.

  Logically Collective

  Input Parameters:
+ ts   - `TS` context obtained from `TSCreate()`
. Amat - JacobianP matrix
. func - function
- ctx  - [optional] user-defined function context

  Calling sequence of `func`:
+ ts    - the `TS` context
. t     - current timestep
. U     - input vector (current ODE solution)
. Udot  - time derivative of state vector
. shift - shift to apply, see note below
. A     - output matrix
- ctx   - [optional] user-defined function context

  Level: intermediate

  Note:
  Amat has the same number of rows and the same row parallel layout as u, Amat has the same number of columns and parallel layout as p

.seealso: [](ch_ts), `TSSetRHSJacobianP()`, `TS`
@*/
PetscErrorCode TSSetIJacobianP(TS ts, Mat Amat, PetscErrorCode (*func)(TS ts, PetscReal t, Vec U, Vec Udot, PetscReal shift, Mat A, void *ctx), void *ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(Amat, MAT_CLASSID, 2);

  ts->ijacobianp    = func;
  ts->ijacobianpctx = ctx;
  if (Amat) {
    PetscCall(PetscObjectReference((PetscObject)Amat));
    PetscCall(MatDestroy(&ts->Jacp));
    ts->Jacp = Amat;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeIJacobianP - Runs the user-defined IJacobianP function.

  Collective

  Input Parameters:
+ ts    - the `TS` context
. t     - current timestep
. U     - state vector
. Udot  - time derivative of state vector
. shift - shift to apply, see note below
- imex  - flag indicates if the method is IMEX so that the RHSJacobianP should be kept separate

  Output Parameter:
. Amat - Jacobian matrix

  Level: developer

.seealso: [](ch_ts), `TS`, `TSSetIJacobianP()`
@*/
PetscErrorCode TSComputeIJacobianP(TS ts, PetscReal t, Vec U, Vec Udot, PetscReal shift, Mat Amat, PetscBool imex)
{
  PetscFunctionBegin;
  if (!Amat) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);
  PetscValidHeaderSpecific(Udot, VEC_CLASSID, 4);

  PetscCall(PetscLogEventBegin(TS_JacobianPEval, ts, U, Amat, 0));
  if (ts->ijacobianp) PetscCallBack("TS callback JacobianP for sensitivity analysis", (*ts->ijacobianp)(ts, t, U, Udot, shift, Amat, ts->ijacobianpctx));
  if (imex) {
    if (!ts->ijacobianp) { /* system was written as Udot = G(t,U) */
      PetscBool assembled;
      PetscCall(MatZeroEntries(Amat));
      PetscCall(MatAssembled(Amat, &assembled));
      if (!assembled) {
        PetscCall(MatAssemblyBegin(Amat, MAT_FINAL_ASSEMBLY));
        PetscCall(MatAssemblyEnd(Amat, MAT_FINAL_ASSEMBLY));
      }
    }
  } else {
    if (ts->rhsjacobianp) PetscCall(TSComputeRHSJacobianP(ts, t, U, ts->Jacprhs));
    if (ts->Jacprhs == Amat) { /* No IJacobian, so we only have the RHS matrix */
      PetscCall(MatScale(Amat, -1));
    } else if (ts->Jacprhs) { /* Both IJacobian and RHSJacobian */
      MatStructure axpy = DIFFERENT_NONZERO_PATTERN;
      if (!ts->ijacobianp) { /* No IJacobianp provided, but we have a separate RHS matrix */
        PetscCall(MatZeroEntries(Amat));
      }
      PetscCall(MatAXPY(Amat, -1, ts->Jacprhs, axpy));
    }
  }
  PetscCall(PetscLogEventEnd(TS_JacobianPEval, ts, U, Amat, 0));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSSetCostIntegrand - Sets the routine for evaluating the integral term in one or more cost functions

  Logically Collective

  Input Parameters:
+ ts           - the `TS` context obtained from `TSCreate()`
. numcost      - number of gradients to be computed, this is the number of cost functions
. costintegral - vector that stores the integral values
. rf           - routine for evaluating the integrand function
. drduf        - function that computes the gradients of the r's with respect to u
. drdpf        - function that computes the gradients of the r's with respect to p, can be `NULL` if parametric sensitivity is not desired (`mu` = `NULL`)
. fwd          - flag indicating whether to evaluate cost integral in the forward run or the adjoint run
- ctx          - [optional] user-defined context for private data for the function evaluation routine (may be `NULL`)

  Calling sequence of `rf`:
$   PetscErrorCode rf(TS ts, PetscReal t, Vec U, Vec F, oid *ctx)

  Calling sequence of `drduf`:
$   PetscErroCode drduf(TS ts, PetscReal t, Vec U, Vec *dRdU, void *ctx)

  Calling sequence of `drdpf`:
$   PetscErroCode drdpf(TS ts, PetscReal t, Vec U, Vec *dRdP, void *ctx)

  Level: deprecated

  Note:
  For optimization there is usually a single cost function (numcost = 1). For sensitivities there may be multiple cost functions

.seealso: [](ch_ts), `TS`, `TSSetRHSJacobianP()`, `TSGetCostGradients()`, `TSSetCostGradients()`
@*/
PetscErrorCode TSSetCostIntegrand(TS ts, PetscInt numcost, Vec costintegral, PetscErrorCode (*rf)(TS, PetscReal, Vec, Vec, void *), PetscErrorCode (*drduf)(TS, PetscReal, Vec, Vec *, void *), PetscErrorCode (*drdpf)(TS, PetscReal, Vec, Vec *, void *), PetscBool fwd, void *ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  if (costintegral) PetscValidHeaderSpecific(costintegral, VEC_CLASSID, 3);
  PetscCheck(!ts->numcost || ts->numcost == numcost, PetscObjectComm((PetscObject)ts), PETSC_ERR_USER, "The number of cost functions (2nd parameter of TSSetCostIntegrand()) is inconsistent with the one set by TSSetCostGradients() or TSForwardSetIntegralGradients()");
  if (!ts->numcost) ts->numcost = numcost;

  if (costintegral) {
    PetscCall(PetscObjectReference((PetscObject)costintegral));
    PetscCall(VecDestroy(&ts->vec_costintegral));
    ts->vec_costintegral = costintegral;
  } else {
    if (!ts->vec_costintegral) { /* Create a seq vec if user does not provide one */
      PetscCall(VecCreateSeq(PETSC_COMM_SELF, numcost, &ts->vec_costintegral));
    } else {
      PetscCall(VecSet(ts->vec_costintegral, 0.0));
    }
  }
  if (!ts->vec_costintegrand) {
    PetscCall(VecDuplicate(ts->vec_costintegral, &ts->vec_costintegrand));
  } else {
    PetscCall(VecSet(ts->vec_costintegrand, 0.0));
  }
  ts->costintegralfwd  = fwd; /* Evaluate the cost integral in forward run if fwd is true */
  ts->costintegrand    = rf;
  ts->costintegrandctx = ctx;
  ts->drdufunction     = drduf;
  ts->drdpfunction     = drdpf;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSGetCostIntegral - Returns the values of the integral term in the cost functions.
  It is valid to call the routine after a backward run.

  Not Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Output Parameter:
. v - the vector containing the integrals for each cost function

  Level: intermediate

.seealso: [](ch_ts), `TS`, `TSAdjointSolve()`, ``TSSetCostIntegrand()`
@*/
PetscErrorCode TSGetCostIntegral(TS ts, Vec *v)
{
  TS quadts;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(v, 2);
  PetscCall(TSGetQuadratureTS(ts, NULL, &quadts));
  *v = quadts->vec_sol;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeCostIntegrand - Evaluates the integral function in the cost functions.

  Input Parameters:
+ ts - the `TS` context
. t  - current time
- U  - state vector, i.e. current solution

  Output Parameter:
. Q - vector of size numcost to hold the outputs

  Level: deprecated

  Note:
  Most users should not need to explicitly call this routine, as it
  is used internally within the sensitivity analysis context.

.seealso: [](ch_ts), `TS`, `TSAdjointSolve()`, `TSSetCostIntegrand()`
@*/
PetscErrorCode TSComputeCostIntegrand(TS ts, PetscReal t, Vec U, Vec Q)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);
  PetscValidHeaderSpecific(Q, VEC_CLASSID, 4);

  PetscCall(PetscLogEventBegin(TS_FunctionEval, ts, U, Q, 0));
  if (ts->costintegrand) PetscCallBack("TS callback integrand in the cost function", (*ts->costintegrand)(ts, t, U, Q, ts->costintegrandctx));
  else PetscCall(VecZeroEntries(Q));
  PetscCall(PetscLogEventEnd(TS_FunctionEval, ts, U, Q, 0));
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-*
/*@C
  TSComputeDRDUFunction - Deprecated, use `TSGetQuadratureTS()` then `TSComputeRHSJacobian()`

  Level: deprecated

@*/
PetscErrorCode TSComputeDRDUFunction(TS ts, PetscReal t, Vec U, Vec *DRDU)
{
  PetscFunctionBegin;
  if (!DRDU) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback DRDU for sensitivity analysis", (*ts->drdufunction)(ts, t, U, DRDU, ts->costintegrandctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-*
/*@C
  TSComputeDRDPFunction - Deprecated, use `TSGetQuadratureTS()` then `TSComputeRHSJacobianP()`

  Level: deprecated

@*/
PetscErrorCode TSComputeDRDPFunction(TS ts, PetscReal t, Vec U, Vec *DRDP)
{
  PetscFunctionBegin;
  if (!DRDP) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback DRDP for sensitivity analysis", (*ts->drdpfunction)(ts, t, U, DRDP, ts->costintegrandctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-param-list-func-parameter-documentation
// PetscClangLinter pragma disable: -fdoc-section-header-unknown
/*@C
  TSSetIHessianProduct - Sets the function that computes the vector-Hessian-vector product. The Hessian is the second-order derivative of F (IFunction) w.r.t. the state variable.

  Logically Collective

  Input Parameters:
+ ts   - `TS` context obtained from `TSCreate()`
. ihp1 - an array of vectors storing the result of vector-Hessian-vector product for F_UU
. hessianproductfunc1 - vector-Hessian-vector product function for F_UU
. ihp2 - an array of vectors storing the result of vector-Hessian-vector product for F_UP
. hessianproductfunc2 - vector-Hessian-vector product function for F_UP
. ihp3 - an array of vectors storing the result of vector-Hessian-vector product for F_PU
. hessianproductfunc3 - vector-Hessian-vector product function for F_PU
. ihp4 - an array of vectors storing the result of vector-Hessian-vector product for F_PP
- hessianproductfunc4 - vector-Hessian-vector product function for F_PP

  Calling sequence of `ihessianproductfunc1`:
+ ts  - the `TS` context
+ t   - current timestep
. U   - input vector (current ODE solution)
. Vl  - an array of input vectors to be left-multiplied with the Hessian
. Vr  - input vector to be right-multiplied with the Hessian
. VHV - an array of output vectors for vector-Hessian-vector product
- ctx - [optional] user-defined function context

  Level: intermediate

  Notes:
  All other functions have the same calling sequence as `rhhessianproductfunc1`, so their
  descriptions are omitted for brevity.

  The first Hessian function and the working array are required.
  As an example to implement the callback functions, the second callback function calculates the vector-Hessian-vector product
  $ Vl_n^T*F_UP*Vr
  where the vector Vl_n (n-th element in the array Vl) and Vr are of size N and M respectively, and the Hessian F_UP is of size N x N x M.
  Each entry of F_UP corresponds to the derivative
  $ F_UP[i][j][k] = \frac{\partial^2 F[i]}{\partial U[j] \partial P[k]}.
  The result of the vector-Hessian-vector product for Vl_n needs to be stored in vector VHV_n with the j-th entry being
  $ VHV_n[j] = \sum_i \sum_k {Vl_n[i] * F_UP[i][j][k] * Vr[k]}
  If the cost function is a scalar, there will be only one vector in Vl and VHV.

.seealso: [](ch_ts), `TS`
@*/
PetscErrorCode TSSetIHessianProduct(TS ts, Vec *ihp1, PetscErrorCode (*ihessianproductfunc1)(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV, void *ctx), Vec *ihp2, PetscErrorCode (*ihessianproductfunc2)(TS, PetscReal, Vec, Vec *, Vec, Vec *, void *), Vec *ihp3, PetscErrorCode (*ihessianproductfunc3)(TS, PetscReal, Vec, Vec *, Vec, Vec *, void *), Vec *ihp4, PetscErrorCode (*ihessianproductfunc4)(TS, PetscReal, Vec, Vec *, Vec, Vec *, void *), void *ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(ihp1, 2);

  ts->ihessianproductctx = ctx;
  if (ihp1) ts->vecs_fuu = ihp1;
  if (ihp2) ts->vecs_fup = ihp2;
  if (ihp3) ts->vecs_fpu = ihp3;
  if (ihp4) ts->vecs_fpp = ihp4;
  ts->ihessianproduct_fuu = ihessianproductfunc1;
  ts->ihessianproduct_fup = ihessianproductfunc2;
  ts->ihessianproduct_fpu = ihessianproductfunc3;
  ts->ihessianproduct_fpp = ihessianproductfunc4;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeIHessianProductFunctionUU - Runs the user-defined vector-Hessian-vector product function for Fuu.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeIHessianProductFunctionUU()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TSSetIHessianProduct()`
@*/
PetscErrorCode TSComputeIHessianProductFunctionUU(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  if (ts->ihessianproduct_fuu) PetscCallBack("TS callback IHessianProduct 1 for sensitivity analysis", (*ts->ihessianproduct_fuu)(ts, t, U, Vl, Vr, VHV, ts->ihessianproductctx));

  /* does not consider IMEX for now, so either IHessian or RHSHessian will be calculated, using the same output VHV */
  if (ts->rhshessianproduct_guu) {
    PetscInt nadj;
    PetscCall(TSComputeRHSHessianProductFunctionUU(ts, t, U, Vl, Vr, VHV));
    for (nadj = 0; nadj < ts->numcost; nadj++) PetscCall(VecScale(VHV[nadj], -1));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeIHessianProductFunctionUP - Runs the user-defined vector-Hessian-vector product function for Fup.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeIHessianProductFunctionUP()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TSSetIHessianProduct()`
@*/
PetscErrorCode TSComputeIHessianProductFunctionUP(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  if (ts->ihessianproduct_fup) PetscCallBack("TS callback IHessianProduct 2 for sensitivity analysis", (*ts->ihessianproduct_fup)(ts, t, U, Vl, Vr, VHV, ts->ihessianproductctx));

  /* does not consider IMEX for now, so either IHessian or RHSHessian will be calculated, using the same output VHV */
  if (ts->rhshessianproduct_gup) {
    PetscInt nadj;
    PetscCall(TSComputeRHSHessianProductFunctionUP(ts, t, U, Vl, Vr, VHV));
    for (nadj = 0; nadj < ts->numcost; nadj++) PetscCall(VecScale(VHV[nadj], -1));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeIHessianProductFunctionPU - Runs the user-defined vector-Hessian-vector product function for Fpu.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeIHessianProductFunctionPU()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TSSetIHessianProduct()`
@*/
PetscErrorCode TSComputeIHessianProductFunctionPU(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  if (ts->ihessianproduct_fpu) PetscCallBack("TS callback IHessianProduct 3 for sensitivity analysis", (*ts->ihessianproduct_fpu)(ts, t, U, Vl, Vr, VHV, ts->ihessianproductctx));

  /* does not consider IMEX for now, so either IHessian or RHSHessian will be calculated, using the same output VHV */
  if (ts->rhshessianproduct_gpu) {
    PetscInt nadj;
    PetscCall(TSComputeRHSHessianProductFunctionPU(ts, t, U, Vl, Vr, VHV));
    for (nadj = 0; nadj < ts->numcost; nadj++) PetscCall(VecScale(VHV[nadj], -1));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeIHessianProductFunctionPP - Runs the user-defined vector-Hessian-vector product function for Fpp.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeIHessianProductFunctionPP()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TSSetIHessianProduct()`
@*/
PetscErrorCode TSComputeIHessianProductFunctionPP(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  if (ts->ihessianproduct_fpp) PetscCallBack("TS callback IHessianProduct 3 for sensitivity analysis", (*ts->ihessianproduct_fpp)(ts, t, U, Vl, Vr, VHV, ts->ihessianproductctx));

  /* does not consider IMEX for now, so either IHessian or RHSHessian will be calculated, using the same output VHV */
  if (ts->rhshessianproduct_gpp) {
    PetscInt nadj;
    PetscCall(TSComputeRHSHessianProductFunctionPP(ts, t, U, Vl, Vr, VHV));
    for (nadj = 0; nadj < ts->numcost; nadj++) PetscCall(VecScale(VHV[nadj], -1));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-param-list-func-parameter-documentation
// PetscClangLinter pragma disable: -fdoc-section-header-unknown
/*@C
  TSSetRHSHessianProduct - Sets the function that computes the vector-Hessian-vector
  product. The Hessian is the second-order derivative of G (RHSFunction) w.r.t. the state
  variable.

  Logically Collective

  Input Parameters:
+ ts     - `TS` context obtained from `TSCreate()`
. rhshp1 - an array of vectors storing the result of vector-Hessian-vector product for G_UU
. hessianproductfunc1 - vector-Hessian-vector product function for G_UU
. rhshp2 - an array of vectors storing the result of vector-Hessian-vector product for G_UP
. hessianproductfunc2 - vector-Hessian-vector product function for G_UP
. rhshp3 - an array of vectors storing the result of vector-Hessian-vector product for G_PU
. hessianproductfunc3 - vector-Hessian-vector product function for G_PU
. rhshp4 - an array of vectors storing the result of vector-Hessian-vector product for G_PP
. hessianproductfunc4 - vector-Hessian-vector product function for G_PP
- ctx    - [optional] user-defined function context

  Calling sequence of `rhshessianproductfunc1`:
+ ts  - the `TS` context
. t   - current timestep
. U   - input vector (current ODE solution)
. Vl  - an array of input vectors to be left-multiplied with the Hessian
. Vr  - input vector to be right-multiplied with the Hessian
. VHV - an array of output vectors for vector-Hessian-vector product
- ctx - [optional] user-defined function context

  Level: intermediate

  Notes:
  All other functions have the same calling sequence as `rhhessianproductfunc1`, so their
  descriptions are omitted for brevity.

  The first Hessian function and the working array are required.

  As an example to implement the callback functions, the second callback function calculates the vector-Hessian-vector product
  $ Vl_n^T*G_UP*Vr
  where the vector Vl_n (n-th element in the array Vl) and Vr are of size N and M respectively, and the Hessian G_UP is of size N x N x M.
  Each entry of G_UP corresponds to the derivative
  $ G_UP[i][j][k] = \frac{\partial^2 G[i]}{\partial U[j] \partial P[k]}.
  The result of the vector-Hessian-vector product for Vl_n needs to be stored in vector VHV_n with j-th entry being
  $ VHV_n[j] = \sum_i \sum_k {Vl_n[i] * G_UP[i][j][k] * Vr[k]}
  If the cost function is a scalar, there will be only one vector in Vl and VHV.

.seealso: `TS`, `TSAdjoint`
@*/
PetscErrorCode TSSetRHSHessianProduct(TS ts, Vec *rhshp1, PetscErrorCode (*rhshessianproductfunc1)(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV, void *ctx), Vec *rhshp2, PetscErrorCode (*rhshessianproductfunc2)(TS, PetscReal, Vec, Vec *, Vec, Vec *, void *), Vec *rhshp3, PetscErrorCode (*rhshessianproductfunc3)(TS, PetscReal, Vec, Vec *, Vec, Vec *, void *), Vec *rhshp4, PetscErrorCode (*rhshessianproductfunc4)(TS, PetscReal, Vec, Vec *, Vec, Vec *, void *), void *ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(rhshp1, 2);

  ts->rhshessianproductctx = ctx;
  if (rhshp1) ts->vecs_guu = rhshp1;
  if (rhshp2) ts->vecs_gup = rhshp2;
  if (rhshp3) ts->vecs_gpu = rhshp3;
  if (rhshp4) ts->vecs_gpp = rhshp4;
  ts->rhshessianproduct_guu = rhshessianproductfunc1;
  ts->rhshessianproduct_gup = rhshessianproductfunc2;
  ts->rhshessianproduct_gpu = rhshessianproductfunc3;
  ts->rhshessianproduct_gpp = rhshessianproductfunc4;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeRHSHessianProductFunctionUU - Runs the user-defined vector-Hessian-vector product function for Guu.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeRHSHessianProductFunctionUU()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TS`, `TSSetRHSHessianProduct()`
@*/
PetscErrorCode TSComputeRHSHessianProductFunctionUU(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback RHSHessianProduct 1 for sensitivity analysis", (*ts->rhshessianproduct_guu)(ts, t, U, Vl, Vr, VHV, ts->rhshessianproductctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeRHSHessianProductFunctionUP - Runs the user-defined vector-Hessian-vector product function for Gup.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeRHSHessianProductFunctionUP()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TS`, `TSSetRHSHessianProduct()`
@*/
PetscErrorCode TSComputeRHSHessianProductFunctionUP(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback RHSHessianProduct 2 for sensitivity analysis", (*ts->rhshessianproduct_gup)(ts, t, U, Vl, Vr, VHV, ts->rhshessianproductctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeRHSHessianProductFunctionPU - Runs the user-defined vector-Hessian-vector product function for Gpu.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeRHSHessianProductFunctionPU()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TSSetRHSHessianProduct()`
@*/
PetscErrorCode TSComputeRHSHessianProductFunctionPU(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback RHSHessianProduct 3 for sensitivity analysis", (*ts->rhshessianproduct_gpu)(ts, t, U, Vl, Vr, VHV, ts->rhshessianproductctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSComputeRHSHessianProductFunctionPP - Runs the user-defined vector-Hessian-vector product function for Gpp.

  Collective

  Input Parameters:
+ ts - The `TS` context obtained from `TSCreate()`
. t  - the time
. U  - the solution at which to compute the Hessian product
. Vl - the array of input vectors to be multiplied with the Hessian from the left
- Vr - the input vector to be multiplied with the Hessian from the right

  Output Parameter:
. VHV - the array of output vectors that store the Hessian product

  Level: developer

  Note:
  `TSComputeRHSHessianProductFunctionPP()` is typically used for sensitivity implementation,
  so most users would not generally call this routine themselves.

.seealso: [](ch_ts), `TSSetRHSHessianProduct()`
@*/
PetscErrorCode TSComputeRHSHessianProductFunctionPP(TS ts, PetscReal t, Vec U, Vec *Vl, Vec Vr, Vec *VHV)
{
  PetscFunctionBegin;
  if (!VHV) PetscFunctionReturn(PETSC_SUCCESS);
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback RHSHessianProduct 3 for sensitivity analysis", (*ts->rhshessianproduct_gpp)(ts, t, U, Vl, Vr, VHV, ts->rhshessianproductctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* --------------------------- Adjoint sensitivity ---------------------------*/

/*@
  TSSetCostGradients - Sets the initial value of the gradients of the cost function w.r.t. initial values and w.r.t. the problem parameters
  for use by the `TS` adjoint routines.

  Logically Collective

  Input Parameters:
+ ts      - the `TS` context obtained from `TSCreate()`
. numcost - number of gradients to be computed, this is the number of cost functions
. lambda  - gradients with respect to the initial condition variables, the dimension and parallel layout of these vectors is the same as the ODE solution vector
- mu      - gradients with respect to the parameters, the number of entries in these vectors is the same as the number of parameters

  Level: beginner

  Notes:
  the entries in these vectors must be correctly initialized with the values lambda_i = df/dy|finaltime  mu_i = df/dp|finaltime

  After `TSAdjointSolve()` is called the lambda and the mu contain the computed sensitivities

.seealso: `TS`, `TSAdjointSolve()`, `TSGetCostGradients()`
@*/
PetscErrorCode TSSetCostGradients(TS ts, PetscInt numcost, Vec *lambda, Vec *mu)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(lambda, 3);
  ts->vecs_sensi  = lambda;
  ts->vecs_sensip = mu;
  PetscCheck(!ts->numcost || ts->numcost == numcost, PetscObjectComm((PetscObject)ts), PETSC_ERR_USER, "The number of cost functions (2nd parameter of TSSetCostIntegrand()) is inconsistent with the one set by TSSetCostIntegrand");
  ts->numcost = numcost;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSGetCostGradients - Returns the gradients from the `TSAdjointSolve()`

  Not Collective, but the vectors returned are parallel if `TS` is parallel

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Output Parameters:
+ numcost - size of returned arrays
. lambda  - vectors containing the gradients of the cost functions with respect to the ODE/DAE solution variables
- mu      - vectors containing the gradients of the cost functions with respect to the problem parameters

  Level: intermediate

.seealso: [](ch_ts), `TS`, `TSAdjointSolve()`, `TSSetCostGradients()`
@*/
PetscErrorCode TSGetCostGradients(TS ts, PetscInt *numcost, Vec **lambda, Vec **mu)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  if (numcost) *numcost = ts->numcost;
  if (lambda) *lambda = ts->vecs_sensi;
  if (mu) *mu = ts->vecs_sensip;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSSetCostHessianProducts - Sets the initial value of the Hessian-vector products of the cost function w.r.t. initial values and w.r.t. the problem parameters
  for use by the `TS` adjoint routines.

  Logically Collective

  Input Parameters:
+ ts      - the `TS` context obtained from `TSCreate()`
. numcost - number of cost functions
. lambda2 - Hessian-vector product with respect to the initial condition variables, the dimension and parallel layout of these vectors is the same as the ODE solution vector
. mu2     - Hessian-vector product with respect to the parameters, the number of entries in these vectors is the same as the number of parameters
- dir     - the direction vector that are multiplied with the Hessian of the cost functions

  Level: beginner

  Notes:
  Hessian of the cost function is completely different from Hessian of the ODE/DAE system

  For second-order adjoint, one needs to call this function and then `TSAdjointSetForward()` before `TSSolve()`.

  After `TSAdjointSolve()` is called, the lambda2 and the mu2 will contain the computed second-order adjoint sensitivities, and can be used to produce Hessian-vector product (not the full Hessian matrix). Users must provide a direction vector; it is usually generated by an optimization solver.

  Passing `NULL` for `lambda2` disables the second-order calculation.

.seealso: [](ch_ts), `TS`, `TSAdjointSolve()`, `TSAdjointSetForward()`
@*/
PetscErrorCode TSSetCostHessianProducts(TS ts, PetscInt numcost, Vec *lambda2, Vec *mu2, Vec dir)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscCheck(!ts->numcost || ts->numcost == numcost, PetscObjectComm((PetscObject)ts), PETSC_ERR_USER, "The number of cost functions (2nd parameter of TSSetCostIntegrand()) is inconsistent with the one set by TSSetCostIntegrand");
  ts->numcost      = numcost;
  ts->vecs_sensi2  = lambda2;
  ts->vecs_sensi2p = mu2;
  ts->vec_dir      = dir;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSGetCostHessianProducts - Returns the gradients from the `TSAdjointSolve()`

  Not Collective, but vectors returned are parallel if `TS` is parallel

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Output Parameters:
+ numcost - number of cost functions
. lambda2 - Hessian-vector product with respect to the initial condition variables, the dimension and parallel layout of these vectors is the same as the ODE solution vector
. mu2     - Hessian-vector product with respect to the parameters, the number of entries in these vectors is the same as the number of parameters
- dir     - the direction vector that are multiplied with the Hessian of the cost functions

  Level: intermediate

.seealso: [](ch_ts), `TSAdjointSolve()`, `TSSetCostHessianProducts()`
@*/
PetscErrorCode TSGetCostHessianProducts(TS ts, PetscInt *numcost, Vec **lambda2, Vec **mu2, Vec *dir)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  if (numcost) *numcost = ts->numcost;
  if (lambda2) *lambda2 = ts->vecs_sensi2;
  if (mu2) *mu2 = ts->vecs_sensi2p;
  if (dir) *dir = ts->vec_dir;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointSetForward - Trigger the tangent linear solver and initialize the forward sensitivities

  Logically Collective

  Input Parameters:
+ ts   - the `TS` context obtained from `TSCreate()`
- didp - the derivative of initial values w.r.t. parameters

  Level: intermediate

  Notes:
  When computing sensitivities w.r.t. initial condition, set didp to `NULL` so that the solver will take it as an identity matrix mathematically.
  `TSAdjoint` does not reset the tangent linear solver automatically, `TSAdjointResetForward()` should be called to reset the tangent linear solver.

.seealso: [](ch_ts), `TSAdjointSolve()`, `TSSetCostHessianProducts()`, `TSAdjointResetForward()`
@*/
PetscErrorCode TSAdjointSetForward(TS ts, Mat didp)
{
  Mat          A;
  Vec          sp;
  PetscScalar *xarr;
  PetscInt     lsize;

  PetscFunctionBegin;
  ts->forward_solve = PETSC_TRUE; /* turn on tangent linear mode */
  PetscCheck(ts->vecs_sensi2, PetscObjectComm((PetscObject)ts), PETSC_ERR_USER, "Must call TSSetCostHessianProducts() first");
  PetscCheck(ts->vec_dir, PetscObjectComm((PetscObject)ts), PETSC_ERR_USER, "Directional vector is missing. Call TSSetCostHessianProducts() to set it.");
  /* create a single-column dense matrix */
  PetscCall(VecGetLocalSize(ts->vec_sol, &lsize));
  PetscCall(MatCreateDense(PetscObjectComm((PetscObject)ts), lsize, PETSC_DECIDE, PETSC_DECIDE, 1, NULL, &A));

  PetscCall(VecDuplicate(ts->vec_sol, &sp));
  PetscCall(MatDenseGetColumn(A, 0, &xarr));
  PetscCall(VecPlaceArray(sp, xarr));
  if (ts->vecs_sensi2p) { /* tangent linear variable initialized as 2*dIdP*dir */
    if (didp) {
      PetscCall(MatMult(didp, ts->vec_dir, sp));
      PetscCall(VecScale(sp, 2.));
    } else {
      PetscCall(VecZeroEntries(sp));
    }
  } else { /* tangent linear variable initialized as dir */
    PetscCall(VecCopy(ts->vec_dir, sp));
  }
  PetscCall(VecResetArray(sp));
  PetscCall(MatDenseRestoreColumn(A, &xarr));
  PetscCall(VecDestroy(&sp));

  PetscCall(TSForwardSetInitialSensitivities(ts, A)); /* if didp is NULL, identity matrix is assumed */

  PetscCall(MatDestroy(&A));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointResetForward - Reset the tangent linear solver and destroy the tangent linear context

  Logically Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Level: intermediate

.seealso: [](ch_ts), `TSAdjointSetForward()`
@*/
PetscErrorCode TSAdjointResetForward(TS ts)
{
  PetscFunctionBegin;
  ts->forward_solve = PETSC_FALSE; /* turn off tangent linear mode */
  PetscCall(TSForwardReset(ts));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointSetUp - Sets up the internal data structures for the later use
  of an adjoint solver

  Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Level: advanced

.seealso: [](ch_ts), `TSCreate()`, `TSAdjointStep()`, `TSSetCostGradients()`
@*/
PetscErrorCode TSAdjointSetUp(TS ts)
{
  TSTrajectory tj;
  PetscBool    match;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  if (ts->adjointsetupcalled) PetscFunctionReturn(PETSC_SUCCESS);
  PetscCheck(ts->vecs_sensi, PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_WRONGSTATE, "Must call TSSetCostGradients() first");
  PetscCheck(!ts->vecs_sensip || ts->Jacp || ts->Jacprhs, PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_WRONGSTATE, "Must call TSSetRHSJacobianP() or TSSetIJacobianP() first");
  PetscCall(TSGetTrajectory(ts, &tj));
  PetscCall(PetscObjectTypeCompare((PetscObject)tj, TSTRAJECTORYBASIC, &match));
  if (match) {
    PetscBool solution_only;
    PetscCall(TSTrajectoryGetSolutionOnly(tj, &solution_only));
    PetscCheck(!solution_only, PetscObjectComm((PetscObject)ts), PETSC_ERR_USER, "TSAdjoint cannot use the solution-only mode when choosing the Basic TSTrajectory type. Turn it off with -ts_trajectory_solution_only 0");
  }
  PetscCall(TSTrajectorySetUseHistory(tj, PETSC_FALSE)); /* not use TSHistory */

  if (ts->quadraturets) { /* if there is integral in the cost function */
    PetscCall(VecDuplicate(ts->vecs_sensi[0], &ts->vec_drdu_col));
    if (ts->vecs_sensip) PetscCall(VecDuplicate(ts->vecs_sensip[0], &ts->vec_drdp_col));
  }

  PetscTryTypeMethod(ts, adjointsetup);
  ts->adjointsetupcalled = PETSC_TRUE;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointReset - Resets a `TS` adjoint context and removes any allocated `Vec`s and `Mat`s.

  Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Level: beginner

.seealso: [](ch_ts), `TSCreate()`, `TSAdjointSetUp()`, `TSADestroy()`
@*/
PetscErrorCode TSAdjointReset(TS ts)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscTryTypeMethod(ts, adjointreset);
  if (ts->quadraturets) { /* if there is integral in the cost function */
    PetscCall(VecDestroy(&ts->vec_drdu_col));
    if (ts->vecs_sensip) PetscCall(VecDestroy(&ts->vec_drdp_col));
  }
  ts->vecs_sensi         = NULL;
  ts->vecs_sensip        = NULL;
  ts->vecs_sensi2        = NULL;
  ts->vecs_sensi2p       = NULL;
  ts->vec_dir            = NULL;
  ts->adjointsetupcalled = PETSC_FALSE;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointSetSteps - Sets the number of steps the adjoint solver should take backward in time

  Logically Collective

  Input Parameters:
+ ts    - the `TS` context obtained from `TSCreate()`
- steps - number of steps to use

  Level: intermediate

  Notes:
  Normally one does not call this and `TSAdjointSolve()` integrates back to the original timestep. One can call this
  so as to integrate back to less than the original timestep

.seealso: [](ch_ts), `TSAdjointSolve()`, `TS`, `TSSetExactFinalTime()`
@*/
PetscErrorCode TSAdjointSetSteps(TS ts, PetscInt steps)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidLogicalCollectiveInt(ts, steps, 2);
  PetscCheck(steps >= 0, PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_OUTOFRANGE, "Cannot step back a negative number of steps");
  PetscCheck(steps <= ts->steps, PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_OUTOFRANGE, "Cannot step back more than the total number of forward steps");
  ts->adjoint_max_steps = steps;
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-*
/*@C
  TSAdjointSetRHSJacobian - Deprecated, use `TSSetRHSJacobianP()`

  Level: deprecated
@*/
PetscErrorCode TSAdjointSetRHSJacobian(TS ts, Mat Amat, PetscErrorCode (*func)(TS, PetscReal, Vec, Mat, void *), void *ctx)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(Amat, MAT_CLASSID, 2);

  ts->rhsjacobianp    = func;
  ts->rhsjacobianpctx = ctx;
  if (Amat) {
    PetscCall(PetscObjectReference((PetscObject)Amat));
    PetscCall(MatDestroy(&ts->Jacp));
    ts->Jacp = Amat;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-*
/*@C
  TSAdjointComputeRHSJacobian - Deprecated, use `TSComputeRHSJacobianP()`

  Level: deprecated
@*/
PetscErrorCode TSAdjointComputeRHSJacobian(TS ts, PetscReal t, Vec U, Mat Amat)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);
  PetscValidHeaderSpecific(Amat, MAT_CLASSID, 4);

  PetscCallBack("TS callback JacobianP for sensitivity analysis", (*ts->rhsjacobianp)(ts, t, U, Amat, ts->rhsjacobianpctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-*
/*@
  TSAdjointComputeDRDYFunction - Deprecated, use `TSGetQuadratureTS()` then `TSComputeRHSJacobian()`

  Level: deprecated
@*/
PetscErrorCode TSAdjointComputeDRDYFunction(TS ts, PetscReal t, Vec U, Vec *DRDU)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback DRDY for sensitivity analysis", (*ts->drdufunction)(ts, t, U, DRDU, ts->costintegrandctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-*
/*@
  TSAdjointComputeDRDPFunction - Deprecated, use `TSGetQuadratureTS()` then `TSComputeRHSJacobianP()`

  Level: deprecated
@*/
PetscErrorCode TSAdjointComputeDRDPFunction(TS ts, PetscReal t, Vec U, Vec *DRDP)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(U, VEC_CLASSID, 3);

  PetscCallBack("TS callback DRDP for sensitivity analysis", (*ts->drdpfunction)(ts, t, U, DRDP, ts->costintegrandctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

// PetscClangLinter pragma disable: -fdoc-param-list-func-parameter-documentation
/*@C
  TSAdjointMonitorSensi - monitors the first lambda sensitivity

  Level: intermediate

.seealso: [](ch_ts), `TSAdjointMonitorSet()`
@*/
static PetscErrorCode TSAdjointMonitorSensi(TS ts, PetscInt step, PetscReal ptime, Vec v, PetscInt numcost, Vec *lambda, Vec *mu, PetscViewerAndFormat *vf)
{
  PetscViewer viewer = vf->viewer;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(viewer, PETSC_VIEWER_CLASSID, 8);
  PetscCall(PetscViewerPushFormat(viewer, vf->format));
  PetscCall(VecView(lambda[0], viewer));
  PetscCall(PetscViewerPopFormat(viewer));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSAdjointMonitorSetFromOptions - Sets a monitor function and viewer appropriate for the type indicated by the user

  Collective

  Input Parameters:
+ ts           - `TS` object you wish to monitor
. name         - the monitor type one is seeking
. help         - message indicating what monitoring is done
. manual       - manual page for the monitor
. monitor      - the monitor function
- monitorsetup - a function that is called once ONLY if the user selected this monitor that may set additional features of the `TS` or `PetscViewer` objects

  Level: developer

.seealso: [](ch_ts), `PetscOptionsGetViewer()`, `PetscOptionsGetReal()`, `PetscOptionsHasName()`, `PetscOptionsGetString()`,
          `PetscOptionsGetIntArray()`, `PetscOptionsGetRealArray()`, `PetscOptionsBool()`
          `PetscOptionsInt()`, `PetscOptionsString()`, `PetscOptionsReal()`,
          `PetscOptionsName()`, `PetscOptionsBegin()`, `PetscOptionsEnd()`, `PetscOptionsHeadBegin()`,
          `PetscOptionsStringArray()`, `PetscOptionsRealArray()`, `PetscOptionsScalar()`,
          `PetscOptionsBoolGroupBegin()`, `PetscOptionsBoolGroup()`, `PetscOptionsBoolGroupEnd()`,
          `PetscOptionsFList()`, `PetscOptionsEList()`
@*/
PetscErrorCode TSAdjointMonitorSetFromOptions(TS ts, const char name[], const char help[], const char manual[], PetscErrorCode (*monitor)(TS, PetscInt, PetscReal, Vec, PetscInt, Vec *, Vec *, PetscViewerAndFormat *), PetscErrorCode (*monitorsetup)(TS, PetscViewerAndFormat *))
{
  PetscViewer       viewer;
  PetscViewerFormat format;
  PetscBool         flg;

  PetscFunctionBegin;
  PetscCall(PetscOptionsGetViewer(PetscObjectComm((PetscObject)ts), ((PetscObject)ts)->options, ((PetscObject)ts)->prefix, name, &viewer, &format, &flg));
  if (flg) {
    PetscViewerAndFormat *vf;
    PetscCall(PetscViewerAndFormatCreate(viewer, format, &vf));
    PetscCall(PetscObjectDereference((PetscObject)viewer));
    if (monitorsetup) PetscCall((*monitorsetup)(ts, vf));
    PetscCall(TSAdjointMonitorSet(ts, (PetscErrorCode(*)(TS, PetscInt, PetscReal, Vec, PetscInt, Vec *, Vec *, void *))monitor, vf, (PetscErrorCode(*)(void **))PetscViewerAndFormatDestroy));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSAdjointMonitorSet - Sets an ADDITIONAL function that is to be used at every
  timestep to display the iteration's  progress.

  Logically Collective

  Input Parameters:
+ ts              - the `TS` context obtained from `TSCreate()`
. adjointmonitor  - monitoring routine
. adjointmctx     - [optional] user-defined context for private data for the monitor routine
                    (use `NULL` if no context is desired)
- adjointmdestroy - [optional] routine that frees monitor context (may be `NULL`)

  Calling sequence of `adjointmonitor`:
+ ts          - the `TS` context
. steps       - iteration number (after the final time step the monitor routine is called with
                a step of -1, this is at the final time which may have been interpolated to)
. time        - current time
. u           - current iterate
. numcost     - number of cost functionos
. lambda      - sensitivities to initial conditions
. mu          - sensitivities to parameters
- adjointmctx - [optional] adjoint monitoring context

  Calling sequence of `adjointmdestroy`:
. mctx - the monitor context to destroy

  Level: intermediate

  Note:
  This routine adds an additional monitor to the list of monitors that
  already has been loaded.

  Fortran Notes:
  Only a single monitor function can be set for each `TS` object

.seealso: [](ch_ts), `TS`, `TSAdjointSolve()`, `TSAdjointMonitorCancel()`
@*/
PetscErrorCode TSAdjointMonitorSet(TS ts, PetscErrorCode (*adjointmonitor)(TS ts, PetscInt steps, PetscReal time, Vec u, PetscInt numcost, Vec *lambda, Vec *mu, void *adjointmctx), void *adjointmctx, PetscErrorCode (*adjointmdestroy)(void **mctx))
{
  PetscInt  i;
  PetscBool identical;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  for (i = 0; i < ts->numbermonitors; i++) {
    PetscCall(PetscMonitorCompare((PetscErrorCode(*)(void))adjointmonitor, adjointmctx, adjointmdestroy, (PetscErrorCode(*)(void))ts->adjointmonitor[i], ts->adjointmonitorcontext[i], ts->adjointmonitordestroy[i], &identical));
    if (identical) PetscFunctionReturn(PETSC_SUCCESS);
  }
  PetscCheck(ts->numberadjointmonitors < MAXTSMONITORS, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Too many adjoint monitors set");
  ts->adjointmonitor[ts->numberadjointmonitors]          = adjointmonitor;
  ts->adjointmonitordestroy[ts->numberadjointmonitors]   = adjointmdestroy;
  ts->adjointmonitorcontext[ts->numberadjointmonitors++] = (void *)adjointmctx;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSAdjointMonitorCancel - Clears all the adjoint monitors that have been set on a time-step object.

  Logically Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Notes:
  There is no way to remove a single, specific monitor.

  Level: intermediate

.seealso: [](ch_ts), `TS`, `TSAdjointSolve()`, `TSAdjointMonitorSet()`
@*/
PetscErrorCode TSAdjointMonitorCancel(TS ts)
{
  PetscInt i;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  for (i = 0; i < ts->numberadjointmonitors; i++) {
    if (ts->adjointmonitordestroy[i]) PetscCall((*ts->adjointmonitordestroy[i])(&ts->adjointmonitorcontext[i]));
  }
  ts->numberadjointmonitors = 0;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSAdjointMonitorDefault - the default monitor of adjoint computations

  Input Parameters:
+ ts      - the `TS` context
. step    - iteration number (after the final time step the monitor routine is called with a
step of -1, this is at the final time which may have been interpolated to)
. time    - current time
. v       - current iterate
. numcost - number of cost functionos
. lambda  - sensitivities to initial conditions
. mu      - sensitivities to parameters
- vf      - the viewer and format

  Level: intermediate

.seealso: [](ch_ts), `TS`, `TSAdjointSolve()`, `TSAdjointMonitorSet()`
@*/
PetscErrorCode TSAdjointMonitorDefault(TS ts, PetscInt step, PetscReal time, Vec v, PetscInt numcost, Vec *lambda, Vec *mu, PetscViewerAndFormat *vf)
{
  PetscViewer viewer = vf->viewer;

  PetscFunctionBegin;
  (void)v;
  (void)numcost;
  (void)lambda;
  (void)mu;
  PetscValidHeaderSpecific(viewer, PETSC_VIEWER_CLASSID, 8);
  PetscCall(PetscViewerPushFormat(viewer, vf->format));
  PetscCall(PetscViewerASCIIAddTab(viewer, ((PetscObject)ts)->tablevel));
  PetscCall(PetscViewerASCIIPrintf(viewer, "%" PetscInt_FMT " TS dt %g time %g%s", step, (double)ts->time_step, (double)time, ts->steprollback ? " (r)\n" : "\n"));
  PetscCall(PetscViewerASCIISubtractTab(viewer, ((PetscObject)ts)->tablevel));
  PetscCall(PetscViewerPopFormat(viewer));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSAdjointMonitorDrawSensi - Monitors progress of the adjoint `TS` solvers by calling
  `VecView()` for the sensitivities to initial states at each timestep

  Collective

  Input Parameters:
+ ts      - the `TS` context
. step    - current time-step
. ptime   - current time
. u       - current state
. numcost - number of cost functions
. lambda  - sensitivities to initial conditions
. mu      - sensitivities to parameters
- dummy   - either a viewer or `NULL`

  Level: intermediate

.seealso: [](ch_ts), `TSAdjointSolve()`, `TSAdjointMonitorSet()`, `TSAdjointMonitorDefault()`, `VecView()`
@*/
PetscErrorCode TSAdjointMonitorDrawSensi(TS ts, PetscInt step, PetscReal ptime, Vec u, PetscInt numcost, Vec *lambda, Vec *mu, void *dummy)
{
  TSMonitorDrawCtx ictx = (TSMonitorDrawCtx)dummy;
  PetscDraw        draw;
  PetscReal        xl, yl, xr, yr, h;
  char             time[32];

  PetscFunctionBegin;
  if (!(((ictx->howoften > 0) && (!(step % ictx->howoften))) || ((ictx->howoften == -1) && ts->reason))) PetscFunctionReturn(PETSC_SUCCESS);

  PetscCall(VecView(lambda[0], ictx->viewer));
  PetscCall(PetscViewerDrawGetDraw(ictx->viewer, 0, &draw));
  PetscCall(PetscSNPrintf(time, 32, "Timestep %d Time %g", (int)step, (double)ptime));
  PetscCall(PetscDrawGetCoordinates(draw, &xl, &yl, &xr, &yr));
  h = yl + .95 * (yr - yl);
  PetscCall(PetscDrawStringCentered(draw, .5 * (xl + xr), h, PETSC_DRAW_BLACK, time));
  PetscCall(PetscDrawFlush(draw));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSAdjointSetFromOptions - Sets various `TS` adjoint parameters from options database.

  Collective

  Input Parameters:
+ ts                 - the `TS` context
- PetscOptionsObject - the options context

  Options Database Keys:
+ -ts_adjoint_solve <yes,no>     - After solving the ODE/DAE solve the adjoint problem (requires `-ts_save_trajectory`)
. -ts_adjoint_monitor            - print information at each adjoint time step
- -ts_adjoint_monitor_draw_sensi - monitor the sensitivity of the first cost function wrt initial conditions (lambda[0]) graphically

  Level: developer

  Note:
  This is not normally called directly by users

.seealso: [](ch_ts), `TSSetSaveTrajectory()`, `TSTrajectorySetUp()`
@*/
PetscErrorCode TSAdjointSetFromOptions(TS ts, PetscOptionItems *PetscOptionsObject)
{
  PetscBool tflg, opt;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscOptionsHeadBegin(PetscOptionsObject, "TS Adjoint options");
  tflg = ts->adjoint_solve ? PETSC_TRUE : PETSC_FALSE;
  PetscCall(PetscOptionsBool("-ts_adjoint_solve", "Solve the adjoint problem immediately after solving the forward problem", "", tflg, &tflg, &opt));
  if (opt) {
    PetscCall(TSSetSaveTrajectory(ts));
    ts->adjoint_solve = tflg;
  }
  PetscCall(TSAdjointMonitorSetFromOptions(ts, "-ts_adjoint_monitor", "Monitor adjoint timestep size", "TSAdjointMonitorDefault", TSAdjointMonitorDefault, NULL));
  PetscCall(TSAdjointMonitorSetFromOptions(ts, "-ts_adjoint_monitor_sensi", "Monitor sensitivity in the adjoint computation", "TSAdjointMonitorSensi", TSAdjointMonitorSensi, NULL));
  opt = PETSC_FALSE;
  PetscCall(PetscOptionsName("-ts_adjoint_monitor_draw_sensi", "Monitor adjoint sensitivities (lambda only) graphically", "TSAdjointMonitorDrawSensi", &opt));
  if (opt) {
    TSMonitorDrawCtx ctx;
    PetscInt         howoften = 1;

    PetscCall(PetscOptionsInt("-ts_adjoint_monitor_draw_sensi", "Monitor adjoint sensitivities (lambda only) graphically", "TSAdjointMonitorDrawSensi", howoften, &howoften, NULL));
    PetscCall(TSMonitorDrawCtxCreate(PetscObjectComm((PetscObject)ts), NULL, NULL, PETSC_DECIDE, PETSC_DECIDE, 300, 300, howoften, &ctx));
    PetscCall(TSAdjointMonitorSet(ts, TSAdjointMonitorDrawSensi, ctx, (PetscErrorCode(*)(void **))TSMonitorDrawCtxDestroy));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointStep - Steps one time step backward in the adjoint run

  Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Level: intermediate

.seealso: [](ch_ts), `TSAdjointSetUp()`, `TSAdjointSolve()`
@*/
PetscErrorCode TSAdjointStep(TS ts)
{
  DM dm;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscCall(TSGetDM(ts, &dm));
  PetscCall(TSAdjointSetUp(ts));
  ts->steps--; /* must decrease the step index before the adjoint step is taken. */

  ts->reason     = TS_CONVERGED_ITERATING;
  ts->ptime_prev = ts->ptime;
  PetscCall(PetscLogEventBegin(TS_AdjointStep, ts, 0, 0, 0));
  PetscUseTypeMethod(ts, adjointstep);
  PetscCall(PetscLogEventEnd(TS_AdjointStep, ts, 0, 0, 0));
  ts->adjoint_steps++;

  if (ts->reason < 0) {
    PetscCheck(!ts->errorifstepfailed, PetscObjectComm((PetscObject)ts), PETSC_ERR_NOT_CONVERGED, "TSAdjointStep has failed due to %s", TSConvergedReasons[ts->reason]);
  } else if (!ts->reason) {
    if (ts->adjoint_steps >= ts->adjoint_max_steps) ts->reason = TS_CONVERGED_ITS;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointSolve - Solves the discrete ajoint problem for an ODE/DAE

  Collective
  `

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Options Database Key:
. -ts_adjoint_view_solution <viewerinfo> - views the first gradient with respect to the initial values

  Level: intermediate

  Notes:
  This must be called after a call to `TSSolve()` that solves the forward problem

  By default this will integrate back to the initial time, one can use `TSAdjointSetSteps()` to step back to a later time

.seealso: [](ch_ts), `TSCreate()`, `TSSetCostGradients()`, `TSSetSolution()`, `TSAdjointStep()`
@*/
PetscErrorCode TSAdjointSolve(TS ts)
{
  static PetscBool cite = PETSC_FALSE;
#if defined(TSADJOINT_STAGE)
  PetscLogStage adjoint_stage;
#endif

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscCall(PetscCitationsRegister("@article{Zhang2022tsadjoint,\n"
                                   "  title         = {{PETSc TSAdjoint: A Discrete Adjoint ODE Solver for First-Order and Second-Order Sensitivity Analysis}},\n"
                                   "  author        = {Zhang, Hong and Constantinescu, Emil M.  and Smith, Barry F.},\n"
                                   "  journal       = {SIAM Journal on Scientific Computing},\n"
                                   "  volume        = {44},\n"
                                   "  number        = {1},\n"
                                   "  pages         = {C1-C24},\n"
                                   "  doi           = {10.1137/21M140078X},\n"
                                   "  year          = {2022}\n}\n",
                                   &cite));
#if defined(TSADJOINT_STAGE)
  PetscCall(PetscLogStageRegister("TSAdjoint", &adjoint_stage));
  PetscCall(PetscLogStagePush(adjoint_stage));
#endif
  PetscCall(TSAdjointSetUp(ts));

  /* reset time step and iteration counters */
  ts->adjoint_steps     = 0;
  ts->ksp_its           = 0;
  ts->snes_its          = 0;
  ts->num_snes_failures = 0;
  ts->reject            = 0;
  ts->reason            = TS_CONVERGED_ITERATING;

  if (!ts->adjoint_max_steps) ts->adjoint_max_steps = ts->steps;
  if (ts->adjoint_steps >= ts->adjoint_max_steps) ts->reason = TS_CONVERGED_ITS;

  while (!ts->reason) {
    PetscCall(TSTrajectoryGet(ts->trajectory, ts, ts->steps, &ts->ptime));
    PetscCall(TSAdjointMonitor(ts, ts->steps, ts->ptime, ts->vec_sol, ts->numcost, ts->vecs_sensi, ts->vecs_sensip));
    PetscCall(TSAdjointEventHandler(ts));
    PetscCall(TSAdjointStep(ts));
    if ((ts->vec_costintegral || ts->quadraturets) && !ts->costintegralfwd) PetscCall(TSAdjointCostIntegral(ts));
  }
  if (!ts->steps) {
    PetscCall(TSTrajectoryGet(ts->trajectory, ts, ts->steps, &ts->ptime));
    PetscCall(TSAdjointMonitor(ts, ts->steps, ts->ptime, ts->vec_sol, ts->numcost, ts->vecs_sensi, ts->vecs_sensip));
  }
  ts->solvetime = ts->ptime;
  PetscCall(TSTrajectoryViewFromOptions(ts->trajectory, NULL, "-ts_trajectory_view"));
  PetscCall(VecViewFromOptions(ts->vecs_sensi[0], (PetscObject)ts, "-ts_adjoint_view_solution"));
  ts->adjoint_max_steps = 0;
#if defined(TSADJOINT_STAGE)
  PetscCall(PetscLogStagePop());
#endif
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  TSAdjointMonitor - Runs all user-provided adjoint monitor routines set using `TSAdjointMonitorSet()`

  Collective

  Input Parameters:
+ ts      - time stepping context obtained from `TSCreate()`
. step    - step number that has just completed
. ptime   - model time of the state
. u       - state at the current model time
. numcost - number of cost functions (dimension of lambda  or mu)
. lambda  - vectors containing the gradients of the cost functions with respect to the ODE/DAE solution variables
- mu      - vectors containing the gradients of the cost functions with respect to the problem parameters

  Level: developer

  Note:
  `TSAdjointMonitor()` is typically used automatically within the time stepping implementations.
  Users would almost never call this routine directly.

.seealso: `TSAdjointMonitorSet()`, `TSAdjointSolve()`
@*/
PetscErrorCode TSAdjointMonitor(TS ts, PetscInt step, PetscReal ptime, Vec u, PetscInt numcost, Vec *lambda, Vec *mu)
{
  PetscInt i, n = ts->numberadjointmonitors;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(u, VEC_CLASSID, 4);
  PetscCall(VecLockReadPush(u));
  for (i = 0; i < n; i++) PetscCall((*ts->adjointmonitor[i])(ts, step, ptime, u, numcost, lambda, mu, ts->adjointmonitorcontext[i]));
  PetscCall(VecLockReadPop(u));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSAdjointCostIntegral - Evaluate the cost integral in the adjoint run.

  Collective

  Input Parameter:
. ts - time stepping context

  Level: advanced

  Notes:
  This function cannot be called until `TSAdjointStep()` has been completed.

.seealso: [](ch_ts), `TSAdjointSolve()`, `TSAdjointStep()`
 @*/
PetscErrorCode TSAdjointCostIntegral(TS ts)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscUseTypeMethod(ts, adjointintegral);
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* ------------------ Forward (tangent linear) sensitivity  ------------------*/

/*@
  TSForwardSetUp - Sets up the internal data structures for the later use
  of forward sensitivity analysis

  Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Level: advanced

.seealso: [](ch_ts), `TS`, `TSCreate()`, `TSDestroy()`, `TSSetUp()`
@*/
PetscErrorCode TSForwardSetUp(TS ts)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  if (ts->forwardsetupcalled) PetscFunctionReturn(PETSC_SUCCESS);
  PetscTryTypeMethod(ts, forwardsetup);
  PetscCall(VecDuplicate(ts->vec_sol, &ts->vec_sensip_col));
  ts->forwardsetupcalled = PETSC_TRUE;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardReset - Reset the internal data structures used by forward sensitivity analysis

  Collective

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Level: advanced

.seealso: [](ch_ts), `TSCreate()`, `TSDestroy()`, `TSForwardSetUp()`
@*/
PetscErrorCode TSForwardReset(TS ts)
{
  TS quadts = ts->quadraturets;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscTryTypeMethod(ts, forwardreset);
  PetscCall(MatDestroy(&ts->mat_sensip));
  if (quadts) PetscCall(MatDestroy(&quadts->mat_sensip));
  PetscCall(VecDestroy(&ts->vec_sensip_col));
  ts->forward_solve      = PETSC_FALSE;
  ts->forwardsetupcalled = PETSC_FALSE;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardSetIntegralGradients - Set the vectors holding forward sensitivities of the integral term.

  Input Parameters:
+ ts        - the `TS` context obtained from `TSCreate()`
. numfwdint - number of integrals
- vp        - the vectors containing the gradients for each integral w.r.t. parameters

  Level: deprecated

.seealso: [](ch_ts), `TSForwardGetSensitivities()`, `TSForwardGetIntegralGradients()`, `TSForwardStep()`
@*/
PetscErrorCode TSForwardSetIntegralGradients(TS ts, PetscInt numfwdint, Vec *vp)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscCheck(!ts->numcost || ts->numcost == numfwdint, PetscObjectComm((PetscObject)ts), PETSC_ERR_USER, "The number of cost functions (2nd parameter of TSSetCostIntegrand()) is inconsistent with the one set by TSSetCostIntegrand()");
  if (!ts->numcost) ts->numcost = numfwdint;

  ts->vecs_integral_sensip = vp;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardGetIntegralGradients - Returns the forward sensitivities of the integral term.

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Output Parameters:
+ numfwdint - number of integrals
- vp        - the vectors containing the gradients for each integral w.r.t. parameters

  Level: deprecated

.seealso: [](ch_ts), `TSForwardSetSensitivities()`, `TSForwardSetIntegralGradients()`, `TSForwardStep()`
@*/
PetscErrorCode TSForwardGetIntegralGradients(TS ts, PetscInt *numfwdint, Vec **vp)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(vp, 3);
  if (numfwdint) *numfwdint = ts->numcost;
  if (vp) *vp = ts->vecs_integral_sensip;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardStep - Compute the forward sensitivity for one time step.

  Collective

  Input Parameter:
. ts - time stepping context

  Level: advanced

  Notes:
  This function cannot be called until `TSStep()` has been completed.

.seealso: [](ch_ts), `TSForwardSetSensitivities()`, `TSForwardGetSensitivities()`, `TSForwardSetIntegralGradients()`, `TSForwardGetIntegralGradients()`, `TSForwardSetUp()`
@*/
PetscErrorCode TSForwardStep(TS ts)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscCall(PetscLogEventBegin(TS_ForwardStep, ts, 0, 0, 0));
  PetscUseTypeMethod(ts, forwardstep);
  PetscCall(PetscLogEventEnd(TS_ForwardStep, ts, 0, 0, 0));
  PetscCheck(ts->reason >= 0 || !ts->errorifstepfailed, PetscObjectComm((PetscObject)ts), PETSC_ERR_NOT_CONVERGED, "TSFowardStep has failed due to %s", TSConvergedReasons[ts->reason]);
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardSetSensitivities - Sets the initial value of the trajectory sensitivities of solution  w.r.t. the problem parameters and initial values.

  Logically Collective

  Input Parameters:
+ ts   - the `TS` context obtained from `TSCreate()`
. nump - number of parameters
- Smat - sensitivities with respect to the parameters, the number of entries in these vectors is the same as the number of parameters

  Level: beginner

  Notes:
  Forward sensitivity is also called 'trajectory sensitivity' in some fields such as power systems.
  This function turns on a flag to trigger `TSSolve()` to compute forward sensitivities automatically.
  You must call this function before `TSSolve()`.
  The entries in the sensitivity matrix must be correctly initialized with the values S = dy/dp|startingtime.

.seealso: [](ch_ts), `TSForwardGetSensitivities()`, `TSForwardSetIntegralGradients()`, `TSForwardGetIntegralGradients()`, `TSForwardStep()`
@*/
PetscErrorCode TSForwardSetSensitivities(TS ts, PetscInt nump, Mat Smat)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(Smat, MAT_CLASSID, 3);
  ts->forward_solve = PETSC_TRUE;
  if (nump == PETSC_DEFAULT) {
    PetscCall(MatGetSize(Smat, NULL, &ts->num_parameters));
  } else ts->num_parameters = nump;
  PetscCall(PetscObjectReference((PetscObject)Smat));
  PetscCall(MatDestroy(&ts->mat_sensip));
  ts->mat_sensip = Smat;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardGetSensitivities - Returns the trajectory sensitivities

  Not Collective, but Smat returned is parallel if ts is parallel

  Output Parameters:
+ ts   - the `TS` context obtained from `TSCreate()`
. nump - number of parameters
- Smat - sensitivities with respect to the parameters, the number of entries in these vectors is the same as the number of parameters

  Level: intermediate

.seealso: [](ch_ts), `TSForwardSetSensitivities()`, `TSForwardSetIntegralGradients()`, `TSForwardGetIntegralGradients()`, `TSForwardStep()`
@*/
PetscErrorCode TSForwardGetSensitivities(TS ts, PetscInt *nump, Mat *Smat)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  if (nump) *nump = ts->num_parameters;
  if (Smat) *Smat = ts->mat_sensip;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardCostIntegral - Evaluate the cost integral in the forward run.

  Collective

  Input Parameter:
. ts - time stepping context

  Level: advanced

  Note:
  This function cannot be called until `TSStep()` has been completed.

.seealso: [](ch_ts), `TS`, `TSSolve()`, `TSAdjointCostIntegral()`
@*/
PetscErrorCode TSForwardCostIntegral(TS ts)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscUseTypeMethod(ts, forwardintegral);
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardSetInitialSensitivities - Set initial values for tangent linear sensitivities

  Collective

  Input Parameters:
+ ts   - the `TS` context obtained from `TSCreate()`
- didp - parametric sensitivities of the initial condition

  Level: intermediate

  Notes:
  `TSSolve()` allows users to pass the initial solution directly to `TS`. But the tangent linear variables cannot be initialized in this way.
  This function is used to set initial values for tangent linear variables.

.seealso: [](ch_ts), `TS`, `TSForwardSetSensitivities()`
@*/
PetscErrorCode TSForwardSetInitialSensitivities(TS ts, Mat didp)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscValidHeaderSpecific(didp, MAT_CLASSID, 2);
  if (!ts->mat_sensip) PetscCall(TSForwardSetSensitivities(ts, PETSC_DEFAULT, didp));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSForwardGetStages - Get the number of stages and the tangent linear sensitivities at the intermediate stages

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Output Parameters:
+ ns - number of stages
- S  - tangent linear sensitivities at the intermediate stages

  Level: advanced

.seealso: `TS`
@*/
PetscErrorCode TSForwardGetStages(TS ts, PetscInt *ns, Mat **S)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);

  if (!ts->ops->getstages) *S = NULL;
  else PetscUseTypeMethod(ts, forwardgetstages, ns, S);
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSCreateQuadratureTS - Create a sub-`TS` that evaluates integrals over time

  Input Parameters:
+ ts  - the `TS` context obtained from `TSCreate()`
- fwd - flag indicating whether to evaluate cost integral in the forward run or the adjoint run

  Output Parameter:
. quadts - the child `TS` context

  Level: intermediate

.seealso: [](ch_ts), `TSGetQuadratureTS()`
@*/
PetscErrorCode TSCreateQuadratureTS(TS ts, PetscBool fwd, TS *quadts)
{
  char prefix[128];

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  PetscAssertPointer(quadts, 3);
  PetscCall(TSDestroy(&ts->quadraturets));
  PetscCall(TSCreate(PetscObjectComm((PetscObject)ts), &ts->quadraturets));
  PetscCall(PetscObjectIncrementTabLevel((PetscObject)ts->quadraturets, (PetscObject)ts, 1));
  PetscCall(PetscSNPrintf(prefix, sizeof(prefix), "%squad_", ((PetscObject)ts)->prefix ? ((PetscObject)ts)->prefix : ""));
  PetscCall(TSSetOptionsPrefix(ts->quadraturets, prefix));
  *quadts = ts->quadraturets;

  if (ts->numcost) {
    PetscCall(VecCreateSeq(PETSC_COMM_SELF, ts->numcost, &(*quadts)->vec_sol));
  } else {
    PetscCall(VecCreateSeq(PETSC_COMM_SELF, 1, &(*quadts)->vec_sol));
  }
  ts->costintegralfwd = fwd;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSGetQuadratureTS - Return the sub-`TS` that evaluates integrals over time

  Input Parameter:
. ts - the `TS` context obtained from `TSCreate()`

  Output Parameters:
+ fwd    - flag indicating whether to evaluate cost integral in the forward run or the adjoint run
- quadts - the child `TS` context

  Level: intermediate

.seealso: [](ch_ts), `TSCreateQuadratureTS()`
@*/
PetscErrorCode TSGetQuadratureTS(TS ts, PetscBool *fwd, TS *quadts)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts, TS_CLASSID, 1);
  if (fwd) *fwd = ts->costintegralfwd;
  if (quadts) *quadts = ts->quadraturets;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  TSComputeSNESJacobian - Compute the Jacobian needed for the `SNESSolve()` in `TS`

  Collective

  Input Parameters:
+ ts - the `TS` context obtained from `TSCreate()`
- x  - state vector

  Output Parameters:
+ J    - Jacobian matrix
- Jpre - preconditioning matrix for J (may be same as J)

  Level: developer

  Note:
  Uses finite differencing when `TS` Jacobian is not available.

.seealso: `SNES`, `TS`, `SNESSetJacobian()`, `TSSetRHSJacobian()`, `TSSetIJacobian()`
@*/
PetscErrorCode TSComputeSNESJacobian(TS ts, Vec x, Mat J, Mat Jpre)
{
  SNES snes                                          = ts->snes;
  PetscErrorCode (*jac)(SNES, Vec, Mat, Mat, void *) = NULL;

  PetscFunctionBegin;
  /*
    Unlike implicit methods, explicit methods do not have SNESMatFDColoring in the snes object
    because SNESSolve() has not been called yet; so querying SNESMatFDColoring does not work for
    explicit methods. Instead, we check the Jacobian compute function directly to determine if FD
    coloring is used.
  */
  PetscCall(SNESGetJacobian(snes, NULL, NULL, &jac, NULL));
  if (jac == SNESComputeJacobianDefaultColor) {
    Vec f;
    PetscCall(SNESSetSolution(snes, x));
    PetscCall(SNESGetFunction(snes, &f, NULL, NULL));
    /* Force MatFDColoringApply to evaluate the SNES residual function for the base vector */
    PetscCall(SNESComputeFunction(snes, x, f));
  }
  PetscCall(SNESComputeJacobian(snes, x, J, Jpre));
  PetscFunctionReturn(PETSC_SUCCESS);
}
