#include <petsctao.h> /*I "petsctao.h" I*/
#include <../src/tao/matrix/submatfree.h>

/*@
  MatCreateSubMatrixFree - Creates a reduced matrix by masking a
  full matrix.

  Collective

  Input Parameters:
+ mat  - matrix of arbitrary type
. Rows - the rows that will be in the submatrix
- Cols - the columns that will be in the submatrix

  Output Parameter:
. J - New matrix

  Level: developer

  Note:
  The caller is responsible for destroying the input objects after matrix J has been destroyed.

  Developer Note:
  This should be moved/supported in `Mat`

.seealso: `MatCreate()`
@*/
PetscErrorCode MatCreateSubMatrixFree(Mat mat, IS Rows, IS Cols, Mat *J)
{
  MPI_Comm         comm = PetscObjectComm((PetscObject)mat);
  MatSubMatFreeCtx ctx;
  PetscInt         mloc, nloc, m, n;

  PetscFunctionBegin;
  PetscCall(PetscNew(&ctx));
  ctx->A = mat;
  PetscCall(MatGetSize(mat, &m, &n));
  PetscCall(MatGetLocalSize(mat, &mloc, &nloc));
  PetscCall(MatCreateVecs(mat, NULL, &ctx->VC));
  ctx->VR = ctx->VC;
  PetscCall(PetscObjectReference((PetscObject)mat));

  ctx->Rows = Rows;
  ctx->Cols = Cols;
  PetscCall(PetscObjectReference((PetscObject)Rows));
  PetscCall(PetscObjectReference((PetscObject)Cols));
  PetscCall(MatCreateShell(comm, mloc, nloc, m, n, ctx, J));
  PetscCall(MatShellSetManageScalingShifts(*J));
  PetscCall(MatShellSetOperation(*J, MATOP_MULT, (PetscErrorCodeFn *)MatMult_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_DESTROY, (PetscErrorCodeFn *)MatDestroy_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_VIEW, (PetscErrorCodeFn *)MatView_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_MULT_TRANSPOSE, (PetscErrorCodeFn *)MatMultTranspose_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_DIAGONAL_SET, (PetscErrorCodeFn *)MatDiagonalSet_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_SHIFT, (PetscErrorCodeFn *)MatShift_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_EQUAL, (PetscErrorCodeFn *)MatEqual_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_SCALE, (PetscErrorCodeFn *)MatScale_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_TRANSPOSE, (PetscErrorCodeFn *)MatTranspose_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_GET_DIAGONAL, (PetscErrorCodeFn *)MatGetDiagonal_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_CREATE_SUBMATRICES, (PetscErrorCodeFn *)MatCreateSubMatrices_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_NORM, (PetscErrorCodeFn *)MatNorm_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_DUPLICATE, (PetscErrorCodeFn *)MatDuplicate_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_CREATE_SUBMATRIX, (PetscErrorCodeFn *)MatCreateSubMatrix_SMF));
  PetscCall(MatShellSetOperation(*J, MATOP_GET_ROW_MAX, (PetscErrorCodeFn *)MatDuplicate_SMF));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatSMFResetRowColumn(Mat mat, IS Rows, IS Cols)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(ISDestroy(&ctx->Rows));
  PetscCall(ISDestroy(&ctx->Cols));
  PetscCall(PetscObjectReference((PetscObject)Rows));
  PetscCall(PetscObjectReference((PetscObject)Cols));
  ctx->Cols = Cols;
  ctx->Rows = Rows;
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatMult_SMF(Mat mat, Vec a, Vec y)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(VecCopy(a, ctx->VR));
  PetscCall(VecISSet(ctx->VR, ctx->Cols, 0.0));
  PetscCall(MatMult(ctx->A, ctx->VR, y));
  PetscCall(VecISSet(y, ctx->Rows, 0.0));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatMultTranspose_SMF(Mat mat, Vec a, Vec y)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(VecCopy(a, ctx->VC));
  PetscCall(VecISSet(ctx->VC, ctx->Rows, 0.0));
  PetscCall(MatMultTranspose(ctx->A, ctx->VC, y));
  PetscCall(VecISSet(y, ctx->Cols, 0.0));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatDiagonalSet_SMF(Mat M, Vec D, InsertMode is)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(M, &ctx));
  PetscCall(MatDiagonalSet(ctx->A, D, is));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatDestroy_SMF(Mat mat)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatDestroy(&ctx->A));
  PetscCall(ISDestroy(&ctx->Rows));
  PetscCall(ISDestroy(&ctx->Cols));
  PetscCall(VecDestroy(&ctx->VC));
  PetscCall(PetscFree(ctx));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatView_SMF(Mat mat, PetscViewer viewer)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatView(ctx->A, viewer));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatShift_SMF(Mat Y, PetscReal a)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(Y, &ctx));
  PetscCall(MatShift(ctx->A, a));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatDuplicate_SMF(Mat mat, MatDuplicateOption op, Mat *M)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatCreateSubMatrixFree(ctx->A, ctx->Rows, ctx->Cols, M));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatEqual_SMF(Mat A, Mat B, PetscBool *flg)
{
  MatSubMatFreeCtx ctx1, ctx2;
  PetscBool        flg1, flg2, flg3;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(A, &ctx1));
  PetscCall(MatShellGetContext(B, &ctx2));
  PetscCall(ISEqual(ctx1->Rows, ctx2->Rows, &flg2));
  PetscCall(ISEqual(ctx1->Cols, ctx2->Cols, &flg3));
  if (flg2 == PETSC_FALSE || flg3 == PETSC_FALSE) {
    *flg = PETSC_FALSE;
  } else {
    PetscCall(MatEqual(ctx1->A, ctx2->A, &flg1));
    if (flg1 == PETSC_FALSE) {
      *flg = PETSC_FALSE;
    } else {
      *flg = PETSC_TRUE;
    }
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatScale_SMF(Mat mat, PetscReal a)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatScale(ctx->A, a));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatTranspose_SMF(Mat mat, Mat *B)
{
  PetscFunctionBegin;
  SETERRQ(PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "No support for transpose for MatCreateSubMatrixFree() matrix");
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatGetDiagonal_SMF(Mat mat, Vec v)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatGetDiagonal(ctx->A, v));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatGetRowMax_SMF(Mat M, Vec D)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(M, &ctx));
  PetscCall(MatGetRowMax(ctx->A, D, NULL));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatCreateSubMatrices_SMF(Mat A, PetscInt n, IS *irow, IS *icol, MatReuse scall, Mat **B)
{
  PetscInt i;

  PetscFunctionBegin;
  if (scall == MAT_INITIAL_MATRIX) PetscCall(PetscCalloc1(n + 1, B));

  for (i = 0; i < n; i++) PetscCall(MatCreateSubMatrix_SMF(A, irow[i], icol[i], scall, &(*B)[i]));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatCreateSubMatrix_SMF(Mat mat, IS isrow, IS iscol, MatReuse cll, Mat *newmat)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  if (newmat) PetscCall(MatDestroy(&*newmat));
  PetscCall(MatCreateSubMatrixFree(ctx->A, isrow, iscol, newmat));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatGetRow_SMF(Mat mat, PetscInt row, PetscInt *ncols, const PetscInt **cols, const PetscScalar **vals)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatGetRow(ctx->A, row, ncols, cols, vals));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatRestoreRow_SMF(Mat mat, PetscInt row, PetscInt *ncols, const PetscInt **cols, const PetscScalar **vals)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatRestoreRow(ctx->A, row, ncols, cols, vals));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatGetColumnVector_SMF(Mat mat, Vec Y, PetscInt col)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  PetscCall(MatGetColumnVector(ctx->A, Y, col));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatNorm_SMF(Mat mat, NormType type, PetscReal *norm)
{
  MatSubMatFreeCtx ctx;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(mat, &ctx));
  if (type == NORM_FROBENIUS) {
    *norm = 1.0;
  } else if (type == NORM_1 || type == NORM_INFINITY) {
    *norm = 1.0;
  } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "No two norm");
  PetscFunctionReturn(PETSC_SUCCESS);
}
