#include <petsc/private/pcmgimpl.h>
#include <petsc/private/pcbddcimpl.h>
#include <petsc/private/pcbddcprivateimpl.h>

static PetscErrorCode PCMGGDSWSetUp(PC pc, PetscInt l, DM dm, KSP smooth, PetscInt Nc, Mat A, PetscInt *ns, Mat **sA_IG_n, KSP **sksp_n, IS **sI_n, IS **sG_n, Mat **sGf_n, IS **sGi_n, IS **sGiM_n)
{
  KSP                   *sksp;
  PC                     pcbddc = NULL, smoothpc;
  PC_BDDC               *ipcbddc;
  PC_IS                 *ipcis;
  Mat                   *sA_IG, *sGf, cmat, lA;
  ISLocalToGlobalMapping l2g;
  IS                    *sI, *sG, *sGi, *sGiM, cref;
  PCBDDCSubSchurs        sub_schurs = NULL;
  PCBDDCGraph            graph;
  const char            *prefix;
  const PetscScalar     *tdata;
  PetscScalar           *data, *cdata;
  PetscReal              tol = 0.0, otol;
  const PetscInt        *ia, *ja;
  PetscInt              *ccii, *cridx;
  PetscInt               i, j, ngct, ng, dbg = 0, odbg, minmax[2] = {0, PETSC_INT_MAX}, ominmax[2], vsize;
  PetscBool              flg, userdefined = PETSC_TRUE, reuse_solver = PETSC_TRUE, reduced = PETSC_FALSE;

  PetscFunctionBegin;
  PetscCall(MatGetBlockSize(A, &vsize));
  PetscCall(KSPGetOptionsPrefix(smooth, &prefix));
  PetscOptionsBegin(PetscObjectComm((PetscObject)smooth), prefix, "GDSW options", "PC");
  PetscCall(PetscOptionsReal("-gdsw_tolerance", "Tolerance for eigenvalue problem", NULL, tol, &tol, NULL));
  PetscCall(PetscOptionsBool("-gdsw_userdefined", "Use user-defined functions in addition to those adaptively generated", NULL, userdefined, &userdefined, NULL));
  PetscCall(PetscOptionsIntArray("-gdsw_minmax", "Minimum and maximum number of basis functions per connected component for adaptive GDSW", NULL, minmax, (i = 2, &i), NULL));
  PetscCall(PetscOptionsInt("-gdsw_vertex_size", "Connected components smaller or equal to vertex size will be considered as vertices", NULL, vsize, &vsize, NULL));
  PetscCall(PetscOptionsBool("-gdsw_reuse", "Reuse interior solver from Schur complement computations", NULL, reuse_solver, &reuse_solver, NULL));
  PetscCall(PetscOptionsBool("-gdsw_reduced", "Reduced GDSW", NULL, reduced, &reduced, NULL));
  PetscCall(PetscOptionsInt("-gdsw_debug", "Debug output", NULL, dbg, &dbg, NULL));
  PetscOptionsEnd();

  PetscCall(PetscObjectTypeCompare((PetscObject)A, MATIS, &flg));
  if (!flg) {
    MatNullSpace nnsp;

    PetscCall(MatGetNearNullSpace(A, &nnsp));
    PetscCall(PetscObjectReference((PetscObject)nnsp));
    PetscCall(MatConvert(A, MATIS, MAT_INITIAL_MATRIX, &A));
    PetscCall(MatSetNearNullSpace(A, nnsp));
    PetscCall(MatNullSpaceDestroy(&nnsp));
  } else PetscCall(PetscObjectReference((PetscObject)A));

  /* TODO Multi sub */
  *ns = 1;
  PetscCall(PetscMalloc1(*ns, &sA_IG));
  PetscCall(PetscMalloc1(*ns, &sksp));
  PetscCall(PetscMalloc1(*ns, &sI));
  PetscCall(PetscMalloc1(*ns, &sG));
  PetscCall(PetscMalloc1(*ns, &sGf));
  PetscCall(PetscMalloc1(*ns, &sGi));
  PetscCall(PetscMalloc1(*ns, &sGiM));
  *sA_IG_n = sA_IG;
  *sksp_n  = sksp;
  *sI_n    = sI;
  *sG_n    = sG;
  *sGf_n   = sGf;
  *sGi_n   = sGi;
  *sGiM_n  = sGiM;

  /* submatrices and solvers */
  PetscCall(KSPGetPC(smooth, &smoothpc));
  PetscCall(PetscObjectTypeCompareAny((PetscObject)smoothpc, &flg, PCBDDC, ""));
  if (!flg) {
    Mat smoothA;

    PetscCall(PCGetOperators(smoothpc, &smoothA, NULL));
    PetscCall(PCCreate(PetscObjectComm((PetscObject)A), &pcbddc));
    PetscCall(PCSetType(pcbddc, PCBDDC));
    PetscCall(PCSetOperators(pcbddc, smoothA, A));
    PetscCall(PCISSetUp(pcbddc, PETSC_TRUE, PETSC_FALSE));
  } else {
    PetscCall(PetscObjectReference((PetscObject)smoothpc));
    pcbddc = smoothpc;
  }
  ipcis   = (PC_IS *)pcbddc->data;
  ipcbddc = (PC_BDDC *)pcbddc->data;
  PetscCall(PetscObjectReference((PetscObject)ipcis->A_IB));
  PetscCall(PetscObjectReference((PetscObject)ipcis->is_I_global));
  PetscCall(PetscObjectReference((PetscObject)ipcis->is_B_global));
  sA_IG[0] = ipcis->A_IB;
  sI[0]    = ipcis->is_I_global;
  sG[0]    = ipcis->is_B_global;

  PetscCall(KSPCreate(PetscObjectComm((PetscObject)ipcis->A_II), &sksp[0]));
  PetscCall(KSPSetNestLevel(sksp[0], pc->kspnestlevel));
  PetscCall(KSPSetOperators(sksp[0], ipcis->A_II, ipcis->pA_II));
  PetscCall(KSPSetOptionsPrefix(sksp[0], prefix));
  PetscCall(KSPAppendOptionsPrefix(sksp[0], "gdsw_"));
  PetscCall(KSPSetFromOptions(sksp[0]));

  /* analyze interface */
  PetscCall(MatISGetLocalMat(A, &lA));
  graph = ipcbddc->mat_graph;
  if (!flg) {
    PetscInt N;

    PetscCall(MatISGetLocalToGlobalMapping(A, &l2g, NULL));
    PetscCall(MatGetSize(A, &N, NULL));
    PetscCall(PCBDDCGraphInit(graph, l2g, N, PETSC_INT_MAX));
    PetscCall(MatGetRowIJ(lA, 0, PETSC_TRUE, PETSC_FALSE, &graph->nvtxs_csr, (const PetscInt **)&graph->xadj, (const PetscInt **)&graph->adjncy, &flg));
    PetscCall(PCBDDCGraphSetUp(graph, vsize, NULL, NULL, 0, NULL, NULL));
    PetscCall(MatRestoreRowIJ(lA, 0, PETSC_TRUE, PETSC_FALSE, &graph->nvtxs_csr, (const PetscInt **)&graph->xadj, (const PetscInt **)&graph->adjncy, &flg));
    PetscCall(PCBDDCGraphComputeConnectedComponents(graph));
  }
  l2g = graph->l2gmap;
  if (reduced) {
    PetscContainer        gcand;
    PCBDDCGraphCandidates cand;
    PetscErrorCode (*rgdsw)(DM, PetscInt *, IS **);

    PetscCall(PetscObjectQueryFunction((PetscObject)dm, "DMComputeLocalRGDSWSets", &rgdsw));
    PetscCheck(rgdsw, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not supported");
    PetscCall(PetscNew(&cand));
    PetscCall((*rgdsw)(dm, &cand->nfc, &cand->Faces));
    /* filter interior (if any) and guarantee IS are ordered by global numbering */
    for (i = 0; i < cand->nfc; i++) {
      IS is, is2;

      PetscCall(ISLocalToGlobalMappingApplyIS(l2g, cand->Faces[i], &is));
      PetscCall(ISDestroy(&cand->Faces[i]));
      PetscCall(ISSort(is));
      PetscCall(ISGlobalToLocalMappingApplyIS(l2g, IS_GTOLM_DROP, is, &is2));
      PetscCall(ISDestroy(&is));
      PetscCall(ISGlobalToLocalMappingApplyIS(ipcis->BtoNmap, IS_GTOLM_DROP, is2, &is));
      PetscCall(ISDestroy(&is2));
      PetscCall(ISLocalToGlobalMappingApplyIS(ipcis->BtoNmap, is, &cand->Faces[i]));
      PetscCall(ISDestroy(&is));
    }
    PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &gcand));
    PetscCall(PetscContainerSetPointer(gcand, cand));
    PetscCall(PetscContainerSetCtxDestroy(gcand, PCBDDCDestroyGraphCandidatesIS));
    PetscCall(PetscObjectCompose((PetscObject)l2g, "_PCBDDCGraphCandidatesIS", (PetscObject)gcand));
    PetscCall(PetscContainerDestroy(&gcand));
  }

  /* interface functions */
  otol                           = ipcbddc->adaptive_threshold[1];
  odbg                           = ipcbddc->dbg_flag;
  ominmax[0]                     = ipcbddc->adaptive_nmin;
  ominmax[1]                     = ipcbddc->adaptive_nmax;
  ipcbddc->adaptive_threshold[1] = tol;
  ipcbddc->dbg_flag              = dbg;
  ipcbddc->adaptive_nmin         = minmax[0];
  ipcbddc->adaptive_nmax         = minmax[1];
  if (tol != 0.0) { /* adaptive */
    Mat lS;

    PetscCall(MatCreateSchurComplement(ipcis->A_II, ipcis->pA_II, ipcis->A_IB, ipcis->A_BI, ipcis->A_BB, &lS));
    PetscCall(KSPGetOptionsPrefix(sksp[0], &prefix));
    PetscCall(PCBDDCSubSchursCreate(&sub_schurs));
    PetscCall(PCBDDCSubSchursInit(sub_schurs, prefix, ipcis->is_I_local, ipcis->is_B_local, graph, ipcis->BtoNmap, PETSC_FALSE, PETSC_TRUE));
    if (userdefined) PetscCall(PCBDDCComputeFakeChange(pcbddc, PETSC_FALSE, graph, NULL, &cmat, &cref, NULL, &flg));
    else {
      cmat = NULL;
      cref = NULL;
    }
    PetscCall(PCBDDCSubSchursSetUp(sub_schurs, lA, lS, PETSC_TRUE, NULL, NULL, -1, NULL, PETSC_TRUE, reuse_solver, PETSC_FALSE, 0, NULL, NULL, cmat, cref));
    PetscCall(MatDestroy(&lS));
    PetscCall(MatDestroy(&cmat));
    PetscCall(ISDestroy(&cref));
    if (sub_schurs->reuse_solver) {
      PetscCall(KSPSetPC(sksp[0], sub_schurs->reuse_solver->interior_solver));
      PetscCall(PCDestroy(&sub_schurs->reuse_solver->interior_solver));
      sub_schurs->reuse_solver = NULL;
    }
  }
  PetscCall(PCBDDCComputeFakeChange(pcbddc, PETSC_TRUE, graph, sub_schurs, &cmat, &cref, &sGiM[0], NULL));
  PetscCall(PCBDDCSubSchursDestroy(&sub_schurs));
  ipcbddc->adaptive_threshold[1] = otol;
  ipcbddc->dbg_flag              = odbg;
  ipcbddc->adaptive_nmin         = ominmax[0];
  ipcbddc->adaptive_nmax         = ominmax[1];

  PetscCall(ISLocalToGlobalMappingApplyIS(l2g, cref, &sGi[0]));
  PetscCall(ISDestroy(&cref));

  PetscCall(MatSeqAIJGetArrayRead(cmat, &tdata));
  PetscCall(MatGetRowIJ(cmat, 0, PETSC_FALSE, PETSC_FALSE, &ngct, &ia, &ja, &flg));
  PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Error in MatGetRowIJ");

  PetscCall(PetscMalloc1(ngct + 1, &ccii));
  PetscCall(PetscMalloc1(ia[ngct], &cridx));
  PetscCall(PetscMalloc1(ia[ngct], &cdata));

  PetscCall(PetscArraycpy(ccii, ia, ngct + 1));
  PetscCall(PetscArraycpy(cdata, tdata, ia[ngct]));
  PetscCall(ISGlobalToLocalMappingApply(ipcis->BtoNmap, IS_GTOLM_DROP, ia[ngct], ja, &i, cridx));
  PetscCheck(i == ia[ngct], PETSC_COMM_SELF, PETSC_ERR_PLIB, "Error in G2L");

  PetscCall(MatRestoreRowIJ(cmat, 0, PETSC_FALSE, PETSC_FALSE, &i, &ia, &ja, &flg));
  PetscCall(MatSeqAIJRestoreArrayRead(cmat, &tdata));
  PetscCall(MatDestroy(&cmat));

  /* populate dense matrix */
  PetscCall(ISGetLocalSize(sG[0], &ng));
  PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, ng, ngct, NULL, &sGf[0]));
  PetscCall(MatDenseGetArrayWrite(sGf[0], &data));
  for (i = 0; i < ngct; i++)
    for (j = ccii[i]; j < ccii[i + 1]; j++) data[ng * i + cridx[j]] = cdata[j];
  PetscCall(MatDenseRestoreArrayWrite(sGf[0], &data));

  PetscCall(PetscFree(cdata));
  PetscCall(PetscFree(ccii));
  PetscCall(PetscFree(cridx));
  PetscCall(PCDestroy(&pcbddc));
  PetscCall(MatDestroy(&A));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode PCMGGDSWCreateCoarseSpace_Private(PC pc, PetscInt l, DM dm, KSP smooth, PetscInt Nc, Mat guess, Mat *cspace)
{
  KSP            *sksp;
  Mat             A, *sA_IG, *sGf, preallocator;
  IS              Gidx, GidxMult, cG;
  IS             *sI, *sG, *sGi, *sGiM;
  const PetscInt *cidx;
  PetscInt        NG, ns, n, i, c, rbs, cbs[2];
  PetscBool       flg;
  MatType         ptype;

  PetscFunctionBegin;
  *cspace = NULL;
  if (!l) PetscFunctionReturn(PETSC_SUCCESS);
  if (pc->useAmat) {
    PetscCall(KSPGetOperatorsSet(smooth, &flg, NULL));
    PetscCheck(flg, PetscObjectComm((PetscObject)smooth), PETSC_ERR_ORDER, "Amat not set");
    PetscCall(KSPGetOperators(smooth, &A, NULL));
  } else {
    PetscCall(KSPGetOperatorsSet(smooth, NULL, &flg));
    PetscCheck(flg, PetscObjectComm((PetscObject)smooth), PETSC_ERR_ORDER, "Pmat not set");
    PetscCall(KSPGetOperators(smooth, NULL, &A));
  }

  /* Setup (also setup smoother here) */
  if (!pc->setupcalled) PetscCall(KSPSetFromOptions(smooth));
  PetscCall(KSPSetUp(smooth));
  PetscCall(KSPSetUpOnBlocks(smooth));
  PetscCall(PCMGGDSWSetUp(pc, l, dm, smooth, Nc, A, &ns, &sA_IG, &sksp, &sI, &sG, &sGf, &sGi, &sGiM));

  /* Number GDSW basis functions */
  PetscCall(ISConcatenate(PetscObjectComm((PetscObject)A), ns, sGi, &Gidx));
  PetscCall(ISConcatenate(PetscObjectComm((PetscObject)A), ns, sGiM, &GidxMult));
  PetscCall(ISRenumber(Gidx, GidxMult, &NG, &cG));
  PetscCall(ISDestroy(&Gidx));

  /* Detect column block size */
  PetscCall(ISGetMinMax(GidxMult, &cbs[0], &cbs[1]));
  PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)A), cbs, cbs));
  PetscCall(ISDestroy(&GidxMult));

  /* Construct global interpolation matrix */
  PetscCall(MatGetLocalSize(A, NULL, &n));
  PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &preallocator));
  PetscCall(MatSetSizes(preallocator, n, PETSC_DECIDE, PETSC_DECIDE, NG));
  PetscCall(MatSetType(preallocator, MATPREALLOCATOR));
  PetscCall(MatSetUp(preallocator));
  PetscCall(ISGetIndices(cG, &cidx));
  for (i = 0, c = 0; i < ns; i++) {
    const PetscInt *ri, *rg;
    PetscInt        nri, nrg, ncg;

    PetscCall(ISGetLocalSize(sI[i], &nri));
    PetscCall(ISGetLocalSize(sG[i], &nrg));
    PetscCall(ISGetIndices(sI[i], &ri));
    PetscCall(ISGetIndices(sG[i], &rg));
    PetscCall(MatGetSize(sGf[i], NULL, &ncg));
    PetscCall(MatSetValues(preallocator, nri, ri, ncg, cidx + c, NULL, INSERT_VALUES));
    PetscCall(MatSetValues(preallocator, nrg, rg, ncg, cidx + c, NULL, INSERT_VALUES));
    PetscCall(ISRestoreIndices(sI[i], &ri));
    PetscCall(ISRestoreIndices(sG[i], &rg));
  }
  PetscCall(MatAssemblyBegin(preallocator, MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(preallocator, MAT_FINAL_ASSEMBLY));

  ptype = MATAIJ;
  if (PetscDefined(HAVE_DEVICE)) {
    PetscCall(MatBoundToCPU(A, &flg));
    if (!flg) {
      VecType vtype;
      char   *found = NULL;

      PetscCall(MatGetVecType(A, &vtype));
      PetscCall(PetscStrstr(vtype, "cuda", &found));
      if (found) ptype = MATAIJCUSPARSE;
    }
  }
  PetscCall(MatCreate(PetscObjectComm((PetscObject)A), cspace));
  PetscCall(MatSetSizes(*cspace, n, PETSC_DECIDE, PETSC_DECIDE, NG));
  PetscCall(MatSetType(*cspace, ptype));
  PetscCall(MatGetBlockSizes(A, NULL, &rbs));
  PetscCall(MatSetBlockSizes(*cspace, rbs, cbs[0] == cbs[1] ? cbs[0] : 1));
  PetscCall(MatPreallocatorPreallocate(preallocator, PETSC_FALSE, *cspace));
  PetscCall(MatDestroy(&preallocator));
  PetscCall(MatSetOption(*cspace, MAT_ROW_ORIENTED, PETSC_FALSE));

  for (i = 0, c = 0; i < ns; i++) {
    Mat                X, Y;
    const PetscScalar *v;
    const PetscInt    *ri, *rg;
    PetscInt           nri, nrg, ncg;

    PetscCall(MatMatMult(sA_IG[i], sGf[i], MAT_INITIAL_MATRIX, PETSC_CURRENT, &Y));
    PetscCall(MatScale(Y, -1.0));
    PetscCall(MatDuplicate(Y, MAT_DO_NOT_COPY_VALUES, &X));
    PetscCall(KSPMatSolve(sksp[i], Y, X));

    PetscCall(ISGetLocalSize(sI[i], &nri));
    PetscCall(ISGetLocalSize(sG[i], &nrg));
    PetscCall(ISGetIndices(sI[i], &ri));
    PetscCall(ISGetIndices(sG[i], &rg));
    PetscCall(MatGetSize(sGf[i], NULL, &ncg));

    PetscCall(MatDenseGetArrayRead(X, &v));
    PetscCall(MatSetValues(*cspace, nri, ri, ncg, cidx + c, v, INSERT_VALUES));
    PetscCall(MatDenseRestoreArrayRead(X, &v));
    PetscCall(MatDenseGetArrayRead(sGf[i], &v));
    PetscCall(MatSetValues(*cspace, nrg, rg, ncg, cidx + c, v, INSERT_VALUES));
    PetscCall(MatDenseRestoreArrayRead(sGf[i], &v));
    PetscCall(ISRestoreIndices(sI[i], &ri));
    PetscCall(ISRestoreIndices(sG[i], &rg));
    PetscCall(MatDestroy(&Y));
    PetscCall(MatDestroy(&X));
  }
  PetscCall(ISRestoreIndices(cG, &cidx));
  PetscCall(ISDestroy(&cG));
  PetscCall(MatAssemblyBegin(*cspace, MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(*cspace, MAT_FINAL_ASSEMBLY));

  for (i = 0; i < ns; i++) {
    PetscCall(KSPDestroy(&sksp[i]));
    PetscCall(ISDestroy(&sI[i]));
    PetscCall(ISDestroy(&sG[i]));
    PetscCall(ISDestroy(&sGi[i]));
    PetscCall(ISDestroy(&sGiM[i]));
    PetscCall(MatDestroy(&sGf[i]));
    PetscCall(MatDestroy(&sA_IG[i]));
  }
  PetscCall(PetscFree(sksp));
  PetscCall(PetscFree(sI));
  PetscCall(PetscFree(sG));
  PetscCall(PetscFree(sGi));
  PetscCall(PetscFree(sGiM));
  PetscCall(PetscFree(sGf));
  PetscCall(PetscFree(sA_IG));
  PetscFunctionReturn(PETSC_SUCCESS);
}
