#include <petsc/private/tsimpl.h>        /*I "petscts.h"  I*/

/*@C
   TSRHSSplitSetIS - Set the index set for the specified split

   Logically Collective on TS

   Input Parameters:
+  ts        - the TS context obtained from TSCreate()
.  splitname - name of this split, if NULL the number of the split is used
-  is        - the index set for part of the solution vector

   Level: intermediate

.seealso: TSRHSSplitGetIS

.keywords: TS, TSRHSSplit
@*/
PetscErrorCode TSRHSSplitSetIS(TS ts,const char splitname[],IS is)
{
  TS_RHSSplit    newsplit;
  char           prefix[128];
  PetscErrorCode ierr;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts,TS_CLASSID,1);
  PetscValidHeaderSpecific(is,IS_CLASSID,3);
  if (ts->num_rhs_splits == MAXRHSSPLITS) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_USER,"MAXIMUM number of splits reached");
  ierr = PetscNew(&newsplit);CHKERRQ(ierr);
  if (splitname) {
    ierr = PetscStrallocpy(splitname,&newsplit->splitname);CHKERRQ(ierr);
  } else {
    ierr = PetscMalloc1(8,&newsplit->splitname);CHKERRQ(ierr);
    ierr = PetscSNPrintf(newsplit->splitname,7,"%D",ts->num_rhs_splits);CHKERRQ(ierr);
  }
  ierr = PetscObjectReference((PetscObject)is);CHKERRQ(ierr);
  newsplit->is = is;
  ierr = TSCreate(PetscObjectComm((PetscObject)ts),&newsplit->ts);CHKERRQ(ierr);
  ierr = PetscObjectIncrementTabLevel((PetscObject)newsplit->ts,(PetscObject)ts,1);CHKERRQ(ierr);
  ierr = PetscLogObjectParent((PetscObject)ts,(PetscObject)newsplit->ts);CHKERRQ(ierr);
  ierr = PetscSNPrintf(prefix,sizeof(prefix),"%srhsplit_%s_",((PetscObject)ts)->prefix ? ((PetscObject)ts)->prefix : "",newsplit->splitname);
  ierr = TSSetOptionsPrefix(newsplit->ts,prefix);CHKERRQ(ierr);
  ts->tsrhssplit[ts->num_rhs_splits++] = newsplit;
  PetscFunctionReturn(0);
}

/*@C
   TSRHSSplitGetIS - Retrieves the elements for a split as an IS

   Logically Collective on TS

   Input Parameters:
+  ts        - the TS context obtained from TSCreate()
-  splitname - name of this split

   Output Parameters:
-  is        - the index set for part of the solution vector

   Level: intermediate

.seealso: TSRHSSplitSetIS

.keywords: TS, TSRHSSplit
@*/
PetscErrorCode TSRHSSplitGetIS(TS ts,const char splitname[],IS *is)
{
  PetscInt       i = 0;
  PetscBool      found = PETSC_FALSE;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts,TS_CLASSID,1);
  /* look up the split */
  while (i<ts->num_rhs_splits) {
    ierr = PetscStrcmp(ts->tsrhssplit[i]->splitname,splitname,&found);CHKERRQ(ierr);
    if (found) {
      *is = ts->tsrhssplit[i]->is;
      break;
    }
    i++;
  }
  PetscFunctionReturn(0);
}

/*@C
   TSRHSSplitSetRHSFunction - Set the split right-hand-side functions.

   Logically Collective on TS

   Input Parameters:
+  ts        - the TS context obtained from TSCreate()
.  splitname - name of this split
.  r         - vector to hold the residual (or NULL to have it created internally)
.  rhsfunc   - the RHS function evaluation routine
-  ctx       - user-defined context for private data for the split function evaluation routine (may be NULL)

 Calling sequence of fun:
$  rhsfunc(TS ts,PetscReal t,Vec u,Vec f,ctx);

+  t    - time at step/stage being solved
.  u    - state vector
.  f    - function vector
-  ctx  - [optional] user-defined context for matrix evaluation routine (may be NULL)

 Level: beginner

.keywords: TS, timestep, set, ODE, Hamiltonian, Function

.seealso: TSGetRHSSplitFunction()
@*/
PetscErrorCode TSRHSSplitSetRHSFunction(TS ts,const char splitname[],Vec r,TSRHSFunction rhsfunc,void *ctx)
{
  DM             dm;
  SNES           snes;
  Vec            subvec,ralloc = NULL;
  PetscBool      found = PETSC_FALSE;
  PetscInt       i = 0;
  TS             subts;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts,TS_CLASSID,1);
  if (r) PetscValidHeaderSpecific(r,VEC_CLASSID,2);

  /* look up the split */
  while (i<ts->num_rhs_splits) {
     ierr = PetscStrcmp(ts->tsrhssplit[i]->splitname,splitname,&found);CHKERRQ(ierr);
     if (found) break;
     i++;
  }
  if (!found) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_USER,"The split %s is not created, check the split name or call TSRHSSplitSetIS() to create one",splitname);

  subts = ts->tsrhssplit[i]->ts;

  ierr = TSGetDM(subts,&dm);CHKERRQ(ierr);
  ierr = DMTSSetRHSFunction(dm,rhsfunc,ctx);CHKERRQ(ierr);

  ierr = TSGetSNES(subts,&snes);CHKERRQ(ierr);
  if (!r && !subts->dm && ts->vec_sol) {
    ierr = VecGetSubVector(ts->vec_sol,ts->tsrhssplit[i]->is,&subvec);CHKERRQ(ierr);
    ierr = VecDuplicate(subvec,&ralloc);CHKERRQ(ierr);
    r = ralloc;
    ierr = VecRestoreSubVector(ts->vec_sol,ts->tsrhssplit[i]->is,&subvec);CHKERRQ(ierr);
  }
  ierr = SNESSetFunction(snes,r,SNESTSFormFunction,subts);CHKERRQ(ierr);
  ierr = VecDestroy(&ralloc);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

/*@C
   TSRHSSplitGetSubTS - Get the split right-hand-side functions.

   Logically Collective on TS

   Output Parameters:
+  n - the number of splits
-  subksp - the array of TS contexts

   Note:
   After TSRHSSplitGetSubTS() the array of TSs is to be freed by the user with PetscFree()
   (not the TS just the array that contains them).

   Level: advanced

.seealso: TSGetRHSSplitFunction()
@*/
PetscErrorCode TSRHSSplitGetSubTS(TS ts,PetscInt *n,TS **subts)
{
  PetscInt       i = 0;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(ts,TS_CLASSID,1);
  if (subts) {
    ierr = PetscMalloc1(ts->num_rhs_splits,subts);CHKERRQ(ierr);
    while (i<ts->num_rhs_splits) {
      (*subts)[i] = ts->tsrhssplit[i]->ts;
      i++;
    }
  }
  if (n) *n = ts->num_rhs_splits;
  PetscFunctionReturn(0);
}


