#include <petsc/private/matimpl.h> /*I "petscmat.h" I*/
#include <petsc/private/vecimpl.h> /* for Vec->ops->setvalues */

PetscErrorCode MatConvert_Shell(Mat oldmat, MatType newtype, MatReuse reuse, Mat *newmat)
{
  Mat          mat;
  Vec          in, out;
  PetscScalar *array;
  PetscInt    *dnnz, *onnz, *dnnzu, *onnzu;
  PetscInt     cst, cen, Nbs, mbs, nbs, rbs, cbs;
  PetscInt     im, i, m, n, M, N, *rows, start;

  PetscFunctionBegin;
  PetscCall(MatGetOwnershipRange(oldmat, &start, NULL));
  PetscCall(MatGetOwnershipRangeColumn(oldmat, &cst, &cen));
  PetscCall(MatCreateVecs(oldmat, &in, &out));
  PetscCall(MatGetLocalSize(oldmat, &m, &n));
  PetscCall(MatGetSize(oldmat, &M, &N));
  PetscCall(PetscMalloc1(m, &rows));
  if (reuse != MAT_REUSE_MATRIX) {
    PetscCall(MatCreate(PetscObjectComm((PetscObject)oldmat), &mat));
    PetscCall(MatSetSizes(mat, m, n, M, N));
    PetscCall(MatSetType(mat, newtype));
    PetscCall(MatSetBlockSizesFromMats(mat, oldmat, oldmat));
    PetscCall(MatGetBlockSizes(mat, &rbs, &cbs));
    mbs = m / rbs;
    nbs = n / cbs;
    Nbs = N / cbs;
    cst = cst / cbs;
    PetscCall(PetscMalloc4(mbs, &dnnz, mbs, &onnz, mbs, &dnnzu, mbs, &onnzu));
    for (i = 0; i < mbs; i++) {
      dnnz[i]  = nbs;
      onnz[i]  = Nbs - nbs;
      dnnzu[i] = PetscMax(nbs - i, 0);
      onnzu[i] = PetscMax(Nbs - (cst + nbs), 0);
    }
    PetscCall(MatXAIJSetPreallocation(mat, PETSC_DECIDE, dnnz, onnz, dnnzu, onnzu));
    PetscCall(PetscFree4(dnnz, onnz, dnnzu, onnzu));
    PetscCall(VecSetOption(in, VEC_IGNORE_OFF_PROC_ENTRIES, PETSC_TRUE));
    PetscCall(MatSetUp(mat));
  } else {
    mat = *newmat;
    PetscCall(MatZeroEntries(mat));
  }
  for (i = 0; i < N; i++) {
    PetscInt j;

    PetscCall(VecZeroEntries(in));
    if (in->ops->setvalues) {
      PetscCall(VecSetValue(in, i, 1., INSERT_VALUES));
    } else {
      if (i >= cst && i < cen) {
        PetscCall(VecGetArray(in, &array));
        array[i - cst] = 1.0;
        PetscCall(VecRestoreArray(in, &array));
      }
    }
    PetscCall(VecAssemblyBegin(in));
    PetscCall(VecAssemblyEnd(in));
    PetscCall(MatMult(oldmat, in, out));
    PetscCall(VecGetArray(out, &array));
    for (j = 0, im = 0; j < m; j++) {
      if (PetscAbsScalar(array[j]) == 0.0) continue;
      rows[im]  = j + start;
      array[im] = array[j];
      im++;
    }
    PetscCall(MatSetValues(mat, im, rows, 1, &i, array, INSERT_VALUES));
    PetscCall(VecRestoreArray(out, &array));
  }
  PetscCall(PetscFree(rows));
  PetscCall(VecDestroy(&in));
  PetscCall(VecDestroy(&out));
  PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
  if (reuse == MAT_INPLACE_MATRIX) {
    PetscCall(MatHeaderReplace(oldmat, &mat));
  } else {
    *newmat = mat;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatGetDiagonal_CF(Mat A, Vec X)
{
  Mat B;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(A, &B));
  PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
  PetscCall(MatGetDiagonal(B, X));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMult_CF(Mat A, Vec X, Vec Y)
{
  Mat B;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(A, &B));
  PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
  PetscCall(MatMult(B, X, Y));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMultTranspose_CF(Mat A, Vec X, Vec Y)
{
  Mat B;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(A, &B));
  PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
  PetscCall(MatMultTranspose(B, X, Y));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatDestroy_CF(Mat A)
{
  Mat B;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(A, &B));
  PetscCheck(B, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing user matrix");
  PetscCall(MatDestroy(&B));
  PetscCall(MatShellSetContext(A, NULL));
  PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", NULL));
  PetscFunctionReturn(PETSC_SUCCESS);
}

typedef struct {
  void *userdata;
  PetscErrorCode (*ctxdestroy)(void *);
  PetscErrorCode (*numeric)(Mat);
  MatProductType ptype;
  Mat            Dwork;
} MatMatCF;

static PetscErrorCode MatProductDestroy_CF(void *data)
{
  MatMatCF *mmcfdata = (MatMatCF *)data;

  PetscFunctionBegin;
  if (mmcfdata->ctxdestroy) PetscCall((*mmcfdata->ctxdestroy)(mmcfdata->userdata));
  PetscCall(MatDestroy(&mmcfdata->Dwork));
  PetscCall(PetscFree(mmcfdata));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatProductNumericPhase_CF(Mat A, Mat B, Mat C, void *data)
{
  MatMatCF *mmcfdata = (MatMatCF *)data;

  PetscFunctionBegin;
  PetscCheck(mmcfdata, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing data");
  PetscCheck(mmcfdata->numeric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Missing numeric operation");
  /* the MATSHELL interface allows us to play with the product data */
  PetscCall(PetscNew(&C->product));
  C->product->type  = mmcfdata->ptype;
  C->product->data  = mmcfdata->userdata;
  C->product->Dwork = mmcfdata->Dwork;
  PetscCall(MatShellGetContext(A, &C->product->A));
  C->product->B = B;
  PetscCall((*mmcfdata->numeric)(C));
  PetscCall(PetscFree(C->product));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatProductSymbolicPhase_CF(Mat A, Mat B, Mat C, void **data)
{
  MatMatCF *mmcfdata;

  PetscFunctionBegin;
  PetscCall(MatShellGetContext(A, &C->product->A));
  PetscCall(MatProductSetFromOptions(C));
  PetscCall(MatProductSymbolic(C));
  /* the MATSHELL interface does not allow non-empty product data */
  PetscCall(PetscNew(&mmcfdata));

  mmcfdata->numeric    = C->ops->productnumeric;
  mmcfdata->ptype      = C->product->type;
  mmcfdata->userdata   = C->product->data;
  mmcfdata->ctxdestroy = C->product->destroy;
  mmcfdata->Dwork      = C->product->Dwork;

  C->product->Dwork   = NULL;
  C->product->data    = NULL;
  C->product->destroy = NULL;
  C->product->A       = A;

  *data = mmcfdata;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* only for A of type shell, mainly used for MatMat operations of shells with AXPYs */
static PetscErrorCode MatProductSetFromOptions_CF(Mat D)
{
  Mat A, B, Ain;
  PetscErrorCode (*Af)(Mat) = NULL;
  PetscBool flg;

  PetscFunctionBegin;
  MatCheckProduct(D, 1);
  if (D->product->type == MATPRODUCT_ABC) PetscFunctionReturn(PETSC_SUCCESS);
  A = D->product->A;
  B = D->product->B;
  PetscCall(MatIsShell(A, &flg));
  if (!flg) PetscFunctionReturn(PETSC_SUCCESS);
  PetscCall(PetscObjectQueryFunction((PetscObject)A, "MatProductSetFromOptions_anytype_C", &Af));
  if (Af == MatProductSetFromOptions_CF) {
    PetscCall(MatShellGetContext(A, &Ain));
  } else PetscFunctionReturn(PETSC_SUCCESS);
  D->product->A = Ain;
  PetscCall(MatProductSetFromOptions(D));
  D->product->A = A;
  if (D->ops->productsymbolic) { /* we have a symbolic match, now populate the MATSHELL operations */
    PetscCall(MatShellSetMatProductOperation(A, D->product->type, MatProductSymbolicPhase_CF, MatProductNumericPhase_CF, MatProductDestroy_CF, ((PetscObject)B)->type_name, NULL));
    PetscCall(MatProductSetFromOptions(D));
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatConvertFrom_Shell(Mat A, MatType newtype, MatReuse reuse, Mat *B)
{
  Mat       M;
  PetscBool flg;

  PetscFunctionBegin;
  PetscCall(PetscStrcmp(newtype, MATSHELL, &flg));
  PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_SUP, "Only conversion to MATSHELL");
  if (reuse == MAT_INITIAL_MATRIX) {
    PetscCall(PetscObjectReference((PetscObject)A));
    PetscCall(MatCreateShell(PetscObjectComm((PetscObject)A), A->rmap->n, A->cmap->n, A->rmap->N, A->cmap->N, A, &M));
    PetscCall(MatSetBlockSizesFromMats(M, A, A));
    PetscCall(MatShellSetOperation(M, MATOP_MULT, (void (*)(void))MatMult_CF));
    PetscCall(MatShellSetOperation(M, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_CF));
    PetscCall(MatShellSetOperation(M, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_CF));
    PetscCall(MatShellSetOperation(M, MATOP_DESTROY, (void (*)(void))MatDestroy_CF));
    PetscCall(PetscObjectComposeFunction((PetscObject)M, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_CF));
    PetscCall(PetscFree(M->defaultvectype));
    PetscCall(PetscStrallocpy(A->defaultvectype, &M->defaultvectype));
#if defined(PETSC_HAVE_DEVICE)
    PetscCall(MatBindToCPU(M, A->boundtocpu));
#endif
    *B = M;
  } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Not implemented");
  PetscFunctionReturn(PETSC_SUCCESS);
}
