static char help[] = "Shows how to add a new MatOperation to AIJ MatType\n\n";

#include <petscmat.h>
#include <petscblaslapack.h>

static PetscErrorCode MatScaleUserImpl_SeqAIJ(Mat inA, PetscScalar alpha)
{
  PetscFunctionBegin;
  PetscCall(MatScale(inA, alpha));
  PetscFunctionReturn(PETSC_SUCCESS);
}

extern PetscErrorCode MatScaleUserImpl(Mat, PetscScalar);

static PetscErrorCode MatScaleUserImpl_MPIAIJ(Mat A, PetscScalar aa)
{
  Mat AA, AB;

  PetscFunctionBegin;
  PetscCall(MatMPIAIJGetSeqAIJ(A, &AA, &AB, NULL));
  PetscCall(MatScaleUserImpl(AA, aa));
  PetscCall(MatScaleUserImpl(AB, aa));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* This routine registers MatScaleUserImpl_SeqAIJ() and
   MatScaleUserImpl_MPIAIJ() as methods providing MatScaleUserImpl()
   functionality for SeqAIJ and MPIAIJ matrix-types */
PetscErrorCode RegisterMatScaleUserImpl(Mat mat)
{
  PetscMPIInt size;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
  if (size == 1) { /* SeqAIJ Matrix */
    PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatScaleUserImpl_C", MatScaleUserImpl_SeqAIJ));
  } else { /* MPIAIJ Matrix */
    Mat AA, AB;
    PetscCall(MatMPIAIJGetSeqAIJ(mat, &AA, &AB, NULL));
    PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatScaleUserImpl_C", MatScaleUserImpl_MPIAIJ));
    PetscCall(PetscObjectComposeFunction((PetscObject)AA, "MatScaleUserImpl_C", MatScaleUserImpl_SeqAIJ));
    PetscCall(PetscObjectComposeFunction((PetscObject)AB, "MatScaleUserImpl_C", MatScaleUserImpl_SeqAIJ));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* This routine deregisters MatScaleUserImpl_SeqAIJ() and
   MatScaleUserImpl_MPIAIJ() as methods providing MatScaleUserImpl()
   functionality for SeqAIJ and MPIAIJ matrix-types */
PetscErrorCode DeRegisterMatScaleUserImpl(Mat mat)
{
  PetscMPIInt size;

  PetscFunctionBegin;
  PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
  if (size == 1) { /* SeqAIJ Matrix */
    PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatScaleUserImpl_C", NULL));
  } else { /* MPIAIJ Matrix */
    Mat AA, AB;
    PetscCall(MatMPIAIJGetSeqAIJ(mat, &AA, &AB, NULL));
    PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatScaleUserImpl_C", NULL));
    PetscCall(PetscObjectComposeFunction((PetscObject)AA, "MatScaleUserImpl_C", NULL));
    PetscCall(PetscObjectComposeFunction((PetscObject)AB, "MatScaleUserImpl_C", NULL));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* this routines queries the already registered MatScaleUserImp_XXX
   implementations for the given matrix, and calls the correct
   routine. i.e if MatType is SeqAIJ, MatScaleUserImpl_SeqAIJ() gets
   called, and if MatType is MPIAIJ, MatScaleUserImpl_MPIAIJ() gets
   called */
PetscErrorCode MatScaleUserImpl(Mat mat, PetscScalar a)
{
  PetscErrorCode (*f)(Mat, PetscScalar);

  PetscFunctionBegin;
  PetscCall(PetscObjectQueryFunction((PetscObject)mat, "MatScaleUserImpl_C", &f));
  if (f) PetscCall((*f)(mat, a));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* Main user code that uses MatScaleUserImpl() */

int main(int argc, char **args)
{
  Mat         mat;
  PetscInt    i, j, m = 2, n, Ii, J;
  PetscScalar v, none = -1.0;
  PetscMPIInt rank, size;

  PetscFunctionBeginUser;
  PetscCall(PetscInitialize(&argc, &args, (char *)0, help));
  PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
  PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
  n = 2 * size;

  /* create the matrix */
  PetscCall(MatCreate(PETSC_COMM_WORLD, &mat));
  PetscCall(MatSetSizes(mat, PETSC_DECIDE, PETSC_DECIDE, m * n, m * n));
  PetscCall(MatSetType(mat, MATAIJ));
  PetscCall(MatSetUp(mat));

  /* register user defined MatScaleUser() operation for both SeqAIJ
     and MPIAIJ types */
  PetscCall(RegisterMatScaleUserImpl(mat));

  /* assemble the matrix */
  for (i = 0; i < m; i++) {
    for (j = 2 * rank; j < 2 * rank + 2; j++) {
      v  = -1.0;
      Ii = j + n * i;
      if (i > 0) {
        J = Ii - n;
        PetscCall(MatSetValues(mat, 1, &Ii, 1, &J, &v, INSERT_VALUES));
      }
      if (i < m - 1) {
        J = Ii + n;
        PetscCall(MatSetValues(mat, 1, &Ii, 1, &J, &v, INSERT_VALUES));
      }
      if (j > 0) {
        J = Ii - 1;
        PetscCall(MatSetValues(mat, 1, &Ii, 1, &J, &v, INSERT_VALUES));
      }
      if (j < n - 1) {
        J = Ii + 1;
        PetscCall(MatSetValues(mat, 1, &Ii, 1, &J, &v, INSERT_VALUES));
      }
      v = 4.0;
      PetscCall(MatSetValues(mat, 1, &Ii, 1, &Ii, &v, INSERT_VALUES));
    }
  }
  PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));

  /* check the matrix before and after scaling by -1.0 */
  PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Matrix _before_ MatScaleUserImpl() operation\n"));
  PetscCall(MatView(mat, PETSC_VIEWER_STDOUT_WORLD));
  PetscCall(MatScaleUserImpl(mat, none));
  PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Matrix _after_ MatScaleUserImpl() operation\n"));
  PetscCall(MatView(mat, PETSC_VIEWER_STDOUT_WORLD));

  /* deregister user defined MatScaleUser() operation for both SeqAIJ
     and MPIAIJ types */
  PetscCall(DeRegisterMatScaleUserImpl(mat));
  PetscCall(MatDestroy(&mat));
  PetscCall(PetscFinalize());
  return 0;
}

/*TEST

   test:

   test:
      suffix: 2
      nsize: 2

TEST*/
