/*
    Provides a PETSc interface to RADAU5 solver.

*/
#include <petsc/private/tsimpl.h> /*I   "petscts.h"   I*/

typedef struct {
  Vec work, workf;
} TS_Radau5;

static void FVPOL(int *N, double *X, double *Y, double *F, double *RPAR, void *IPAR)
{
  TS             ts    = (TS)IPAR;
  TS_Radau5     *cvode = (TS_Radau5 *)ts->data;
  DM             dm;
  DMTS           tsdm;
  TSIFunctionFn *ifunction;

  PetscCallAbort(PETSC_COMM_SELF, VecPlaceArray(cvode->work, Y));
  PetscCallAbort(PETSC_COMM_SELF, VecPlaceArray(cvode->workf, F));

  /* Now compute the right-hand side function, via IFunction unless only the more efficient RHSFunction is set */
  PetscCallAbort(PETSC_COMM_SELF, TSGetDM(ts, &dm));
  PetscCallAbort(PETSC_COMM_SELF, DMGetDMTS(dm, &tsdm));
  PetscCallAbort(PETSC_COMM_SELF, DMTSGetIFunction(dm, &ifunction, NULL));
  if (!ifunction) {
    PetscCallAbort(PETSC_COMM_SELF, TSComputeRHSFunction(ts, *X, cvode->work, cvode->workf));
  } else { /* If rhsfunction is also set, this computes both parts and scale them to the right-hand side */
    Vec yydot;

    PetscCallAbort(PETSC_COMM_SELF, VecDuplicate(cvode->work, &yydot));
    PetscCallAbort(PETSC_COMM_SELF, VecZeroEntries(yydot));
    PetscCallAbort(PETSC_COMM_SELF, TSComputeIFunction(ts, *X, cvode->work, yydot, cvode->workf, PETSC_FALSE));
    PetscCallAbort(PETSC_COMM_SELF, VecScale(cvode->workf, -1.));
    PetscCallAbort(PETSC_COMM_SELF, VecDestroy(&yydot));
  }

  PetscCallAbort(PETSC_COMM_SELF, VecResetArray(cvode->work));
  PetscCallAbort(PETSC_COMM_SELF, VecResetArray(cvode->workf));
}

static void JVPOL(PetscInt *N, PetscScalar *X, PetscScalar *Y, PetscScalar *DFY, int *LDFY, PetscScalar *RPAR, void *IPAR)
{
  TS         ts    = (TS)IPAR;
  TS_Radau5 *cvode = (TS_Radau5 *)ts->data;
  Vec        yydot;
  Mat        mat;
  PetscInt   n;

  PetscCallAbort(PETSC_COMM_SELF, VecPlaceArray(cvode->work, Y));
  PetscCallAbort(PETSC_COMM_SELF, VecDuplicate(cvode->work, &yydot));
  PetscCallAbort(PETSC_COMM_SELF, VecGetSize(yydot, &n));
  PetscCallAbort(PETSC_COMM_SELF, MatCreateSeqDense(PETSC_COMM_SELF, n, n, DFY, &mat));
  PetscCallAbort(PETSC_COMM_SELF, VecZeroEntries(yydot));
  PetscCallAbort(PETSC_COMM_SELF, TSComputeIJacobian(ts, *X, cvode->work, yydot, 0, mat, mat, PETSC_FALSE));
  PetscCallAbort(PETSC_COMM_SELF, MatScale(mat, -1.0));
  PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&mat));
  PetscCallAbort(PETSC_COMM_SELF, VecDestroy(&yydot));
  PetscCallAbort(PETSC_COMM_SELF, VecResetArray(cvode->work));
}

static void SOLOUT(int *NR, double *XOLD, double *X, double *Y, double *CONT, double *LRC, int *N, double *RPAR, void *IPAR, int *IRTRN)
{
  TS         ts    = (TS)IPAR;
  TS_Radau5 *cvode = (TS_Radau5 *)ts->data;

  PetscCallAbort(PETSC_COMM_SELF, VecPlaceArray(cvode->work, Y));
  ts->time_step = *X - *XOLD;
  PetscCallAbort(PETSC_COMM_SELF, TSMonitor(ts, *NR - 1, *X, cvode->work));
  PetscCallAbort(PETSC_COMM_SELF, VecResetArray(cvode->work));
}

static void radau5_(int *, void *, double *, double *, double *, double *, double *, double *, int *, void *, int *, int *, int *, void *, int *, int *, int *, void *, int *, double *, int *, int *, int *, double *, void *, int *);

static PetscErrorCode TSSolve_Radau5(TS ts)
{
  TS_Radau5   *cvode = (TS_Radau5 *)ts->data;
  PetscScalar *Y, *WORK, X, XEND, RTOL, ATOL, H, RPAR;
  PetscInt     ND, *IWORK, LWORK, LIWORK, MUJAC, MLMAS, MUMAS, IDID, ITOL;
  int          IJAC, MLJAC, IMAS, IOUT;

  PetscFunctionBegin;
  PetscCall(VecGetArray(ts->vec_sol, &Y));
  PetscCall(VecGetSize(ts->vec_sol, &ND));
  PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, ND, NULL, &cvode->work));
  PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, ND, NULL, &cvode->workf));

  LWORK  = 4 * ND * ND + 12 * ND + 20;
  LIWORK = 3 * ND + 20;

  PetscCall(PetscCalloc2(LWORK, &WORK, LIWORK, &IWORK));

  /* C --- PARAMETER IN THE DIFFERENTIAL EQUATION */
  RPAR = 1.0e-6;
  /* C --- COMPUTE THE JACOBIAN ANALYTICALLY */
  IJAC = 1;
  /* C --- JACOBIAN IS A FULL MATRIX */
  MLJAC = ND;
  /* C --- DIFFERENTIAL EQUATION IS IN EXPLICIT FORM*/
  IMAS = 0;
  /* C --- OUTPUT ROUTINE IS USED DURING INTEGRATION*/
  IOUT = 1;
  /* C --- INITIAL VALUES*/
  X = ts->ptime;
  /* C --- ENDPOINT OF INTEGRATION */
  XEND = ts->max_time;
  /* C --- REQUIRED TOLERANCE */
  RTOL = ts->rtol;
  ATOL = ts->atol;
  ITOL = 0;
  /* C --- INITIAL STEP SIZE */
  H = ts->time_step;

  /* output MUJAC MLMAS IDID; currently all ignored */

  radau5_(&ND, FVPOL, &X, Y, &XEND, &H, &RTOL, &ATOL, &ITOL, JVPOL, &IJAC, &MLJAC, &MUJAC, FVPOL, &IMAS, &MLMAS, &MUMAS, SOLOUT, &IOUT, WORK, &LWORK, IWORK, &LIWORK, &RPAR, (void *)ts, &IDID);

  PetscCall(PetscFree2(WORK, IWORK));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode TSDestroy_Radau5(TS ts)
{
  TS_Radau5 *cvode = (TS_Radau5 *)ts->data;

  PetscFunctionBegin;
  PetscCall(VecDestroy(&cvode->work));
  PetscCall(VecDestroy(&cvode->workf));
  PetscCall(PetscFree(ts->data));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*MC
      TSRADAU5 - ODE solver using the external RADAU5 package, requires ./configure --download-radau5

    Level: beginner

    Notes:
    This uses its own nonlinear solver and dense matrix direct solver so PETSc `SNES` and `KSP` options do not apply.

    Uses its own time-step adaptivity (but uses the TS rtol and atol, and initial timestep)

    Uses its own memory for the dense matrix storage and factorization

    Can only handle ODEs of the form \cdot{u} = -F(t,u) + G(t,u)

.seealso: [](ch_ts), `TSCreate()`, `TS`, `TSSetType()`, `TSType`
M*/
PETSC_EXTERN PetscErrorCode TSCreate_Radau5(TS ts)
{
  TS_Radau5 *cvode;

  PetscFunctionBegin;
  ts->ops->destroy       = TSDestroy_Radau5;
  ts->ops->solve         = TSSolve_Radau5;
  ts->default_adapt_type = TSADAPTNONE;

  PetscCall(PetscNew(&cvode));
  ts->data = (void *)cvode;
  PetscFunctionReturn(PETSC_SUCCESS);
}
