!     Time-dependent advection-reaction PDE in 1d. Demonstrates IMEX methods
!
!     u_t + a1*u_x = -k1*u + k2*v + s1
!     v_t + a2*v_x = k1*u - k2*v + s2
!     0 < x < 1
!     a1 = 1, k1 = 10^6, s1 = 0,
!     a2 = 0, k2 = 2*k1, s2 = 1
!
!     Initial conditions:
!     u(x,0) = 1 + s2*x
!     v(x,0) = k0/k1*u(x,0) + s1/k1
!
!     Upstream boundary conditions:
!     u(0,t) = 1-sin(12*t)^4
!
#include <petsc/finclude/petscts.h>

module ex22f_mfmodule
  use petscts
  type AppCtx
    PetscReal a(2), k(2), s(2)
  end type AppCtx

  PetscScalar::PETSC_SHIFT
  TS::tscontext
  Mat::Jmat
  type(AppCtx) MFctx
end module ex22f_mfmodule

program main
  use ex22f_mfmodule
  use petscdm
  implicit none

  !
  !     Create an application context to contain data needed by the
  !     application-provided call-back routines, FormJacobian() and
  !     FormFunction(). We use a double precision array with six
  !     entries, two for each problem parameter a, k, s.
  !
  TS ts
  Vec X
  Mat J
  PetscInt mx
  PetscBool OptionSaveToDisk
  PetscErrorCode ierr
  DM da
  PetscReal ftime, dt
  PetscReal one, pone
  PetscInt im11, i2
  PetscBool flg

  PetscInt xs, xe, gxs, gxe, dof, gdof
  PetscScalar shell_shift
  Mat A
  type(AppCtx) ctx

  im11 = 11
  i2 = 2
  one = 1.0
  pone = one/10

  PetscCallA(PetscInitialize(ierr))

  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  !  Create distributed array (DMDA) to manage parallel grid and vectors
  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(DMDACreate1d(PETSC_COMM_WORLD, DM_BOUNDARY_NONE, im11, i2, i2, PETSC_NULL_INTEGER_ARRAY, da, ierr))
  PetscCallA(DMSetFromOptions(da, ierr))
  PetscCallA(DMSetUp(da, ierr))

  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  !    Extract global vectors from DMDA
  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(DMCreateGlobalVector(da, X, ierr))

  ! Initialize user application context
  ! Use zero-based indexing for command line parameters to match ex22.c
  ctx%a(1) = 1.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-a0', ctx%a(1), flg, ierr))
  ctx%a(2) = 0.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-a1', ctx%a(2), flg, ierr))
  ctx%k(1) = 1000000.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-k0', ctx%k(1), flg, ierr))
  ctx%k(2) = 2*ctx%k(1)
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-k1', ctx%k(2), flg, ierr))
  ctx%s(1) = 0.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-s0', ctx%s(1), flg, ierr))
  ctx%s(2) = 1.0
  PetscCallA(PetscOptionsGetReal(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-s1', ctx%s(2), flg, ierr))

  OptionSaveToDisk = .false.
  PetscCallA(PetscOptionsGetBool(PETSC_NULL_OPTIONS, PETSC_NULL_CHARACTER, '-sdisk', OptionSaveToDisk, flg, ierr))
  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  !    Create timestepping solver context
  !     - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(TSCreate(PETSC_COMM_WORLD, ts, ierr))
  tscontext = ts
  PetscCallA(TSSetDM(ts, da, ierr))
  PetscCallA(TSSetType(ts, TSARKIMEX, ierr))
  PetscCallA(TSSetRHSFunction(ts, PETSC_NULL_VEC, FormRHSFunction, ctx, ierr))

  ! - - - - - - - - -- - - - -
  !   Matrix free setup
  PetscCallA(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))
  dof = i2*(xe - xs + 1)
  gdof = i2*(gxe - gxs + 1)
  PetscCallA(MatCreateShell(PETSC_COMM_WORLD, dof, dof, gdof, gdof, shell_shift, A, ierr))

  PetscCallA(MatShellSetOperation(A, MATOP_MULT, MyMult, ierr))
  ! - - - - - - - - - - - -

  PetscCallA(TSSetIFunction(ts, PETSC_NULL_VEC, FormIFunction, ctx, ierr))
  PetscCallA(DMSetMatType(da, MATAIJ, ierr))
  PetscCallA(DMCreateMatrix(da, J, ierr))

  Jmat = J

  PetscCallA(TSSetIJacobian(ts, J, J, FormIJacobian, ctx, ierr))
  PetscCallA(TSSetIJacobian(ts, A, A, FormIJacobianMF, ctx, ierr))

  ftime = 1.0
  PetscCallA(TSSetMaxTime(ts, ftime, ierr))
  PetscCallA(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_STEPOVER, ierr))

  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  !  Set initial conditions
  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(FormInitialSolution(ts, X, ctx, ierr))
  PetscCallA(TSSetSolution(ts, X, ierr))
  PetscCallA(VecGetSize(X, mx, ierr))
!  Advective CFL, I don't know why it needs so much safety factor.
  dt = pone*max(ctx%a(1), ctx%a(2))/mx
  PetscCallA(TSSetTimeStep(ts, dt, ierr))

  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  !   Set runtime options
  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(TSSetFromOptions(ts, ierr))

  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  !  Solve nonlinear system
  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(TSSolve(ts, X, ierr))

  if (OptionSaveToDisk) then
    PetscCallA(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))
    dof = i2*(xe - xs + 1)
    gdof = i2*(gxe - gxs + 1)
    call SaveSolutionToDisk(da, X, gdof, xs, xe)
  end if

  ! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  !  Free work space.
! - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  PetscCallA(MatDestroy(A, ierr))
  PetscCallA(MatDestroy(J, ierr))
  PetscCallA(VecDestroy(X, ierr))
  PetscCallA(TSDestroy(ts, ierr))
  PetscCallA(DMDestroy(da, ierr))
  PetscCallA(PetscFinalize(ierr))
contains

! Small helper to extract the layout, result uses 1-based indexing.
  subroutine GetLayout(da, mx, xs, xe, gxs, gxe, ierr)
    use petscdm
    implicit none

    DM da
    PetscInt mx, xs, xe, gxs, gxe
    PetscErrorCode ierr
    PetscInt xm, gxm
    PetscCall(DMDAGetInfo(da, PETSC_NULL_INTEGER, mx, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, PETSC_NULL_ENUM, PETSC_NULL_ENUM, PETSC_NULL_ENUM, PETSC_NULL_ENUM, ierr))
    PetscCall(DMDAGetCorners(da, xs, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, xm, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, ierr))
    PetscCall(DMDAGetGhostCorners(da, gxs, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, gxm, PETSC_NULL_INTEGER, PETSC_NULL_INTEGER, ierr))
    xs = xs + 1
    gxs = gxs + 1
    xe = xs + xm - 1
    gxe = gxs + gxm - 1
  end subroutine GetLayout

  subroutine FormIFunctionLocal(mx, xs, xe, gxs, gxe, x, xdot, f, a, k, s, ierr)
    implicit none
    PetscInt mx, xs, xe, gxs, gxe
    PetscScalar x(2, xs:xe)
    PetscScalar xdot(2, xs:xe)
    PetscScalar f(2, xs:xe)
    PetscReal a(2), k(2), s(2)
    PetscErrorCode ierr
    PetscInt i
    do i = xs, xe
      f(1, i) = xdot(1, i) + k(1)*x(1, i) - k(2)*x(2, i) - s(1)
      f(2, i) = xdot(2, i) - k(1)*x(1, i) + k(2)*x(2, i) - s(2)
    end do
  end subroutine FormIFunctionLocal

  subroutine FormIFunction(ts, t, X, Xdot, F, ctx, ierr)
    use petscdm
    use ex22f_mfmodule
    implicit none

    TS ts
    PetscReal t
    Vec X, Xdot, F
    PetscErrorCode ierr
    type(AppCtx) ctx

    DM da
    PetscInt mx, xs, xe, gxs, gxe
    PetscScalar, pointer :: xx(:), xxdot(:), ff(:)

    PetscCall(TSGetDM(ts, da, ierr))
    PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

    ! Get access to vector data
    PetscCall(VecGetArrayRead(X, xx, ierr))
    PetscCall(VecGetArrayRead(Xdot, xxdot, ierr))
    PetscCall(VecGetArray(F, ff, ierr))

    PetscCall(FormIFunctionLocal(mx, xs, xe, gxs, gxe, xx, xxdot, ff, ctx%a, ctx%k, ctx%s, ierr))

    PetscCall(VecRestoreArrayRead(X, xx, ierr))
    PetscCall(VecRestoreArrayRead(Xdot, xxdot, ierr))
    PetscCall(VecRestoreArray(F, ff, ierr))
  end subroutine FormIFunction

  subroutine FormRHSFunctionLocal(mx, xs, xe, gxs, gxe, t, x, f, a, k, s, ierr)
    implicit none
    PetscInt mx, xs, xe, gxs, gxe
    PetscReal t
    PetscScalar x(2, gxs:gxe), f(2, xs:xe)
    PetscReal a(2), k(2), s(2)
    PetscErrorCode ierr
    PetscInt i, j
    PetscReal hx, u0t(2)
    PetscReal one, two, three, four, six, twelve
    PetscReal half, third, twothird, sixth
    PetscReal twelfth

    one = 1.0
    two = 2.0
    three = 3.0
    four = 4.0
    six = 6.0
    twelve = 12.0
    hx = one/mx
    u0t(1) = one - sin(twelve*t)**four
    u0t(2) = 0.0
    half = one/two
    third = one/three
    twothird = two/three
    sixth = one/six
    twelfth = one/twelve
    do i = xs, xe
      do j = 1, 2
        if (i == 1) then
          f(j, i) = a(j)/hx*(third*u0t(j) + half*x(j, i) - x(j, i + 1) + sixth*x(j, i + 2))
        else if (i == 2) then
          f(j, i) = a(j)/hx*(-twelfth*u0t(j) + twothird*x(j, i - 1) - twothird*x(j, i + 1) + twelfth*x(j, i + 2))
        else if (i == mx - 1) then
          f(j, i) = a(j)/hx*(-sixth*x(j, i - 2) + x(j, i - 1) - half*x(j, i) - third*x(j, i + 1))
        else if (i == mx) then
          f(j, i) = a(j)/hx*(-x(j, i) + x(j, i - 1))
        else
          f(j, i) = a(j)/hx*(-twelfth*x(j, i - 2) + twothird*x(j, i - 1) - twothird*x(j, i + 1) + twelfth*x(j, i + 2))
        end if
      end do
    end do

#ifdef EXPLICIT_INTEGRATOR22
    do i = xs, xe
      f(1, i) = f(1, i) - (k(1)*x(1, i) - k(2)*x(2, i) - s(1))
      f(2, i) = f(2, i) - (-k(1)*x(1, i) + k(2)*x(2, i) - s(2))
    end do
#endif

  end subroutine FormRHSFunctionLocal

  subroutine FormRHSFunction(ts, t, X, F, ctx, ierr)
    use ex22f_mfmodule
    implicit none

    TS ts
    PetscReal t
    Vec X, F
    type(AppCtx) ctx
    PetscErrorCode ierr
    DM da
    Vec Xloc
    PetscInt mx, xs, xe, gxs, gxe
    PetscScalar, pointer :: xx(:), ff(:)

    PetscCall(TSGetDM(ts, da, ierr))
    PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

    !     Scatter ghost points to local vector,using the 2-step process
    !        DMGlobalToLocalBegin(),DMGlobalToLocalEnd().
    !     By placing code between these two statements, computations can be
    !     done while messages are in transition.
    PetscCall(DMGetLocalVector(da, Xloc, ierr))
    PetscCall(DMGlobalToLocalBegin(da, X, INSERT_VALUES, Xloc, ierr))
    PetscCall(DMGlobalToLocalEnd(da, X, INSERT_VALUES, Xloc, ierr))

    ! Get access to vector data
    PetscCall(VecGetArrayRead(Xloc, xx, ierr))
    PetscCall(VecGetArray(F, ff, ierr))

    PetscCall(FormRHSFunctionLocal(mx, xs, xe, gxs, gxe, t, xx, ff, ctx%a, ctx%k, ctx%s, ierr))

    PetscCall(VecRestoreArrayRead(Xloc, xx, ierr))
    PetscCall(VecRestoreArray(F, ff, ierr))
    PetscCall(DMRestoreLocalVector(da, Xloc, ierr))
  end subroutine FormRHSFunction

! ---------------------------------------------------------------------
!
!  IJacobian - Compute IJacobian = dF/dU + shift*dF/dUdot
!
  subroutine FormIJacobian(ts, t, X, Xdot, shift, J, Jpre, ctx, ierr)
    use ex22f_mfmodule
    use petscdm
    implicit none

    TS ts
    PetscReal t, shift
    Vec X, Xdot
    Mat J, Jpre
    type(AppCtx) ctx
    PetscErrorCode ierr

    DM da
    PetscInt mx, xs, xe, gxs, gxe
    PetscInt i, i1, row, col
    PetscReal k1, k2
    PetscScalar val(4)

    PetscCall(TSGetDM(ts, da, ierr))
    PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

    i1 = 1
    k1 = ctx%k(1)
    k2 = ctx%k(2)
    do i = xs, xe
      row = i - gxs
      col = i - gxs
      val(1) = shift + k1
      val(2) = -k2
      val(3) = -k1
      val(4) = shift + k2
      PetscCall(MatSetValuesBlockedLocal(Jpre, i1, [row], i1, [col], val, INSERT_VALUES, ierr))
    end do
    PetscCall(MatAssemblyBegin(Jpre, MAT_FINAL_ASSEMBLY, ierr))
    PetscCall(MatAssemblyEnd(Jpre, MAT_FINAL_ASSEMBLY, ierr))
    if (J /= Jpre) then
      PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY, ierr))
      PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY, ierr))
    end if
  end subroutine FormIJacobian

  subroutine FormInitialSolutionLocal(mx, xs, xe, gxs, gxe, x, a, k, s, ierr)
    implicit none
    PetscInt mx, xs, xe, gxs, gxe
    PetscScalar x(2, xs:xe)
    PetscReal a(2), k(2), s(2)
    PetscErrorCode ierr

    PetscInt i
    PetscReal one, hx, r, ik
    one = 1.0
    hx = one/mx
    do i = xs, xe
      r = i*hx
      if (k(2) /= 0.0) then
        ik = one/k(2)
      else
        ik = one
      end if
      x(1, i) = one + s(2)*r
      x(2, i) = k(1)*ik*x(1, i) + s(2)*ik
    end do
  end subroutine FormInitialSolutionLocal

  subroutine FormInitialSolution(ts, X, ctx, ierr)
    use ex22f_mfmodule
    use petscdm
    implicit none

    TS ts
    Vec X
    type(AppCtx) ctx
    PetscErrorCode ierr

    DM da
    PetscInt mx, xs, xe, gxs, gxe
    PetscScalar, pointer :: xx(:)

    PetscCall(TSGetDM(ts, da, ierr))
    PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

    ! Get access to vector data
    PetscCall(VecGetArray(X, xx, ierr))

    PetscCall(FormInitialSolutionLocal(mx, xs, xe, gxs, gxe, xx, ctx%a, ctx%k, ctx%s, ierr))

    PetscCall(VecRestoreArray(X, xx, ierr))
  end subroutine FormInitialSolution

! ---------------------------------------------------------------------
!
!  IJacobian - Compute IJacobian = dF/dU + shift*dF/dUdot
!
  subroutine FormIJacobianMF(ts, t, X, Xdot, shift, J, Jpre, ctx, ierr)
    use ex22f_mfmodule
    implicit none
    TS ts
    PetscReal t, shift
    Vec X, Xdot
    Mat J, Jpre
    type(AppCtx) ctx
    PetscErrorCode ierr

    PETSC_SHIFT = shift
    MFctx = ctx

  end subroutine FormIJacobianMF

! -------------------------------------------------------------------
!
!   MyMult - user provided matrix multiply
!
!   Input Parameters:
!.  X - input vector
!
!   Output Parameter:
!.  F - function vector
!
  subroutine MyMult(A, X, F, ierr)
    use ex22f_mfmodule
    implicit none

    Mat A
    Vec X, F

    PetscErrorCode ierr
    PetscScalar shift
    type(AppCtx) ctx

    DM da
    PetscInt mx, xs, xe, gxs, gxe
    PetscInt i, i1, row, col
    PetscReal k1, k2
    PetscScalar val(4)

    shift = PETSC_SHIFT
    ctx = MFctx

    PetscCall(TSGetDM(tscontext, da, ierr))
    PetscCall(GetLayout(da, mx, xs, xe, gxs, gxe, ierr))

    i1 = 1
    k1 = ctx%k(1)
    k2 = ctx%k(2)

    do i = xs, xe
      row = i - gxs
      col = i - gxs
      val(1) = shift + k1
      val(2) = -k2
      val(3) = -k1
      val(4) = shift + k2
      PetscCall(MatSetValuesBlockedLocal(Jmat, i1, [row], i1, [col], val, INSERT_VALUES, ierr))
    end do

!  PetscCall(MatAssemblyBegin(Jpre,MAT_FINAL_ASSEMBLY,ierr))
!  PetscCall(MatAssemblyEnd(Jpre,MAT_FINAL_ASSEMBLY,ierr))
!  if (J /= Jpre) then
    PetscCall(MatAssemblyBegin(Jmat, MAT_FINAL_ASSEMBLY, ierr))
    PetscCall(MatAssemblyEnd(Jmat, MAT_FINAL_ASSEMBLY, ierr))
!  end if

    PetscCall(MatMult(Jmat, X, F, ierr))

  end subroutine MyMult

!
  subroutine SaveSolutionToDisk(da, X, gdof, xs, xe)
    use petscdm
    implicit none

    Vec X
    DM da
    PetscInt xs, xe, two
    PetscInt gdof, i
    PetscErrorCode ierr
    PetscScalar data2(2, xs:xe), data(gdof)
    PetscScalar, pointer :: xx(:)

    PetscCall(VecGetArrayRead(X, xx, ierr))

    two = 2
    data2 = reshape(xx(gdof:gdof), (/two, xe - xs + 1/))
    data = reshape(data2, (/gdof/))
    open (1020, file='solution_out_ex22f_mf.txt')
    do i = 1, gdof
      write (1020, '(e24.16,1x)') data(i)
    end do
    close (1020)

    PetscCall(VecRestoreArrayRead(X, xx, ierr))
  end subroutine SaveSolutionToDisk
end program main

!/*TEST
!
!    test:
!      args: -da_grid_x 200 -ts_arkimex_type 4
!      requires: !single
!      output_file: output/empty.out
!
!TEST*/
