#include <petsc/private/snesimpl.h> /*I  "petscsnes.h" I*/
#include <petscdm.h>                /*I  "petscdm.h"   I*/
#include <../src/mat/impls/mffd/mffdimpl.h>
#include <petsc/private/matimpl.h>

/*@
  MatMFFDComputeJacobian - Tells the matrix-free Jacobian object the new location at which
  Jacobian matrix-vector products will be computed at, i.e. J(x) * a. The x is obtained
  from the `SNES` object (using `SNESGetSolution()`).

  Collective

  Input Parameters:
+ snes  - the nonlinear solver context
. x     - the point at which the Jacobian-vector products will be performed
. jac   - the matrix-free Jacobian object of `MatType` `MATMFFD`, likely obtained with `MatCreateSNESMF()`
. B     - either the same as `jac` or another matrix type (ignored)
- dummy - the user context (ignored)

  Options Database Key:
. -snes_mf - use the matrix created with `MatSNESMFCreate()` to setup the Jacobian for each new solution in the Newton process

  Level: developer

  Notes:
  If `MatMFFDSetBase()` is ever called on `jac` then this routine will NO longer get
  the `x` from the `SNES` object and `MatMFFDSetBase()` must from that point on be used to
  change the base vector `x`.

  This can be passed into `SNESSetJacobian()` as the Jacobian evaluation function argument
  when using a completely matrix-free solver,
  that is the B matrix is also the same matrix operator. This is used when you select
  -snes_mf but rarely used directly by users. (All this routine does is call `MatAssemblyBegin/End()` on
  the `Mat` `jac`.)

.seealso: [](ch_snes), `MatMFFDGetH()`, `MatCreateSNESMF()`, `MatMFFDSetBase()`, `MatCreateMFFD()`, `MATMFFD`,
          `MatMFFDSetHHistory()`, `MatMFFDSetFunctionError()`, `SNESSetJacobian()`
@*/
PetscErrorCode MatMFFDComputeJacobian(SNES snes, Vec x, Mat jac, Mat B, void *dummy)
{
  PetscFunctionBegin;
  PetscCall(MatAssemblyBegin(jac, MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(jac, MAT_FINAL_ASSEMBLY));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL PetscErrorCode MatAssemblyEnd_MFFD(Mat, MatAssemblyType);
PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL PetscErrorCode MatMFFDSetBase_MFFD(Mat, Vec, Vec);

/*@
  MatSNESMFGetSNES - returns the `SNES` associated with a matrix created with `MatCreateSNESMF()`

  Not Collective

  Input Parameter:
. J - the matrix

  Output Parameter:
. snes - the `SNES` object

  Level: advanced

.seealso: [](ch_snes), `Mat`, `SNES`, `MatCreateSNESMF()`
@*/
PetscErrorCode MatSNESMFGetSNES(Mat J, SNES *snes)
{
  MatMFFD j;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(J, &j));
  *snes = (SNES)j->ctx;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*
   MatAssemblyEnd_SNESMF - Calls MatAssemblyEnd_MFFD() and then sets the
    base from the SNES context

*/
static PetscErrorCode MatAssemblyEnd_SNESMF(Mat J, MatAssemblyType mt)
{
  MatMFFD j;
  SNES    snes;
  Vec     u, f;
  DM      dm;
  DMSNES  dms;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(J, &j));
  snes = (SNES)j->ctx;
  PetscCall(MatAssemblyEnd_MFFD(J, mt));

  PetscCall(SNESGetSolution(snes, &u));
  PetscCall(SNESGetDM(snes, &dm));
  PetscCall(DMGetDMSNES(dm, &dms));
  if ((j->func == (PetscErrorCode (*)(void *, Vec, Vec))SNESComputeFunction) && !dms->ops->computemffunction) {
    PetscCall(SNESGetFunction(snes, &f, NULL, NULL));
    PetscCall(MatMFFDSetBase_MFFD(J, u, f));
  } else {
    /* f value known by SNES is not correct for other differencing function */
    PetscCall(MatMFFDSetBase_MFFD(J, u, NULL));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*
   MatAssemblyEnd_SNESMF_UseBase - Calls MatAssemblyEnd_MFFD() and then sets the
    base from the SNES context. This version will cause the base to be used for differencing
    even if the func is not SNESComputeFunction. See: MatSNESMFUseBase()

*/
static PetscErrorCode MatAssemblyEnd_SNESMF_UseBase(Mat J, MatAssemblyType mt)
{
  MatMFFD j;
  SNES    snes;
  Vec     u, f;

  PetscFunctionBegin;
  PetscCall(MatAssemblyEnd_MFFD(J, mt));
  PetscCall(MatShellGetContext(J, &j));
  snes = (SNES)j->ctx;
  PetscCall(SNESGetSolution(snes, &u));
  PetscCall(SNESGetFunction(snes, &f, NULL, NULL));
  PetscCall(MatMFFDSetBase_MFFD(J, u, f));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*
    This routine resets the MatAssemblyEnd() for the MatMFFD created from MatCreateSNESMF() so that it NO longer
  uses the solution in the SNES object to update the base. See the warning in MatCreateSNESMF().
*/
static PetscErrorCode MatMFFDSetBase_SNESMF(Mat J, Vec U, Vec F)
{
  PetscFunctionBegin;
  PetscCall(MatMFFDSetBase_MFFD(J, U, F));
  J->ops->assemblyend = MatAssemblyEnd_MFFD;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatSNESMFSetReuseBase_SNESMF(Mat J, PetscBool use)
{
  PetscFunctionBegin;
  if (use) {
    J->ops->assemblyend = MatAssemblyEnd_SNESMF_UseBase;
  } else {
    J->ops->assemblyend = MatAssemblyEnd_SNESMF;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  MatSNESMFSetReuseBase - Causes the base vector to be used for differencing even if the function provided to `SNESSetFunction()` is not the
  same as that provided to `MatMFFDSetFunction()`.

  Logically Collective

  Input Parameters:
+ J   - the `MATMFFD` matrix
- use - if true always reuse the base vector instead of recomputing f(u) even if the function in the `MATMFFD` is
          not `SNESComputeFunction()`

  Level: advanced

  Note:
  Care must be taken when using this routine to insure that the function provided to `MatMFFDSetFunction()`, call it F_MF() is compatible with
  with that provided to `SNESSetFunction()`, call it F_SNES(). That is, (F_MF(u + h*d) - F_SNES(u))/h has to approximate the derivative

  Developer Notes:
  This was provided for the MOOSE team who desired to have a `SNESSetFunction()` function that could change configurations (similar to variable
  switching) to contacts while the function provided to `MatMFFDSetFunction()` cannot. Except for the possibility of changing the configuration
  both functions compute the same mathematical function so the differencing makes sense.

.seealso: [](ch_snes), `SNES`, `MATMFFD`, `MatMFFDSetFunction()`, `SNESSetFunction()`, `MatCreateSNESMF()`, `MatSNESMFGetReuseBase()`
@*/
PetscErrorCode MatSNESMFSetReuseBase(Mat J, PetscBool use)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(J, MAT_CLASSID, 1);
  PetscTryMethod(J, "MatSNESMFSetReuseBase_C", (Mat, PetscBool), (J, use));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatSNESMFGetReuseBase_SNESMF(Mat J, PetscBool *use)
{
  PetscFunctionBegin;
  if (J->ops->assemblyend == MatAssemblyEnd_SNESMF_UseBase) *use = PETSC_TRUE;
  else *use = PETSC_FALSE;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  MatSNESMFGetReuseBase - Determines if the base vector is to be used for differencing even if the function provided to `SNESSetFunction()` is not the
  same as that provided to `MatMFFDSetFunction()`.

  Logically Collective

  Input Parameter:
. J - the `MATMFFD` matrix

  Output Parameter:
. use - if true always reuse the base vector instead of recomputing f(u) even if the function in the `MATMFFD` is
          not `SNESComputeFunction()`

  Level: advanced

  Note:
  See `MatSNESMFSetReuseBase()`

.seealso: [](ch_snes), `Mat`, `SNES`, `MatSNESMFSetReuseBase()`, `MatCreateSNESMF()`
@*/
PetscErrorCode MatSNESMFGetReuseBase(Mat J, PetscBool *use)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(J, MAT_CLASSID, 1);
  PetscUseMethod(J, "MatSNESMFGetReuseBase_C", (Mat, PetscBool *), (J, use));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  MatCreateSNESMF - Creates a finite differencing based matrix-free matrix context for use with
  a `SNES` solver.  This matrix can be used as the Jacobian argument for
  the routine `SNESSetJacobian()`. See `MatCreateMFFD()` for details on how
  the finite difference computation is done.

  Collective

  Input Parameters:
. snes - the `SNES` context

  Output Parameter:
. J - the matrix-free matrix which is of type `MATMFFD`

  Level: advanced

  Notes:
  You can call `SNESSetJacobian()` with `MatMFFDComputeJacobian()` if you are not using a different
  matrix to construct the preconditioner.

  If you wish to provide a different function to do differencing on to compute the matrix-free operator than
  that provided to `SNESSetFunction()` then call `MatMFFDSetFunction()` with your function after this call.

  The difference between this routine and `MatCreateMFFD()` is that this matrix
  automatically gets the current base vector from the `SNES` object and not from an
  explicit call to `MatMFFDSetBase()`.

  If `MatMFFDSetBase()` is ever called on `jac` then this routine will NO longer get
  the x from the `SNES` object and `MatMFFDSetBase()` must from that point on be used to
  change the base vector `x`.

  Using a different function for the differencing will not work if you are using non-linear left preconditioning.

  This uses finite-differencing to apply the operator. To create a matrix-free `Mat` whose matrix-vector operator you
  provide with your own function use `MatCreateShell()`.

  Developer Note:
  This function should really be called `MatCreateSNESMFFD()` in correspondence to `MatCreateMFFD()` to clearly indicate
  that this is for using finite differences to apply the operator matrix-free.

.seealso: [](ch_snes), `SNES`, `MATMFFD`, `MatDestroy()`, `MatMFFDSetFunction()`, `MatMFFDSetFunctionError()`, `MatMFFDDSSetUmin()`
          `MatMFFDSetHHistory()`, `MatMFFDResetHHistory()`, `MatCreateMFFD()`, `MatCreateShell()`,
          `MatMFFDGetH()`, `MatMFFDRegister()`, `MatMFFDComputeJacobian()`, `MatSNESMFSetReuseBase()`, `MatSNESMFGetReuseBase()`
@*/
PetscErrorCode MatCreateSNESMF(SNES snes, Mat *J)
{
  PetscInt n, N;
  MatMFFD  mf;

  PetscFunctionBegin;
  if (snes->vec_func) {
    PetscCall(VecGetLocalSize(snes->vec_func, &n));
    PetscCall(VecGetSize(snes->vec_func, &N));
  } else if (snes->dm) {
    Vec tmp;
    PetscCall(DMGetGlobalVector(snes->dm, &tmp));
    PetscCall(VecGetLocalSize(tmp, &n));
    PetscCall(VecGetSize(tmp, &N));
    PetscCall(DMRestoreGlobalVector(snes->dm, &tmp));
  } else SETERRQ(PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "Must call SNESSetFunction() or SNESSetDM() first");
  PetscCall(MatCreateMFFD(PetscObjectComm((PetscObject)snes), n, n, N, N, J));
  PetscCall(MatShellGetContext(*J, &mf));
  mf->ctx = snes;

  if (snes->npc && snes->npcside == PC_LEFT) {
    PetscCall(MatMFFDSetFunction(*J, (PetscErrorCode (*)(void *, Vec, Vec))SNESComputeFunctionDefaultNPC, snes));
  } else {
    DM     dm;
    DMSNES dms;

    PetscCall(SNESGetDM(snes, &dm));
    PetscCall(DMGetDMSNES(dm, &dms));
    PetscCall(MatMFFDSetFunction(*J, (PetscErrorCode (*)(void *, Vec, Vec))(dms->ops->computemffunction ? SNESComputeMFFunction : SNESComputeFunction), snes));
  }
  (*J)->ops->assemblyend = MatAssemblyEnd_SNESMF;

  PetscCall(PetscObjectComposeFunction((PetscObject)*J, "MatMFFDSetBase_C", MatMFFDSetBase_SNESMF));
  PetscCall(PetscObjectComposeFunction((PetscObject)*J, "MatSNESMFSetReuseBase_C", MatSNESMFSetReuseBase_SNESMF));
  PetscCall(PetscObjectComposeFunction((PetscObject)*J, "MatSNESMFGetReuseBase_C", MatSNESMFGetReuseBase_SNESMF));
  PetscFunctionReturn(PETSC_SUCCESS);
}
