/*
   This provides a matrix that applies a VecScatter to a vector.
*/

#include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
#include <petsc/private/vecimpl.h>

typedef struct {
  VecScatter scatter;
} Mat_Scatter;

/*@
  MatScatterGetVecScatter - Returns the user-provided scatter set with `MatScatterSetVecScatter()` in a `MATSCATTER` matrix

  Logically Collective

  Input Parameter:
. mat - the matrix, should have been created with MatCreateScatter() or have type `MATSCATTER`

  Output Parameter:
. scatter - the scatter context

  Level: intermediate

.seealso: [](ch_matrices), `Mat`, `MATSCATTER`, `MatCreateScatter()`, `MatScatterSetVecScatter()`
@*/
PetscErrorCode MatScatterGetVecScatter(Mat mat, VecScatter *scatter)
{
  Mat_Scatter *mscatter;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
  PetscAssertPointer(scatter, 2);
  mscatter = (Mat_Scatter *)mat->data;
  *scatter = mscatter->scatter;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatDestroy_Scatter(Mat mat)
{
  Mat_Scatter *scatter = (Mat_Scatter *)mat->data;

  PetscFunctionBegin;
  PetscCall(VecScatterDestroy(&scatter->scatter));
  PetscCall(PetscFree(mat->data));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMult_Scatter(Mat A, Vec x, Vec y)
{
  Mat_Scatter *scatter = (Mat_Scatter *)A->data;

  PetscFunctionBegin;
  PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
  PetscCall(VecZeroEntries(y));
  PetscCall(VecScatterBegin(scatter->scatter, x, y, ADD_VALUES, SCATTER_FORWARD));
  PetscCall(VecScatterEnd(scatter->scatter, x, y, ADD_VALUES, SCATTER_FORWARD));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMultAdd_Scatter(Mat A, Vec x, Vec y, Vec z)
{
  Mat_Scatter *scatter = (Mat_Scatter *)A->data;

  PetscFunctionBegin;
  PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
  if (z != y) PetscCall(VecCopy(y, z));
  PetscCall(VecScatterBegin(scatter->scatter, x, z, ADD_VALUES, SCATTER_FORWARD));
  PetscCall(VecScatterEnd(scatter->scatter, x, z, ADD_VALUES, SCATTER_FORWARD));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMultTranspose_Scatter(Mat A, Vec x, Vec y)
{
  Mat_Scatter *scatter = (Mat_Scatter *)A->data;

  PetscFunctionBegin;
  PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
  PetscCall(VecZeroEntries(y));
  PetscCall(VecScatterBegin(scatter->scatter, x, y, ADD_VALUES, SCATTER_REVERSE));
  PetscCall(VecScatterEnd(scatter->scatter, x, y, ADD_VALUES, SCATTER_REVERSE));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMultTransposeAdd_Scatter(Mat A, Vec x, Vec y, Vec z)
{
  Mat_Scatter *scatter = (Mat_Scatter *)A->data;

  PetscFunctionBegin;
  PetscCheck(scatter->scatter, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Need to first call MatScatterSetScatter()");
  if (z != y) PetscCall(VecCopy(y, z));
  PetscCall(VecScatterBegin(scatter->scatter, x, z, ADD_VALUES, SCATTER_REVERSE));
  PetscCall(VecScatterEnd(scatter->scatter, x, z, ADD_VALUES, SCATTER_REVERSE));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static struct _MatOps MatOps_Values = {NULL,
                                       NULL,
                                       NULL,
                                       MatMult_Scatter,
                                       /*  4*/ MatMultAdd_Scatter,
                                       MatMultTranspose_Scatter,
                                       MatMultTransposeAdd_Scatter,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 10*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 15*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 20*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 24*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 29*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 34*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 39*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 44*/ NULL,
                                       NULL,
                                       MatShift_Basic,
                                       NULL,
                                       NULL,
                                       /* 49*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 54*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 59*/ NULL,
                                       MatDestroy_Scatter,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 64*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 69*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 74*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 79*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 84*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 89*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /* 94*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*99*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*104*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*109*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*114*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*119*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*124*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*129*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*134*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       /*139*/ NULL,
                                       NULL,
                                       NULL,
                                       NULL,
                                       NULL};

/*MC
   MATSCATTER - "scatter" - A matrix type that simply applies a `VecScatterBegin()` and `VecScatterEnd()` to perform `MatMult()`

  Level: advanced

.seealso: [](ch_matrices), `Mat`, `MATSCATTER`, MatCreateScatter()`, `MatScatterSetVecScatter()`, `MatScatterGetVecScatter()`
M*/

PETSC_EXTERN PetscErrorCode MatCreate_Scatter(Mat A)
{
  Mat_Scatter *b;

  PetscFunctionBegin;
  A->ops[0] = MatOps_Values;
  PetscCall(PetscNew(&b));

  A->data = (void *)b;

  PetscCall(PetscLayoutSetUp(A->rmap));
  PetscCall(PetscLayoutSetUp(A->cmap));

  A->assembled    = PETSC_TRUE;
  A->preallocated = PETSC_FALSE;

  PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSCATTER));
  PetscFunctionReturn(PETSC_SUCCESS);
}

#include <petsc/private/sfimpl.h>
/*@
  MatCreateScatter - Creates a new matrix of `MatType` `MATSCATTER`, based on a VecScatter

  Collective

  Input Parameters:
+ comm    - MPI communicator
- scatter - a `VecScatter`

  Output Parameter:
. A - the matrix

  Level: intermediate

  Notes:
  PETSc requires that matrices and vectors being used for certain
  operations are partitioned accordingly.  For example, when
  creating a scatter matrix, A, that supports parallel matrix-vector
  products using `MatMult`(A,x,y) the user should set the number
  of local matrix rows to be the number of local elements of the
  corresponding result vector, y. Note that this is information is
  required for use of the matrix interface routines, even though
  the scatter matrix may not actually be physically partitioned.

  Developer Notes:
  This directly accesses information inside the `VecScatter` associated with the matrix-vector product
  for this matrix. This is not desirable..

.seealso: [](ch_matrices), `Mat`, `MatScatterSetVecScatter()`, `MatScatterGetVecScatter()`, `MATSCATTER`
@*/
PetscErrorCode MatCreateScatter(MPI_Comm comm, VecScatter scatter, Mat *A)
{
  PetscFunctionBegin;
  PetscCall(MatCreate(comm, A));
  PetscCall(MatSetSizes(*A, scatter->vscat.to_n, scatter->vscat.from_n, PETSC_DETERMINE, PETSC_DETERMINE));
  PetscCall(MatSetType(*A, MATSCATTER));
  PetscCall(MatScatterSetVecScatter(*A, scatter));
  PetscCall(MatSetUp(*A));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  MatScatterSetVecScatter - sets the scatter that the matrix is to apply as its linear operator in a `MATSCATTER`

  Logically Collective

  Input Parameters:
+ mat     - the `MATSCATTER` matrix
- scatter - the scatter context create with `VecScatterCreate()`

  Level: advanced

.seealso: [](ch_matrices), `Mat`, `MATSCATTER`, `MatCreateScatter()`
@*/
PetscErrorCode MatScatterSetVecScatter(Mat mat, VecScatter scatter)
{
  Mat_Scatter *mscatter = (Mat_Scatter *)mat->data;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(mat, MAT_CLASSID, 1);
  PetscValidHeaderSpecific(scatter, PETSCSF_CLASSID, 2);
  PetscCheckSameComm(scatter, 2, mat, 1);
  PetscCheck(mat->rmap->n == scatter->vscat.to_n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Number of local rows in matrix %" PetscInt_FMT " not equal local scatter size %" PetscInt_FMT, mat->rmap->n, scatter->vscat.to_n);
  PetscCheck(mat->cmap->n == scatter->vscat.from_n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Number of local columns in matrix %" PetscInt_FMT " not equal local scatter size %" PetscInt_FMT, mat->cmap->n, scatter->vscat.from_n);

  PetscCall(PetscObjectReference((PetscObject)scatter));
  PetscCall(VecScatterDestroy(&mscatter->scatter));

  mscatter->scatter = scatter;
  PetscFunctionReturn(PETSC_SUCCESS);
}
