/*
  Code for time stepping with the Partitioned Runge-Kutta method

  Notes:
  1) The general system is written as
     Udot = F(t,U) for nonsplit RHS multi-rate RK,
     user should give the indexes for both slow and fast components;
  2) The general system is written as
     Usdot = Fs(t,Us,Uf)
     Ufdot = Ff(t,Us,Uf) for split  RHS multi-rate RK,
     user should partioned RHS by themselves and also provide the indexes for both slow and fast components.
*/

#include <petsc/private/tsimpl.h>                /*I   "petscts.h"   I*/
#include <petscdm.h>

static TSMPRKType TSMPRKDefault = TSMPRKPM2;
static PetscBool TSMPRKRegisterAllCalled;
static PetscBool TSMPRKPackageInitialized;

typedef struct _MPRKTableau *MPRKTableau;
struct _MPRKTableau {
  char       *name;
  PetscInt   order;                          /* Classical approximation order of the method i  */
  PetscInt   s;                              /* Number of stages                               */
  PetscReal  *Af,*bf,*cf;                    /* Tableau for fast components                    */
  PetscReal  *As,*bs,*cs;                    /* Tableau for slow components                    */
};
typedef struct _MPRKTableauLink *MPRKTableauLink;
struct _MPRKTableauLink {
  struct _MPRKTableau tab;
  MPRKTableauLink     next;
};
static MPRKTableauLink MPRKTableauList;

typedef struct {
  MPRKTableau          tableau;
  TSMPRKMultirateType mprkmtype;
  Vec                 *Y;                          /* States computed during the step                           */
  Vec                 Ytmp;
  Vec                 *YdotRHS_fast;               /* Function evaluations by fast tableau for fast components  */
  Vec                 *YdotRHS_slow;               /* Function evaluations by slow tableau for slow components  */
  PetscScalar         *work_fast;                  /* Scalar work_fast by fast tableau                          */
  PetscScalar         *work_slow;                  /* Scalar work_slow by slow tableau                          */
  PetscReal           stage_time;
  TSStepStatus        status;
  PetscReal           ptime;
  PetscReal           time_step;
  IS                  is_slow,is_fast;
  TS                  subts_slow,subts_fast;
} TS_MPRK;

/*MC
     TSMPRKPM2 - Second Order Partitioned Runge Kutta scheme.

     This method has four stages for both slow and fast parts.

     Options database:
.     -ts_mprk_type pm2

     Level: advanced

.seealso: TSMPRK, TSMPRKType, TSMPRKSetType()
M*/
/*MC
     TSMPRKPM3 - Third Order Partitioned Runge Kutta scheme.

     This method has eight stages for both slow and fast parts.

     Options database:
.     -ts_mprk_type pm3  (put here temporarily)

     Level: advanced

.seealso: TSMPRK, TSMPRKType, TSMPRKSetType()
M*/
/*MC
     TSMPRKP2 - Second Order Partitioned Runge Kutta scheme.

     This method has five stages for both slow and fast parts.

     Options database:
.     -ts_mprk_type p2

     Level: advanced

.seealso: TSMPRK, TSMPRKType, TSMPRKSetType()
M*/
/*MC
     TSMPRKP3 - Third Order Partitioned Runge Kutta scheme.

     This method has ten stages for both slow and fast parts.

     Options database:
.     -ts_mprk_type p3

     Level: advanced

.seealso: TSMPRK, TSMPRKType, TSMPRKSetType()
M*/

/*@C
  TSMPRKRegisterAll - Registers all of the Partirioned Runge-Kutta explicit methods in TSMPRK

  Not Collective, but should be called by all processes which will need the schemes to be registered

  Level: advanced

.keywords: TS, TSMPRK, register, all

.seealso:  TSMPRKRegisterDestroy()
@*/
PetscErrorCode TSMPRKRegisterAll(void)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  if (TSMPRKRegisterAllCalled) PetscFunctionReturn(0);
  TSMPRKRegisterAllCalled = PETSC_TRUE;

#define RC PetscRealConstant
  {
    const PetscReal
      As[4][4] = {{0,0,0,0},
                  {RC(1.0),0,0,0},
                  {0,0,0,0},
                  {0,0,RC(1.0),0}},
      A[4][4]  = {{0,0,0,0},
                  {RC(0.5),0,0,0},
                  {RC(0.25),RC(0.25),0,0},
                  {RC(0.25),RC(0.25),RC(0.5),0}},
      bs[4]    = {RC(0.25),RC(0.25),RC(0.25),RC(0.25)},
      b[4]     = {RC(0.25),RC(0.25),RC(0.25),RC(0.25)};
    ierr = TSMPRKRegister(TSMPRKPM2,2,4,&As[0][0],bs,NULL,&A[0][0],b,NULL);CHKERRQ(ierr);
  }

  /*{
      const PetscReal
        As[8][8] = {{0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(2.0),0,0,0,0,0,0,0},
                    {RC(-1.0)/RC(6.0),RC(2.0)/RC(3.0),0,0,0,0,0,0},
                    {RC(1.0)/RC(3.0),RC(-1.0)/RC(3.0),RC(1.0),0,0,0,0,0},
                    {0,0,0,0,0,0,0,0},
                    {0,0,0,0,RC(1.0)/RC(2.0),0,0,0},
                    {0,0,0,0,RC(-1.0)/RC(6.0),RC(2.0)/RC(3.0),0,0},
                    {0,0,0,0,RC(1.0)/RC(3.0),RC(-1.0)/RC(3.0),RC(1.0),0}},
         A[8][8] = {{0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(4.0),0,0,0,0,0,0,0},
                    {RC(-1.0)/RC(12.0),RC(1.0)/RC(3.0),0,0,0,0,0,0},
                    {RC(1.0)/RC(6.0),RC(-1.0)/RC(6.0),RC(1.0)/RC(2.0),0,0,0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),RC(1.0)/RC(4.0),0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),RC(-1.0)/RC(12.0),RC(1.0)/RC(3.0),0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(-1.0)/RC(6.0),RC(1.0)/RC(2.0),0}},
          bs[8] = {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0)},
           b[8] = {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0)};
           ierr = TSMPRKRegister(TSMPRKPM3,3,8,&As[0][0],bs,NULL,&A[0][0],b,NULL);CHKERRQ(ierr);
  }*/

  {
    const PetscReal
      As[5][5] = {{0,0,0,0,0},
                  {RC(1.0)/RC(2.0),0,0,0,0},
                  {RC(1.0)/RC(2.0),0,0,0,0},
                  {RC(1.0),0,0,0,0},
                  {RC(1.0),0,0,0,0}},
      A[5][5]  = {{0,0,0,0,0},
                  {RC(1.0)/RC(2.0),0,0,0,0},
                  {RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),0,0,0},
                  {RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),RC(1.0)/RC(2.0),0,0},
                  {RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),0}},
      bs[5]     = {RC(1.0)/RC(2.0),0,0,0,RC(1.0)/RC(2.0)},
      b[5]      = {RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),RC(1.0)/RC(4.0),0};
    ierr = TSMPRKRegister(TSMPRKP2,2,5,&As[0][0],bs,NULL,&A[0][0],b,NULL);CHKERRQ(ierr);
  }

  {
    const PetscReal
      As[10][10] = {{0,0,0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(4.0),0,0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(4.0),0,0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(2.0),0,0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(2.0),0,0,0,0,0,0,0,0,0},
                    {RC(-1.0)/RC(6.0),0,0,0,RC(2.0)/RC(3.0),0,0,0,0,0},
                    {RC(1.0)/RC(12.0),0,0,0,RC(1.0)/RC(6.0),RC(1.0)/RC(2.0),0,0,0,0},
                    {RC(1.0)/RC(12.0),0,0,0,RC(1.0)/RC(6.0),RC(1.0)/RC(2.0),0,0,0,0},
                    {RC(1.0)/RC(3.0),0,0,0,RC(-1.0)/RC(3.0),RC(1.0),0,0,0,0},
                    {RC(1.0)/RC(3.0),0,0,0,RC(-1.0)/RC(3.0),RC(1.0),0,0,0,0}},
      A[10][10]  = {{0,0,0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(4.0),0,0,0,0,0,0,0,0,0},
                    {RC(-1.0)/RC(12.0),RC(1.0)/RC(3.0),0,0,0,0,0,0,0,0},
                    {RC(1.0)/RC(6.0),RC(-1.0)/RC(6.0),RC(1.0)/RC(2.0),0,0,0,0,0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,0,0,0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,0,0,0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,RC(1.0)/RC(4.0),0,0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,RC(-1.0)/RC(12.0),RC(1.0)/RC(3.0),0,0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,RC(1.0)/RC(6.0),RC(-1.0)/RC(6.0),RC(1.0)/RC(2.0),0,0},
                    {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0}},
      bs[10]     = {RC(1.0)/RC(6.0),0,0,0,RC(1.0)/RC(3.0),RC(1.0)/RC(3.0),0,0,0,RC(1.0)/RC(6.0)},
      b[10]      = {RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0,RC(1.0)/RC(12.0),RC(1.0)/RC(6.0),RC(1.0)/RC(6.0),RC(1.0)/RC(12.0),0};
    ierr = TSMPRKRegister(TSMPRKP3,3,10,&As[0][0],bs,NULL,&A[0][0],b,NULL);CHKERRQ(ierr);
  }
#undef RC
  PetscFunctionReturn(0);
}

/*@C
   TSMPRKRegisterDestroy - Frees the list of schemes that were registered by TSMPRKRegister().

   Not Collective

   Level: advanced

.keywords: TSMPRK, register, destroy
.seealso: TSMPRKRegister(), TSMPRKRegisterAll()
@*/
PetscErrorCode TSMPRKRegisterDestroy(void)
{
  PetscErrorCode ierr;
  MPRKTableauLink link;

  PetscFunctionBegin;
  while ((link = MPRKTableauList)) {
    MPRKTableau t = &link->tab;
    MPRKTableauList = link->next;
    ierr = PetscFree3(t->Af,t->bf,t->cf);CHKERRQ(ierr);
    ierr = PetscFree3(t->As,t->bs,t->cs);CHKERRQ(ierr);
    ierr = PetscFree (t->name);CHKERRQ(ierr);
    ierr = PetscFree (link);CHKERRQ(ierr);
  }
  TSMPRKRegisterAllCalled = PETSC_FALSE;
  PetscFunctionReturn(0);
}

/*@C
  TSMPRKInitializePackage - This function initializes everything in the TSMPRK package. It is called
  from PetscDLLibraryRegister() when using dynamic libraries, and on the first call to TSCreate_MPRK()
  when using static libraries.

  Level: developer

.keywords: TS, TSMPRK, initialize, package
.seealso: PetscInitialize()
@*/
PetscErrorCode TSMPRKInitializePackage(void)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  if (TSMPRKPackageInitialized) PetscFunctionReturn(0);
  TSMPRKPackageInitialized = PETSC_TRUE;
  ierr = TSMPRKRegisterAll();CHKERRQ(ierr);
  ierr = PetscRegisterFinalize(TSMPRKFinalizePackage);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*@C
  TSRKFinalizePackage - This function destroys everything in the TSMPRK package. It is
  called from PetscFinalize().

  Level: developer

.keywords: Petsc, destroy, package
.seealso: PetscFinalize()
@*/
PetscErrorCode TSMPRKFinalizePackage(void)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  TSMPRKPackageInitialized = PETSC_FALSE;
  ierr = TSMPRKRegisterDestroy();CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*@C
   TSMPRKRegister - register a MPRK scheme by providing the entries in the Butcher tableau

   Not Collective, but the same schemes should be registered on all processes on which they will be used

   Input Parameters:
+  name - identifier for method
.  order - approximation order of method
.  s  - number of stages, this is the dimension of the matrices below
.  Af - stage coefficients for fast components(dimension s*s, row-major)
.  bf - step completion table for fast components(dimension s)
.  cf - abscissa for fast components(dimension s)
.  As - stage coefficients for slow components(dimension s*s, row-major)
.  bs - step completion table for slow components(dimension s)
-  cs - abscissa for slow components(dimension s)

   Notes:
   Several MPRK methods are provided, this function is only needed to create new methods.

   Level: advanced

.keywords: TS, register

.seealso: TSMPRK
@*/
PetscErrorCode TSMPRKRegister(TSMPRKType name,PetscInt order,PetscInt s,
                              const PetscReal As[],const PetscReal bs[],const PetscReal cs[],
                              const PetscReal Af[],const PetscReal bf[],const PetscReal cf[])
{
  MPRKTableauLink link;
  MPRKTableau     t;
  PetscInt        i,j;
  PetscErrorCode  ierr;

  PetscFunctionBegin;
  PetscValidCharPointer(name,1);
  PetscValidRealPointer(Af,7);
  if (bf) PetscValidRealPointer(bf,8);
  if (cf) PetscValidRealPointer(cf,9);
  PetscValidRealPointer(As,4);
  if (bs) PetscValidRealPointer(bs,5);
  if (cs) PetscValidRealPointer(cs,6);

  ierr = PetscNew(&link);CHKERRQ(ierr);
  t = &link->tab;

  ierr = PetscStrallocpy(name,&t->name);CHKERRQ(ierr);
  t->order = order;
  t->s = s;
  ierr = PetscMalloc3(s*s,&t->Af,s,&t->bf,s,&t->cf);CHKERRQ(ierr);
  ierr = PetscMemcpy(t->Af,Af,s*s*sizeof(Af[0]));CHKERRQ(ierr);
  if (bf) {
    ierr = PetscMemcpy(t->bf,bf,s*sizeof(bf[0]));CHKERRQ(ierr);
  }
  else
    for (i=0; i<s; i++) t->bf[i] = Af[(s-1)*s+i];
  if (cf) {
    ierr = PetscMemcpy(t->cf,cf,s*sizeof(cf[0]));CHKERRQ(ierr);
  }
  else {
    for (i=0; i<s; i++)
      for (j=0,t->cf[i]=0; j<s; j++)
        t->cf[i] += Af[i*s+j];
  }
  ierr = PetscMalloc3(s*s,&t->As,s,&t->bs,s,&t->cs);CHKERRQ(ierr);
  ierr = PetscMemcpy(t->As,As,s*s*sizeof(As[0]));CHKERRQ(ierr);
  if (bs) {
    ierr = PetscMemcpy(t->bs,bs,s*sizeof(bs[0]));CHKERRQ(ierr);
  }
  else
    for (i=0; i<s; i++) t->bs[i] = As[(s-1)*s+i];
  if (cs) {
    ierr = PetscMemcpy(t->cs,cs,s*sizeof(cs[0]));CHKERRQ(ierr);
  }
  else {
    for (i=0; i<s; i++)
      for (j=0,t->cs[i]=0; j<s; j++)
        t->cs[i] += As[i*s+j];
  }
  link->next = MPRKTableauList;
  MPRKTableauList = link;
  PetscFunctionReturn(0);
}

static PetscErrorCode TSMPRKSetSplits(TS ts)
{
  TS_MPRK        *mprk = (TS_MPRK*)ts->data;
  DM             dm,subdm,newdm;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = TSRHSSplitGetSubTS(ts,"slow",&mprk->subts_slow);CHKERRQ(ierr);
  ierr = TSRHSSplitGetSubTS(ts,"fast",&mprk->subts_fast);CHKERRQ(ierr);
  if (!mprk->subts_slow || !mprk->subts_fast) SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_USER,"Must set up the RHSFunctions for 'slow' and 'fast' components using TSRHSSplitSetRHSFunction() or calling TSSetRHSFunction() for each sub-TS");

  /* Only copy */
  ierr = TSGetDM(ts,&dm);CHKERRQ(ierr);
  ierr = DMClone(dm,&newdm);CHKERRQ(ierr);
  ierr = TSGetDM(mprk->subts_fast,&subdm);CHKERRQ(ierr);
  ierr = DMCopyDMTS(subdm,newdm);CHKERRQ(ierr);
  ierr = DMCopyDMSNES(subdm,newdm);CHKERRQ(ierr);
  ierr = TSSetDM(mprk->subts_fast,newdm);CHKERRQ(ierr);
  ierr = DMDestroy(&newdm);CHKERRQ(ierr);
  ierr = DMClone(dm,&newdm);CHKERRQ(ierr);
  ierr = TSGetDM(mprk->subts_slow,&subdm);CHKERRQ(ierr);
  ierr = DMCopyDMTS(subdm,newdm);CHKERRQ(ierr);
  ierr = DMCopyDMSNES(subdm,newdm);CHKERRQ(ierr);
  ierr = TSSetDM(mprk->subts_slow,newdm);CHKERRQ(ierr);
  ierr = DMDestroy(&newdm);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*
 This if for nonsplit RHS MPRK
 The step completion formula is

 x1 = x0 + h b^T YdotRHS

*/
static PetscErrorCode TSEvaluateStep_MPRK(TS ts,PetscInt order,Vec X,PetscBool *done)
{
  TS_MPRK        *mprk = (TS_MPRK*)ts->data;
  MPRKTableau    tab = mprk->tableau;
  Vec            Xtmp = mprk->Ytmp,Xslow,Xfast;
  PetscScalar    *wf = mprk->work_fast,*ws = mprk->work_slow;
  PetscReal      h = ts->time_step;
  PetscInt       s = tab->s,j;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = VecCopy(ts->vec_sol,X);CHKERRQ(ierr);
  for (j=0; j<s; j++) wf[j] = h*tab->bf[j];
  for (j=0; j<s; j++) ws[j] = h*tab->bs[j];
  ierr = VecCopy(X,Xtmp);CHKERRQ(ierr);
  ierr = VecMAXPY(Xtmp,s,ws,mprk->YdotRHS_slow);CHKERRQ(ierr);
  ierr = VecGetSubVector(Xtmp,mprk->is_slow,&Xslow);CHKERRQ(ierr);
  ierr = VecISCopy(X,mprk->is_slow,SCATTER_FORWARD,Xslow);CHKERRQ(ierr);
  ierr = VecRestoreSubVector(Xtmp,mprk->is_slow,&Xslow);CHKERRQ(ierr);

  /* Update fast part of X, note that the slow part has been changed but is simply discarded here */
  ierr = VecCopy(X,Xtmp);CHKERRQ(ierr);
  ierr = VecMAXPY(Xtmp,s,wf,mprk->YdotRHS_fast);CHKERRQ(ierr);
  ierr = VecGetSubVector(Xtmp,mprk->is_fast,&Xfast);CHKERRQ(ierr);
  ierr = VecISCopy(X,mprk->is_fast,SCATTER_FORWARD,Xfast);CHKERRQ(ierr);
  ierr = VecRestoreSubVector(Xtmp,mprk->is_fast,&Xfast);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

static PetscErrorCode TSStep_MPRK(TS ts)
{
  TS_MPRK         *mprk = (TS_MPRK*)ts->data;
  Vec             *Y = mprk->Y,Ytmp = mprk->Ytmp,*YdotRHS_fast = mprk->YdotRHS_fast,*YdotRHS_slow = mprk->YdotRHS_slow;
  Vec             Yfast,Yslow;
  MPRKTableau     tab = mprk->tableau;
  const PetscInt  s   = tab->s;
  const PetscReal *Af = tab->Af,*cf = tab->cf,*As = tab->As,*cs = tab->cs;
  PetscScalar     *wf = mprk->work_fast, *ws = mprk->work_slow;
  PetscInt        i,j;
  PetscReal       next_time_step = ts->time_step,t = ts->ptime,h = ts->time_step;
  PetscErrorCode  ierr;

  PetscFunctionBegin;
  for (i=0; i<s; i++) {
    mprk->stage_time = t + h*cf[i];
    ierr = TSPreStage(ts,mprk->stage_time);CHKERRQ(ierr);

    /* update the satge value for all components by slow and fast tableau respectively */
    for (j=0; j<i; j++) ws[j] = h*As[i*s+j];
    ierr = VecCopy(ts->vec_sol,Ytmp);CHKERRQ(ierr);
    ierr = VecMAXPY(Ytmp,i,ws,YdotRHS_slow);CHKERRQ(ierr);
    ierr = VecGetSubVector(Ytmp,mprk->is_slow,&Yslow);CHKERRQ(ierr);
    ierr = VecISCopy(Y[i],mprk->is_slow,SCATTER_FORWARD,Yslow);CHKERRQ(ierr);
    ierr = VecRestoreSubVector(Ytmp,mprk->is_slow,&Yslow);CHKERRQ(ierr);

    for (j=0; j<i; j++) wf[j] = h*Af[i*s+j];
    ierr = VecCopy(ts->vec_sol,Ytmp);CHKERRQ(ierr);
    ierr = VecMAXPY(Ytmp,i,wf,YdotRHS_fast);CHKERRQ(ierr);
    ierr = VecGetSubVector(Ytmp,mprk->is_fast,&Yfast);CHKERRQ(ierr);
    ierr = VecISCopy(Y[i],mprk->is_fast,SCATTER_FORWARD,Yfast);CHKERRQ(ierr);
    ierr = VecRestoreSubVector(Ytmp,mprk->is_fast,&Yfast);CHKERRQ(ierr);

    ierr = TSPostStage(ts,mprk->stage_time,i,Y); CHKERRQ(ierr);
    /* compute the stage RHS by fast and slow tableau respectively */
    ierr = TSComputeRHSFunction(ts,t+h*cs[i],Y[i],YdotRHS_slow[i]);CHKERRQ(ierr);
    ierr = TSComputeRHSFunction(ts,t+h*cf[i],Y[i],YdotRHS_fast[i]);CHKERRQ(ierr);
  }
  ierr = TSEvaluateStep(ts,tab->order,ts->vec_sol,NULL);CHKERRQ(ierr);
  ts->ptime += ts->time_step;
  ts->time_step = next_time_step;
  PetscFunctionReturn(0);
}

/*
 This if for partitioned RHS MPRK
 The step completion formula is

 x1 = x0 + h b^T YdotRHS

*/
static PetscErrorCode TSEvaluateStep_MPRKSPLIT(TS ts,PetscInt order,Vec X,PetscBool *done)
{
  TS_MPRK        *mprk = (TS_MPRK*)ts->data;
  MPRKTableau    tab  = mprk->tableau;
  Vec            Xslow,Xfast; /* subvectors for slow and fast componets in X respectively */
  PetscScalar    *wf = mprk->work_fast,*ws = mprk->work_slow;
  PetscReal      h = ts->time_step;
  PetscInt       s = tab->s,j;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = VecCopy(ts->vec_sol,X);CHKERRQ(ierr);
  for (j=0; j<s; j++) wf[j] = h*tab->bf[j];
  for (j=0; j<s; j++) ws[j] = h*tab->bs[j];
  ierr = VecGetSubVector(X,mprk->is_slow,&Xslow);CHKERRQ(ierr);
  ierr = VecGetSubVector(X,mprk->is_fast,&Xfast);CHKERRQ(ierr);
  ierr = VecMAXPY(Xslow,s,ws,mprk->YdotRHS_slow);CHKERRQ(ierr);
  ierr = VecMAXPY(Xfast,s,wf,mprk->YdotRHS_fast);CHKERRQ(ierr);
  ierr = VecRestoreSubVector(X,mprk->is_slow,&Xfast);CHKERRQ(ierr);
  ierr = VecRestoreSubVector(X,mprk->is_fast,&Xslow);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

static PetscErrorCode TSStep_MPRKSPLIT(TS ts)
{
  TS_MPRK         *mprk = (TS_MPRK*)ts->data;
  MPRKTableau     tab = mprk->tableau;
  Vec             *Y = mprk->Y,*YdotRHS_fast = mprk->YdotRHS_fast, *YdotRHS_slow = mprk->YdotRHS_slow;
  Vec             Yslow,Yfast; /* subvectors for slow and fast components in Y[i] respectively */
  const PetscInt  s = tab->s;
  const PetscReal *Af = tab->Af,*cf = tab->cf,*As = tab->As,*cs = tab->cs;
  PetscScalar     *wf = mprk->work_fast, *ws = mprk->work_slow;
  PetscInt        i,j;
  PetscReal       next_time_step = ts->time_step,t = ts->ptime,h = ts->time_step;
  PetscErrorCode  ierr;

  PetscFunctionBegin;
  for (i=0; i<s; i++) {
    mprk->stage_time = t + h*cf[i];
    ierr = TSPreStage(ts,mprk->stage_time);CHKERRQ(ierr);
    /* calculate the stage value for fast and slow components respectively */
    ierr = VecCopy(ts->vec_sol,Y[i]);CHKERRQ(ierr);
    for (j=0; j<i; j++) wf[j] = h*Af[i*s+j];
    for (j=0; j<i; j++) ws[j] = h*As[i*s+j];
    ierr = VecGetSubVector(Y[i],mprk->is_slow,&Yslow);CHKERRQ(ierr);
    ierr = VecGetSubVector(Y[i],mprk->is_fast,&Yfast);CHKERRQ(ierr);
    ierr = VecMAXPY(Yslow,i,ws,YdotRHS_slow);CHKERRQ(ierr);
    ierr = VecMAXPY(Yfast,i,wf,YdotRHS_fast);CHKERRQ(ierr);
    ierr = VecRestoreSubVector(Y[i],mprk->is_slow,&Yslow);CHKERRQ(ierr);
    ierr = VecRestoreSubVector(Y[i],mprk->is_fast,&Yfast);CHKERRQ(ierr);
    ierr = TSPostStage(ts,mprk->stage_time,i,Y); CHKERRQ(ierr);
    /* calculate the stage RHS for slow and fast components respectively */
    ierr = TSComputeRHSFunction(mprk->subts_slow,t+h*cs[i],Y[i],YdotRHS_slow[i]);CHKERRQ(ierr);
    ierr = TSComputeRHSFunction(mprk->subts_fast,t+h*cf[i],Y[i],YdotRHS_fast[i]);CHKERRQ(ierr);
  }
  ierr = TSEvaluateStep(ts,tab->order,ts->vec_sol,NULL);CHKERRQ(ierr);
  ts->ptime += ts->time_step;
  ts->time_step = next_time_step;
  PetscFunctionReturn(0);
}

static PetscErrorCode TSMPRKTableauReset(TS ts)
{
  TS_MPRK        *mprk = (TS_MPRK*)ts->data;
  MPRKTableau    tab = mprk->tableau;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  if (!tab) PetscFunctionReturn(0);
  ierr = PetscFree(mprk->work_fast);CHKERRQ(ierr);
  ierr = PetscFree(mprk->work_slow);CHKERRQ(ierr);
  ierr = VecDestroyVecs(tab->s,&mprk->Y);CHKERRQ(ierr);
  if (mprk->mprkmtype == TSMPRKNONSPLIT) {
    ierr = VecDestroy(&mprk->Ytmp);CHKERRQ(ierr);
  }
  ierr = VecDestroyVecs(tab->s,&mprk->YdotRHS_fast);CHKERRQ(ierr);
  ierr = VecDestroyVecs(tab->s,&mprk->YdotRHS_slow);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

static PetscErrorCode TSReset_MPRK(TS ts)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = TSMPRKTableauReset(ts);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

static PetscErrorCode DMCoarsenHook_TSMPRK(DM fine,DM coarse,void *ctx)
{
  PetscFunctionBegin;
  PetscFunctionReturn(0);
}

static PetscErrorCode DMRestrictHook_TSMPRK(DM fine,Mat restrct,Vec rscale,Mat inject,DM coarse,void *ctx)
{
  PetscFunctionBegin;
  PetscFunctionReturn(0);
}

static PetscErrorCode DMSubDomainHook_TSMPRK(DM dm,DM subdm,void *ctx)
{
  PetscFunctionBegin;
  PetscFunctionReturn(0);
}

static PetscErrorCode DMSubDomainRestrictHook_TSMPRK(DM dm,VecScatter gscat,VecScatter lscat,DM subdm,void *ctx)
{
  PetscFunctionBegin;
  PetscFunctionReturn(0);
}

static PetscErrorCode TSMPRKTableauSetUp(TS ts)
{
  TS_MPRK        *mprk  = (TS_MPRK*)ts->data;
  MPRKTableau    tab = mprk->tableau;
  Vec            YdotRHS_fast,YdotRHS_slow;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = PetscMalloc1(tab->s,&mprk->work_fast);CHKERRQ(ierr);
  ierr = PetscMalloc1(tab->s,&mprk->work_slow);CHKERRQ(ierr);
  ierr = VecDuplicateVecs(ts->vec_sol,tab->s,&mprk->Y);CHKERRQ(ierr);
  if (mprk->mprkmtype == TSMPRKNONSPLIT) {
    ierr = VecDuplicateVecs(ts->vec_sol,tab->s,&mprk->YdotRHS_slow);CHKERRQ(ierr);
    ierr = VecDuplicateVecs(ts->vec_sol,tab->s,&mprk->YdotRHS_fast);CHKERRQ(ierr);
    ierr = VecDuplicate(ts->vec_sol,&mprk->Ytmp);CHKERRQ(ierr);
  }
  if (mprk->mprkmtype == TSMPRKSPLIT) {
    ierr = VecGetSubVector(ts->vec_sol,mprk->is_slow,&YdotRHS_slow);CHKERRQ(ierr);
    ierr = VecGetSubVector(ts->vec_sol,mprk->is_fast,&YdotRHS_fast);CHKERRQ(ierr);
    ierr = VecDuplicateVecs(YdotRHS_slow,tab->s,&mprk->YdotRHS_slow);CHKERRQ(ierr);
    ierr = VecDuplicateVecs(YdotRHS_fast,tab->s,&mprk->YdotRHS_fast);CHKERRQ(ierr);
    ierr = VecRestoreSubVector(ts->vec_sol,mprk->is_slow,&YdotRHS_slow);CHKERRQ(ierr);
    ierr = VecRestoreSubVector(ts->vec_sol,mprk->is_fast,&YdotRHS_fast);CHKERRQ(ierr);
  }
  PetscFunctionReturn(0);
}

static PetscErrorCode TSSetUp_MPRK(TS ts)
{
  TS_MPRK         *mprk = (TS_MPRK*)ts->data;
  DM             dm;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = TSRHSSplitGetIS(ts,"slow",&mprk->is_slow);CHKERRQ(ierr);
  ierr = TSRHSSplitGetIS(ts,"fast",&mprk->is_fast);CHKERRQ(ierr);
  if (!mprk->is_slow || !mprk->is_fast) SETERRQ(PetscObjectComm((PetscObject)ts),PETSC_ERR_USER,"Must set up RHSSplits with TSRHSSplitSetIS() using split names 'slow' and 'fast' respectively in order to use -ts_type mprk");
  ierr = TSCheckImplicitTerm(ts);CHKERRQ(ierr);
  ierr = TSMPRKTableauSetUp(ts);CHKERRQ(ierr);
  ierr = TSGetDM(ts,&dm);CHKERRQ(ierr);
  ierr = DMCoarsenHookAdd(dm,DMCoarsenHook_TSMPRK,DMRestrictHook_TSMPRK,ts);CHKERRQ(ierr);
  ierr = DMSubDomainHookAdd(dm,DMSubDomainHook_TSMPRK,DMSubDomainRestrictHook_TSMPRK,ts);CHKERRQ(ierr);
  ierr = PetscTryMethod(ts,"TSMPRKSetSplits_C",(TS),(ts));CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/* construct a database to chose nonsplit RHS mutirate mprk method or split RHS MPRK method */
const char *const TSMPRKMultirateTypes[] = {"NONSPLIT","SPLIT","TSMPRKMultirateType","TSMPRK",0};

static PetscErrorCode TSSetFromOptions_MPRK(PetscOptionItems *PetscOptionsObject,TS ts)
{
  TS_MPRK        *mprk = (TS_MPRK*)ts->data;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = PetscOptionsHead(PetscOptionsObject,"PRK ODE solver options");CHKERRQ(ierr);
  {
    MPRKTableauLink link;
    PetscInt        count,choice;
    PetscBool       flg;
    const char      **namelist;
    PetscInt        mprkmtype = 0;
    for (link=MPRKTableauList,count=0; link; link=link->next,count++) ;
    ierr = PetscMalloc1(count,(char***)&namelist);CHKERRQ(ierr);
    for (link=MPRKTableauList,count=0; link; link=link->next,count++) namelist[count] = link->tab.name;
    ierr = PetscOptionsEList("-ts_mprk_type","Family of MPRK method","TSMPRKSetType",(const char*const*)namelist,count,mprk->tableau->name,&choice,&flg);CHKERRQ(ierr);
    if (flg) {ierr = TSMPRKSetType(ts,namelist[choice]);CHKERRQ(ierr);}
    ierr = PetscFree(namelist);CHKERRQ(ierr);
    ierr = PetscOptionsEList("-ts_mprk_multirate_type","Use Combined RHS Multirate or Partioned RHS Multirat MPRK method","TSMPRKSetMultirateType",TSMPRKMultirateTypes,2,TSMPRKMultirateTypes[0],&mprkmtype,&flg);CHKERRQ(ierr);
     if (flg) {ierr = TSMPRKSetMultirateType(ts,mprkmtype);CHKERRQ(ierr);}
  }
  ierr = PetscOptionsTail();CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

static PetscErrorCode TSView_MPRK(TS ts,PetscViewer viewer)
{
  TS_MPRK        *mprk = (TS_MPRK*)ts->data;
  PetscBool      iascii;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii);CHKERRQ(ierr);
  if (iascii) {
    MPRKTableau tab  = mprk->tableau;
    TSMPRKType  mprktype;
    char        fbuf[512];
    char        sbuf[512];
    ierr = TSMPRKGetType(ts,&mprktype);CHKERRQ(ierr);
    ierr = PetscViewerASCIIPrintf(viewer,"  MPRK type %s\n",mprktype);CHKERRQ(ierr);
    ierr = PetscViewerASCIIPrintf(viewer,"  Order: %D\n",tab->order);CHKERRQ(ierr);
    ierr = PetscFormatRealArray(fbuf,sizeof(fbuf),"% 8.6f",tab->s,tab->cf);CHKERRQ(ierr);
    ierr = PetscViewerASCIIPrintf(viewer,"  Abscissa cf = %s\n",fbuf);CHKERRQ(ierr);
    ierr = PetscFormatRealArray(sbuf,sizeof(sbuf),"% 8.6f",tab->s,tab->cs);CHKERRQ(ierr);
    ierr = PetscViewerASCIIPrintf(viewer,"  Abscissa cs = %s\n",sbuf);CHKERRQ(ierr);
  }
  PetscFunctionReturn(0);
}

static PetscErrorCode TSLoad_MPRK(TS ts,PetscViewer viewer)
{
  PetscErrorCode ierr;
  TSAdapt        adapt;

  PetscFunctionBegin;
  ierr = TSGetAdapt(ts,&adapt);CHKERRQ(ierr);
  ierr = TSAdaptLoad(adapt,viewer);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*@C
  TSMPRKSetType - Set the type of MPRK scheme

  Logically collective

  Input Parameter:
+  ts - timestepping context
-  mprktype - type of MPRK-scheme

  Options Database:
.   -ts_mprk_type - <pm2,p2,p3>

  Level: intermediate

.seealso: TSMPRKGetType(), TSMPRK, TSMPRKType
@*/
PetscErrorCode TSMPRKSetType(TS ts,TSMPRKType mprktype)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts,TS_CLASSID,1);
  PetscValidCharPointer(mprktype,2);
  ierr = PetscTryMethod(ts,"TSMPRKSetType_C",(TS,TSMPRKType),(ts,mprktype));CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*@C
  TSMPRKGetType - Get the type of MPRK scheme

  Logically collective

  Input Parameter:
.  ts - timestepping context

  Output Parameter:
.  mprktype - type of MPRK-scheme

  Level: intermediate

.seealso: TSMPRKGetType()
@*/
PetscErrorCode TSMPRKGetType(TS ts,TSMPRKType *mprktype)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts,TS_CLASSID,1);
  ierr = PetscUseMethod(ts,"TSMPRKGetType_C",(TS,TSMPRKType*),(ts,mprktype));CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*@C
  TSMPRKSetMultirateType - Set the type of MPRK multirate scheme

  Logically collective

  Input Parameter:
+  ts - timestepping context
-  mprkmtype - type of the multirate configuration

  Options Database:
.   -ts_mprk_multirate_type - <nonsplit,split>

  Level: intermediate
@*/
PetscErrorCode TSMPRKSetMultirateType(TS ts, TSMPRKMultirateType mprkmtype)
{
  TS_MPRK        *mprk = (TS_MPRK*)ts->data;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts,TS_CLASSID,1);
  switch(mprkmtype){
    case TSMPRKNONSPLIT:
      ts->ops->step         = TSStep_MPRK;
      ts->ops->evaluatestep = TSEvaluateStep_MPRK;
      break;
    case TSMPRKSPLIT:
      ts->ops->step         = TSStep_MPRKSPLIT;
      ts->ops->evaluatestep = TSEvaluateStep_MPRKSPLIT;
      ierr = PetscObjectComposeFunction((PetscObject)ts,"TSMPRKSetSplits_C",TSMPRKSetSplits);CHKERRQ(ierr);
      break;
    default :
      SETERRQ1(PetscObjectComm((PetscObject)ts),PETSC_ERR_ARG_UNKNOWN_TYPE,"Unknown type '%s'",mprkmtype);
  }
  mprk->mprkmtype = mprkmtype;
  PetscFunctionReturn(0);
}

static PetscErrorCode TSMPRKGetType_MPRK(TS ts,TSMPRKType *mprktype)
{
  TS_MPRK *mprk = (TS_MPRK*)ts->data;

  PetscFunctionBegin;
  *mprktype = mprk->tableau->name;
  PetscFunctionReturn(0);
}

static PetscErrorCode TSMPRKSetType_MPRK(TS ts,TSMPRKType mprktype)
{
  TS_MPRK         *mprk = (TS_MPRK*)ts->data;
  PetscBool       match;
  MPRKTableauLink link;
  PetscErrorCode  ierr;

  PetscFunctionBegin;
  if (mprk->tableau) {
    ierr = PetscStrcmp(mprk->tableau->name,mprktype,&match);CHKERRQ(ierr);
    if (match) PetscFunctionReturn(0);
  }
  for (link = MPRKTableauList; link; link=link->next) {
    ierr = PetscStrcmp(link->tab.name,mprktype,&match);CHKERRQ(ierr);
    if (match) {
      if (ts->setupcalled) {ierr = TSMPRKTableauReset(ts);CHKERRQ(ierr);}
      mprk->tableau = &link->tab;
      if (ts->setupcalled) {ierr = TSMPRKTableauSetUp(ts);CHKERRQ(ierr);}
      PetscFunctionReturn(0);
    }
  }
  SETERRQ1(PetscObjectComm((PetscObject)ts),PETSC_ERR_ARG_UNKNOWN_TYPE,"Could not find '%s'",mprktype);
  PetscFunctionReturn(0);
}

static PetscErrorCode TSGetStages_MPRK(TS ts,PetscInt *ns,Vec **Y)
{
  TS_MPRK *mprk = (TS_MPRK*)ts->data;

  PetscFunctionBegin;
  *ns = mprk->tableau->s;
  if (Y) *Y = mprk->Y;
  PetscFunctionReturn(0);
}

static PetscErrorCode TSDestroy_MPRK(TS ts)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = TSReset_MPRK(ts);CHKERRQ(ierr);
  if (ts->dm) {
    ierr = DMCoarsenHookRemove(ts->dm,DMCoarsenHook_TSMPRK,DMRestrictHook_TSMPRK,ts);CHKERRQ(ierr);
    ierr = DMSubDomainHookRemove(ts->dm,DMSubDomainHook_TSMPRK,DMSubDomainRestrictHook_TSMPRK,ts);CHKERRQ(ierr);
  }
  ierr = PetscFree(ts->data);CHKERRQ(ierr);
  ierr = PetscObjectComposeFunction((PetscObject)ts,"TSMPRKGetType_C",NULL);CHKERRQ(ierr);
  ierr = PetscObjectComposeFunction((PetscObject)ts,"TSMPRKSetType_C",NULL);CHKERRQ(ierr);
  ierr = PetscObjectComposeFunction((PetscObject)ts,"TSMPRKSetMultirateType_C",NULL);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*MC
      TSMPRK - ODE solver using Partitioned Runge-Kutta schemes

  The user should provide the right hand side of the equation
  using TSSetRHSFunction().

  Notes:
  The default is TSMPRKPM2, it can be changed with TSRKSetType() or -ts_mprk_type

  Level: beginner

.seealso:  TSCreate(), TS, TSSetType(), TSMPRKSetType(), TSMPRKGetType(), TSMPRKType, TSMPRKRegister(), TSMPRKSetMultirateType()
           TSMPRKM2, TSMPRKM3, TSMPRKRFSMR3, TSMPRKRFSMR2

M*/
PETSC_EXTERN PetscErrorCode TSCreate_MPRK(TS ts)
{
  TS_MPRK        *mprk;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = TSMPRKInitializePackage();CHKERRQ(ierr);

  ts->ops->reset          = TSReset_MPRK;
  ts->ops->destroy        = TSDestroy_MPRK;
  ts->ops->view           = TSView_MPRK;
  ts->ops->load           = TSLoad_MPRK;
  ts->ops->setup          = TSSetUp_MPRK;
  ts->ops->step           = TSStep_MPRK;
  ts->ops->evaluatestep   = TSEvaluateStep_MPRK;
  ts->ops->setfromoptions = TSSetFromOptions_MPRK;
  ts->ops->getstages      = TSGetStages_MPRK;

  ierr = PetscNewLog(ts,&mprk);CHKERRQ(ierr);
  ts->data = (void*)mprk;

  ierr = PetscObjectComposeFunction((PetscObject)ts,"TSMPRKGetType_C",TSMPRKGetType_MPRK);CHKERRQ(ierr);
  ierr = PetscObjectComposeFunction((PetscObject)ts,"TSMPRKSetType_C",TSMPRKSetType_MPRK);CHKERRQ(ierr);
  ierr = PetscObjectComposeFunction((PetscObject)ts,"TSMPRKSetMultirateType_C",TSMPRKSetMultirateType);CHKERRQ(ierr);

  ierr = TSMPRKSetType(ts,TSMPRKDefault);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}
