#include <../src/snes/impls/fas/fasimpls.h> /*I  "petscsnes.h"  I*/

/*@
  SNESFASGetGalerkin - Gets if the coarse problems are formed by projection to the fine problem

  Not Collective but the result would be the same on all MPI processes

  Input Parameter:
. snes - the `SNESFAS` nonlinear solver context

  Output Parameter:
. flg - `PETSC_TRUE` if the coarse problem is formed by projection

  Level: advanced

.seealso: [](ch_snes), `SNES`, `SNESFAS`, `SNESFASSetLevels()`, `SNESFASSetGalerkin()`
@*/
PetscErrorCode SNESFASGetGalerkin(SNES snes, PetscBool *flg)
{
  SNES_FAS *fas;

  PetscFunctionBegin;
  PetscValidHeaderSpecificType(snes, SNES_CLASSID, 1, SNESFAS);
  fas  = (SNES_FAS *)snes->data;
  *flg = fas->galerkin;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  SNESFASSetGalerkin - Sets coarse problems as formed by projection to the fine problem

  Logically Collective

  Input Parameters:
+ snes - the `SNESFAS` nonlinear solver context
- flg  - `PETSC_TRUE` to use the projection process

  Level: advanced

.seealso: [](ch_snes), `SNES`, `SNESFAS`, `SNESFASSetLevels()`, `SNESFASGetGalerkin()`
@*/
PetscErrorCode SNESFASSetGalerkin(SNES snes, PetscBool flg)
{
  SNES_FAS *fas;

  PetscFunctionBegin;
  PetscValidHeaderSpecificType(snes, SNES_CLASSID, 1, SNESFAS);
  fas           = (SNES_FAS *)snes->data;
  fas->galerkin = flg;
  if (fas->next) PetscCall(SNESFASSetGalerkin(fas->next, flg));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@C
  SNESFASGalerkinFunctionDefault - Computes the Galerkin FAS function

  Collective

  Input Parameters:
+ snes - the `SNESFAS` nonlinear solver context
. X    - input vector
- ctx  - the application context

  Output Parameter:
. F - output vector

  Level: developer

  Note:
  The Galerkin FAS function evaluation is defined as

  $$
  F^l(x^l) = I^l_0 F^0(P^0_l x^l)
  $$

.seealso: [](ch_snes), `SNES`, `SNESFAS`, `SNESFASGetGalerkin()`, `SNESFASSetGalerkin()`
@*/
PetscErrorCode SNESFASGalerkinFunctionDefault(SNES snes, Vec X, Vec F, PetscCtx ctx)
{
  SNES      fassnes;
  SNES_FAS *fas;
  SNES_FAS *prevfas;
  SNES      prevsnes;
  Vec       b_temp;

  PetscFunctionBegin;
  /* prolong to the fine level and evaluate there. */
  fassnes  = (SNES)ctx;
  fas      = (SNES_FAS *)fassnes->data;
  prevsnes = fas->previous;
  prevfas  = (SNES_FAS *)prevsnes->data;
  /* interpolate down the solution */
  PetscCall(MatInterpolate(prevfas->interpolate, X, prevfas->Xg));
  /* the RHS we care about is at the coarsest level */
  b_temp            = prevsnes->vec_rhs;
  prevsnes->vec_rhs = NULL;
  PetscCall(SNESComputeFunction(prevsnes, prevfas->Xg, prevfas->Fg));
  prevsnes->vec_rhs = b_temp;
  /* restrict up the function */
  PetscCall(MatRestrict(prevfas->restrct, prevfas->Fg, F));
  PetscFunctionReturn(PETSC_SUCCESS);
}
